Spaces:
Sleeping
Sleeping
Upload 15 files
Browse files- Dockerfile +23 -0
- best_corr_model.pt +3 -0
- data/.DS_Store +0 -0
- data/base2info.json +0 -0
- data/bs_record_energy_normalized_sampled.npz +3 -0
- data/spatial_features.npz +3 -0
- hierarchical_flow_matching_training_v4.py +935 -0
- hierarchical_flow_matching_v4.py +1019 -0
- index.html +147 -0
- multimodal_spatial_encoder_v4.py +645 -0
- prediction_backend.py +359 -0
- requirements.txt +8 -0
- script.js +954 -0
- server.py +269 -0
- style.css +617 -0
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. Use official lightweight Python 3.10 image to minimize build size
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# 2. Define the working directory inside the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# 3. Create a non-root user with UID 1000
|
| 8 |
+
RUN useradd -m -u 1000 user
|
| 9 |
+
USER user
|
| 10 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 11 |
+
|
| 12 |
+
# 4. Copy requirements file and install dependencies
|
| 13 |
+
COPY --chown=user requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# 5. Copy the entire project source code and data assets to the container
|
| 17 |
+
COPY --chown=user . /app
|
| 18 |
+
|
| 19 |
+
# 6. Expose port 7860
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
# 7. Execute the Flask server script
|
| 23 |
+
CMD ["python", "server.py"]
|
best_corr_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:83b578e901f1f3d11421431ec5286548ae200b0ed1165afee66b6f05ad19bd76
|
| 3 |
+
size 563607134
|
data/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
data/base2info.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/bs_record_energy_normalized_sampled.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:62a53c46d7cc8b0fca0fa9943eac045ca1853a4eb1df2a4f58c5c939179b66c7
|
| 3 |
+
size 74859436
|
data/spatial_features.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:395420bb4696091d85bae743e680b861326342965ef7c85a21eb22d9c295e4ea
|
| 3 |
+
size 294559
|
hierarchical_flow_matching_training_v4.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hierarchical Flow Matching Training Framework (V4) - Generative Fusion
|
| 3 |
+
======================================================================
|
| 4 |
+
|
| 5 |
+
Complete training pipeline with:
|
| 6 |
+
1. Three-level cascaded Flow Matching losses
|
| 7 |
+
2. Hierarchical multi-periodic supervision
|
| 8 |
+
3. Temporal structure preservation
|
| 9 |
+
4. Adaptive learning rate scheduling
|
| 10 |
+
|
| 11 |
+
[FUSION] Generative Mode: Implicit alignment via conditional flow matching.
|
| 12 |
+
Enhanced with explicit peak conditioning and auxiliary classification.
|
| 13 |
+
Physical Boundary Loss & Bias Correction.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import json
|
| 18 |
+
import numpy as np
|
| 19 |
+
from typing import Dict, Optional, Tuple, Literal
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from torch.utils.data import DataLoader
|
| 26 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 27 |
+
|
| 28 |
+
from hierarchical_flow_matching_v4 import HierarchicalFlowMatchingV4
|
| 29 |
+
from multimodal_spatial_encoder_v4 import MultiModalSpatialEncoderV4
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# =============================================================================
|
| 33 |
+
# Hierarchical Multi-Periodic Loss Functions
|
| 34 |
+
# =============================================================================
|
| 35 |
+
|
| 36 |
+
class HierarchicalFlowMatchingLoss(nn.Module):
|
| 37 |
+
"""
|
| 38 |
+
Hierarchical Flow Matching loss with multi-periodic supervision.
|
| 39 |
+
|
| 40 |
+
Combines:
|
| 41 |
+
1. Level 1 (Daily) Flow Matching loss
|
| 42 |
+
2. Level 2 (Weekly) Flow Matching loss
|
| 43 |
+
3. Level 3 (Residual) Flow Matching loss [Peak Conditioned]
|
| 44 |
+
4. Temporal structure preservation loss
|
| 45 |
+
5. Multi-periodic consistency loss
|
| 46 |
+
6. Peak Hour Classification Loss
|
| 47 |
+
7. Physical Boundary Loss (No-Negative Constraint)
|
| 48 |
+
8. Bias Correction Loss (Global Mean Alignment)
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
# Helper: Physical Constraint
|
| 55 |
+
def compute_boundary_loss(self, predicted_x1: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
"""
|
| 57 |
+
Penalize negative values in the estimated traffic.
|
| 58 |
+
Loss = ReLU(-x).mean() * scale
|
| 59 |
+
"""
|
| 60 |
+
return F.relu(-predicted_x1).mean() * 10.0
|
| 61 |
+
|
| 62 |
+
# Helper: Bias Constraint
|
| 63 |
+
def compute_bias_loss(self, generated: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
"""
|
| 65 |
+
Fix 'Parallel Lines' issue by forcing global mean alignment.
|
| 66 |
+
"""
|
| 67 |
+
gen_mean = generated.mean(dim=1) # [B]
|
| 68 |
+
real_mean = real.mean(dim=1) # [B]
|
| 69 |
+
return F.l1_loss(gen_mean, real_mean) * 20.0
|
| 70 |
+
|
| 71 |
+
def compute_level1_loss(
|
| 72 |
+
self,
|
| 73 |
+
model: HierarchicalFlowMatchingV4,
|
| 74 |
+
real_traffic: torch.Tensor,
|
| 75 |
+
spatial_cond_level1: torch.Tensor,
|
| 76 |
+
) -> Tuple[torch.Tensor, torch.Tensor]: # [MODIFIED] Returns tuple
|
| 77 |
+
"""
|
| 78 |
+
Level 1 (Day-Type Templates) Flow Matching loss.
|
| 79 |
+
"""
|
| 80 |
+
B = real_traffic.shape[0]
|
| 81 |
+
device = real_traffic.device
|
| 82 |
+
|
| 83 |
+
# 672 hourly samples = 28 days * 24 hours
|
| 84 |
+
steps_per_day = 24
|
| 85 |
+
n_days = 28
|
| 86 |
+
real_reshaped = real_traffic.reshape(B, n_days, steps_per_day) # [B, 28, 24]
|
| 87 |
+
|
| 88 |
+
# Assume sequence starts on Monday
|
| 89 |
+
day_of_week = torch.arange(n_days, device=device) % 7
|
| 90 |
+
weekday_idx = torch.where(day_of_week < 5)[0]
|
| 91 |
+
weekend_idx = torch.where(day_of_week >= 5)[0]
|
| 92 |
+
|
| 93 |
+
weekday_pattern = real_reshaped.index_select(1, weekday_idx).mean(dim=1) # [B, 24]
|
| 94 |
+
weekend_pattern = real_reshaped.index_select(1, weekend_idx).mean(dim=1) # [B, 24]
|
| 95 |
+
|
| 96 |
+
# Target is concatenation: [weekday(24), weekend(24)] -> [B, 48]
|
| 97 |
+
x1 = torch.cat([weekday_pattern, weekend_pattern], dim=1)
|
| 98 |
+
|
| 99 |
+
# Sample noise
|
| 100 |
+
x0 = torch.randn_like(x1)
|
| 101 |
+
|
| 102 |
+
# Sample time
|
| 103 |
+
t = torch.rand(B, 1, device=device)
|
| 104 |
+
|
| 105 |
+
# Interpolation
|
| 106 |
+
x_t = t * x1 + (1 - t) * x0
|
| 107 |
+
|
| 108 |
+
# Target velocity
|
| 109 |
+
v_target = x1 - x0
|
| 110 |
+
|
| 111 |
+
# Predict velocity
|
| 112 |
+
v_pred = model(x_t, t, spatial_cond_level1, level=1)
|
| 113 |
+
|
| 114 |
+
# Flow Matching loss
|
| 115 |
+
loss_fm = F.mse_loss(v_pred, v_target)
|
| 116 |
+
|
| 117 |
+
# Boundary Loss for Level 1
|
| 118 |
+
x1_est = x_t + (1 - t) * v_pred
|
| 119 |
+
loss_boundary = self.compute_boundary_loss(x1_est)
|
| 120 |
+
|
| 121 |
+
return loss_fm, loss_boundary
|
| 122 |
+
|
| 123 |
+
def compute_level2_loss(
|
| 124 |
+
self,
|
| 125 |
+
model: HierarchicalFlowMatchingV4,
|
| 126 |
+
real_traffic: torch.Tensor,
|
| 127 |
+
spatial_cond_level2: torch.Tensor,
|
| 128 |
+
spatial_cond_level1: torch.Tensor,
|
| 129 |
+
daily_pattern: Optional[torch.Tensor] = None,
|
| 130 |
+
use_teacher_forcing: bool = True,
|
| 131 |
+
n_steps_generate: int = 10,
|
| 132 |
+
) -> Tuple[torch.Tensor, torch.Tensor]: # [MODIFIED] Returns tuple
|
| 133 |
+
"""
|
| 134 |
+
Level 2 (Weekly Pattern, 168 hours) Flow Matching loss.
|
| 135 |
+
"""
|
| 136 |
+
B = real_traffic.shape[0]
|
| 137 |
+
device = real_traffic.device
|
| 138 |
+
|
| 139 |
+
steps_per_day = 24
|
| 140 |
+
n_days = 28
|
| 141 |
+
n_weeks = 4
|
| 142 |
+
real_reshaped = real_traffic.reshape(B, n_days, steps_per_day) # [B, 28, 24]
|
| 143 |
+
|
| 144 |
+
# Weekly pattern ground truth
|
| 145 |
+
weekly_days = []
|
| 146 |
+
for dow in range(7):
|
| 147 |
+
idx = torch.tensor([dow + 7 * w for w in range(n_weeks)], device=device, dtype=torch.long)
|
| 148 |
+
weekly_days.append(real_reshaped.index_select(1, idx).mean(dim=1)) # [B, 24]
|
| 149 |
+
weekly_pattern = torch.stack(weekly_days, dim=1).reshape(B, 7 * steps_per_day) # [B, 168]
|
| 150 |
+
|
| 151 |
+
x1 = weekly_pattern # target weekly pattern
|
| 152 |
+
|
| 153 |
+
# Get day-type templates (weekday/weekend)
|
| 154 |
+
if daily_pattern is None or not use_teacher_forcing:
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
daily_pattern = model.generate_daily_pattern(
|
| 157 |
+
spatial_cond_level1, n_steps=n_steps_generate
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
# Teacher forcing
|
| 161 |
+
day_of_week = torch.arange(n_days, device=device) % 7
|
| 162 |
+
weekday_idx = torch.where(day_of_week < 5)[0]
|
| 163 |
+
weekend_idx = torch.where(day_of_week >= 5)[0]
|
| 164 |
+
weekday_pattern = real_reshaped.index_select(1, weekday_idx).mean(dim=1) # [B, 24]
|
| 165 |
+
weekend_pattern = real_reshaped.index_select(1, weekend_idx).mean(dim=1) # [B, 24]
|
| 166 |
+
daily_pattern = torch.cat([weekday_pattern, weekend_pattern], dim=1) # [B, 48]
|
| 167 |
+
|
| 168 |
+
# Sample noise
|
| 169 |
+
x0 = torch.randn_like(x1)
|
| 170 |
+
|
| 171 |
+
# Sample time
|
| 172 |
+
t = torch.rand(B, 1, device=device)
|
| 173 |
+
|
| 174 |
+
# Interpolation
|
| 175 |
+
x_t = t * x1 + (1 - t) * x0
|
| 176 |
+
|
| 177 |
+
# Target velocity
|
| 178 |
+
v_target = x1 - x0
|
| 179 |
+
|
| 180 |
+
# Predict velocity
|
| 181 |
+
v_pred = model(x_t, t, spatial_cond_level2, level=2, daily_pattern=daily_pattern)
|
| 182 |
+
|
| 183 |
+
# Flow Matching loss
|
| 184 |
+
loss_fm = F.mse_loss(v_pred, v_target)
|
| 185 |
+
|
| 186 |
+
# Boundary Loss for Level 2
|
| 187 |
+
x1_est = x_t + (1 - t) * v_pred
|
| 188 |
+
loss_boundary = self.compute_boundary_loss(x1_est)
|
| 189 |
+
|
| 190 |
+
return loss_fm, loss_boundary
|
| 191 |
+
|
| 192 |
+
def compute_level3_loss(
|
| 193 |
+
self,
|
| 194 |
+
model: HierarchicalFlowMatchingV4,
|
| 195 |
+
real_traffic: torch.Tensor,
|
| 196 |
+
spatial_cond_level3: torch.Tensor,
|
| 197 |
+
spatial_cond_level2: torch.Tensor,
|
| 198 |
+
spatial_cond_level1: torch.Tensor,
|
| 199 |
+
peak_hour_gt: torch.Tensor, # Explicit Peak GT
|
| 200 |
+
daily_pattern: Optional[torch.Tensor] = None,
|
| 201 |
+
weekly_trend: Optional[torch.Tensor] = None,
|
| 202 |
+
use_teacher_forcing: bool = True,
|
| 203 |
+
n_steps_generate: int = 10,
|
| 204 |
+
) -> Tuple[torch.Tensor, torch.Tensor]: # [MODIFIED] Returns tuple
|
| 205 |
+
"""
|
| 206 |
+
Level 3 (Residual over 672 hours) Flow Matching loss.
|
| 207 |
+
Models fine-grained hourly fluctuations after removing periodic trends.
|
| 208 |
+
"""
|
| 209 |
+
B = real_traffic.shape[0]
|
| 210 |
+
device = real_traffic.device
|
| 211 |
+
|
| 212 |
+
steps_per_day = 24
|
| 213 |
+
n_days = 28
|
| 214 |
+
n_weeks = 4
|
| 215 |
+
real_reshaped = real_traffic.reshape(B, n_days, steps_per_day) # [B, 28, 24]
|
| 216 |
+
|
| 217 |
+
# Ground-truth weekly pattern (168)
|
| 218 |
+
weekly_days = []
|
| 219 |
+
for dow in range(7):
|
| 220 |
+
idx = torch.tensor([dow + 7 * w for w in range(n_weeks)], device=device, dtype=torch.long)
|
| 221 |
+
weekly_days.append(real_reshaped.index_select(1, idx).mean(dim=1)) # [B, 24]
|
| 222 |
+
weekly_pattern_gt = torch.stack(weekly_days, dim=1).reshape(B, 7 * steps_per_day) # [B, 168]
|
| 223 |
+
|
| 224 |
+
# Get daily pattern and weekly trend
|
| 225 |
+
if use_teacher_forcing:
|
| 226 |
+
# Teacher forcing
|
| 227 |
+
day_of_week = torch.arange(n_days, device=device) % 7
|
| 228 |
+
weekday_idx = torch.where(day_of_week < 5)[0]
|
| 229 |
+
weekend_idx = torch.where(day_of_week >= 5)[0]
|
| 230 |
+
weekday_pattern = real_reshaped.index_select(1, weekday_idx).mean(dim=1) # [B, 24]
|
| 231 |
+
weekend_pattern = real_reshaped.index_select(1, weekend_idx).mean(dim=1) # [B, 24]
|
| 232 |
+
daily_pattern = torch.cat([weekday_pattern, weekend_pattern], dim=1) # [B, 48]
|
| 233 |
+
weekly_trend = weekly_pattern_gt
|
| 234 |
+
else:
|
| 235 |
+
if daily_pattern is None:
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
daily_pattern = model.generate_daily_pattern(
|
| 238 |
+
spatial_cond_level1, n_steps=n_steps_generate
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if weekly_trend is None:
|
| 242 |
+
with torch.no_grad():
|
| 243 |
+
weekly_trend = model.generate_weekly_trend(
|
| 244 |
+
daily_pattern, spatial_cond_level2, n_steps=n_steps_generate
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Construct periodic component (coarse signal) from weekly pattern
|
| 248 |
+
coarse_signal = weekly_trend.repeat(1, n_weeks) # [B, 672]
|
| 249 |
+
|
| 250 |
+
# Target residual
|
| 251 |
+
x1 = real_traffic - coarse_signal # [B, 672]
|
| 252 |
+
|
| 253 |
+
# Sample noise
|
| 254 |
+
x0 = 0.1 * torch.randn_like(x1)
|
| 255 |
+
|
| 256 |
+
# Sample time
|
| 257 |
+
t = torch.rand(B, 1, device=device)
|
| 258 |
+
|
| 259 |
+
# Interpolation
|
| 260 |
+
x_t = t * x1 + (1 - t) * x0
|
| 261 |
+
|
| 262 |
+
# Target velocity
|
| 263 |
+
v_target = x1 - x0
|
| 264 |
+
|
| 265 |
+
# Predict velocity
|
| 266 |
+
# Pass peak_hour_gt to model
|
| 267 |
+
v_pred = model(
|
| 268 |
+
x_t, t, spatial_cond_level3, level=3,
|
| 269 |
+
daily_pattern=daily_pattern,
|
| 270 |
+
weekly_trend=weekly_trend,
|
| 271 |
+
coarse_signal=coarse_signal,
|
| 272 |
+
peak_hour=peak_hour_gt
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Flow Matching loss
|
| 276 |
+
loss_fm = F.mse_loss(v_pred, v_target)
|
| 277 |
+
|
| 278 |
+
# Boundary Loss for Level 3
|
| 279 |
+
# Ensure that (Coarse + Residual) >= 0
|
| 280 |
+
residual_est = x_t + (1 - t) * v_pred
|
| 281 |
+
final_traffic_est = coarse_signal + residual_est
|
| 282 |
+
loss_boundary = self.compute_boundary_loss(final_traffic_est)
|
| 283 |
+
|
| 284 |
+
return loss_fm, loss_boundary
|
| 285 |
+
|
| 286 |
+
def compute_temporal_structure_loss(
|
| 287 |
+
self,
|
| 288 |
+
generated: torch.Tensor,
|
| 289 |
+
real: torch.Tensor,
|
| 290 |
+
) -> torch.Tensor:
|
| 291 |
+
"""
|
| 292 |
+
Temporal structure preservation loss.
|
| 293 |
+
"""
|
| 294 |
+
d_gen = generated[..., 1:] - generated[..., :-1]
|
| 295 |
+
d_real = real[..., 1:] - real[..., :-1]
|
| 296 |
+
loss_deriv = F.mse_loss(d_gen, d_real)
|
| 297 |
+
return loss_deriv
|
| 298 |
+
|
| 299 |
+
def compute_multi_periodic_consistency_loss(
|
| 300 |
+
self,
|
| 301 |
+
generated: torch.Tensor,
|
| 302 |
+
real: torch.Tensor,
|
| 303 |
+
) -> torch.Tensor:
|
| 304 |
+
"""
|
| 305 |
+
Multi-periodic consistency loss.
|
| 306 |
+
"""
|
| 307 |
+
B = generated.shape[0]
|
| 308 |
+
device = generated.device
|
| 309 |
+
|
| 310 |
+
steps_per_day = 24
|
| 311 |
+
n_days = 28
|
| 312 |
+
n_weeks = 4
|
| 313 |
+
|
| 314 |
+
gen_days = generated.reshape(B, n_days, steps_per_day) # [B, 28, 24]
|
| 315 |
+
real_days = real.reshape(B, n_days, steps_per_day)
|
| 316 |
+
|
| 317 |
+
# Daily mean pattern
|
| 318 |
+
gen_daily = gen_days.mean(dim=1)
|
| 319 |
+
real_daily = real_days.mean(dim=1)
|
| 320 |
+
loss_daily = F.mse_loss(gen_daily, real_daily)
|
| 321 |
+
|
| 322 |
+
# Weekly pattern
|
| 323 |
+
def weekly_pattern(x_days: torch.Tensor) -> torch.Tensor:
|
| 324 |
+
days = []
|
| 325 |
+
for dow in range(7):
|
| 326 |
+
idx = torch.tensor([dow + 7 * w for w in range(n_weeks)], device=device, dtype=torch.long)
|
| 327 |
+
days.append(x_days.index_select(1, idx).mean(dim=1)) # [B, 24]
|
| 328 |
+
return torch.stack(days, dim=1).reshape(B, 7 * steps_per_day)
|
| 329 |
+
|
| 330 |
+
gen_weekly = weekly_pattern(gen_days)
|
| 331 |
+
real_weekly = weekly_pattern(real_days)
|
| 332 |
+
loss_weekly = F.mse_loss(gen_weekly, real_weekly)
|
| 333 |
+
|
| 334 |
+
return loss_daily + loss_weekly
|
| 335 |
+
|
| 336 |
+
# Pearson Correlation Loss
|
| 337 |
+
def compute_correlation_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 338 |
+
"""
|
| 339 |
+
Loss = 1 - Correlation. 强迫模型优化波形形状。
|
| 340 |
+
"""
|
| 341 |
+
# 1. Center the data
|
| 342 |
+
pred_mean = pred - pred.mean(dim=1, keepdim=True)
|
| 343 |
+
target_mean = target - target.mean(dim=1, keepdim=True)
|
| 344 |
+
|
| 345 |
+
# 2. Normalize
|
| 346 |
+
pred_norm = torch.norm(pred_mean, p=2, dim=1) + 1e-8
|
| 347 |
+
target_norm = torch.norm(target_mean, p=2, dim=1) + 1e-8
|
| 348 |
+
|
| 349 |
+
# 3. Calculate cosine similarity (i.e., correlation after mean-shifting)
|
| 350 |
+
cosine_sim = (pred_mean * target_mean).sum(dim=1) / (pred_norm * target_norm)
|
| 351 |
+
|
| 352 |
+
# 4. Loss = 1 - Correlation
|
| 353 |
+
return 1.0 - cosine_sim.mean()
|
| 354 |
+
|
| 355 |
+
def forward(
|
| 356 |
+
self,
|
| 357 |
+
model: HierarchicalFlowMatchingV4,
|
| 358 |
+
real_traffic: torch.Tensor,
|
| 359 |
+
spatial_cond: Dict[str, torch.Tensor] | torch.Tensor,
|
| 360 |
+
fusion_method: str = 'generative',
|
| 361 |
+
lambda_level1: float = 1.0,
|
| 362 |
+
lambda_level2: float = 1.0,
|
| 363 |
+
lambda_level3: float = 1.0,
|
| 364 |
+
lambda_temporal: float = 0.1,
|
| 365 |
+
lambda_periodic: float = 0.1,
|
| 366 |
+
lambda_corr: float = 0.5,
|
| 367 |
+
lambda_boundary: float = 1.0,
|
| 368 |
+
lambda_bias: float = 1.0,
|
| 369 |
+
teacher_forcing_ratio: float = 1.0,
|
| 370 |
+
n_steps_generate: int = 10,
|
| 371 |
+
**kwargs
|
| 372 |
+
) -> Dict[str, torch.Tensor]:
|
| 373 |
+
"""
|
| 374 |
+
Compute combined hierarchical loss.
|
| 375 |
+
"""
|
| 376 |
+
if isinstance(spatial_cond, torch.Tensor):
|
| 377 |
+
# Compatibility fallback
|
| 378 |
+
spatial_cond = {
|
| 379 |
+
'level1_cond': spatial_cond,
|
| 380 |
+
'level2_cond': spatial_cond,
|
| 381 |
+
'level3_cond': spatial_cond,
|
| 382 |
+
'pred_peak_logits': None
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
# ---------------------------------------------------------------------
|
| 386 |
+
# 1. Derive Ground Truth Peak Hour
|
| 387 |
+
# ---------------------------------------------------------------------
|
| 388 |
+
# Reshape to [B, days, 24] -> mean daily pattern -> argmax
|
| 389 |
+
B = real_traffic.shape[0]
|
| 390 |
+
avg_daily = real_traffic.reshape(B, -1, 24).mean(dim=1)
|
| 391 |
+
peak_hour_gt = avg_daily.argmax(dim=1) # [B] (0-23)
|
| 392 |
+
|
| 393 |
+
# 2. Auxiliary Classification Loss
|
| 394 |
+
pred_peak_logits = spatial_cond.get('pred_peak_logits', None)
|
| 395 |
+
if pred_peak_logits is not None:
|
| 396 |
+
loss_peak_cls = F.cross_entropy(pred_peak_logits, peak_hour_gt)
|
| 397 |
+
else:
|
| 398 |
+
loss_peak_cls = torch.tensor(0.0, device=real_traffic.device)
|
| 399 |
+
|
| 400 |
+
# Determine teacher forcing
|
| 401 |
+
use_tf = torch.rand(1).item() < teacher_forcing_ratio
|
| 402 |
+
|
| 403 |
+
# ---------------------------------------------------------------------
|
| 404 |
+
# 3. Compute Level Losses (FM + Boundary)
|
| 405 |
+
# ---------------------------------------------------------------------
|
| 406 |
+
loss_l1_fm, loss_l1_bound = self.compute_level1_loss(
|
| 407 |
+
model, real_traffic, spatial_cond['level1_cond']
|
| 408 |
+
)
|
| 409 |
+
loss_l2_fm, loss_l2_bound = self.compute_level2_loss(
|
| 410 |
+
model,
|
| 411 |
+
real_traffic,
|
| 412 |
+
spatial_cond_level2=spatial_cond['level2_cond'],
|
| 413 |
+
spatial_cond_level1=spatial_cond['level1_cond'],
|
| 414 |
+
use_teacher_forcing=use_tf,
|
| 415 |
+
n_steps_generate=n_steps_generate,
|
| 416 |
+
)
|
| 417 |
+
# Pass peak_hour_gt
|
| 418 |
+
loss_l3_fm, loss_l3_bound = self.compute_level3_loss(
|
| 419 |
+
model,
|
| 420 |
+
real_traffic,
|
| 421 |
+
spatial_cond_level3=spatial_cond['level3_cond'],
|
| 422 |
+
spatial_cond_level2=spatial_cond['level2_cond'],
|
| 423 |
+
spatial_cond_level1=spatial_cond['level1_cond'],
|
| 424 |
+
peak_hour_gt=peak_hour_gt, # Explicit GT
|
| 425 |
+
use_teacher_forcing=use_tf,
|
| 426 |
+
n_steps_generate=n_steps_generate,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# ---------------------------------------------------------------------
|
| 430 |
+
# 4. Aux Losses (Temporal, Periodic, Bias)
|
| 431 |
+
# ---------------------------------------------------------------------
|
| 432 |
+
loss_temporal = torch.tensor(0.0, device=real_traffic.device)
|
| 433 |
+
loss_periodic = torch.tensor(0.0, device=real_traffic.device)
|
| 434 |
+
loss_bias = torch.tensor(0.0, device=real_traffic.device)
|
| 435 |
+
loss_corr = torch.tensor(0.0, device=real_traffic.device)
|
| 436 |
+
|
| 437 |
+
# Increase Generation Sampling Frequency
|
| 438 |
+
# If lambda_corr is significant (>0.1), we increase the sampling probability from 30% to 60%
|
| 439 |
+
# This enables more frequent computation of the Correlation Loss
|
| 440 |
+
prob_threshold = 0.6 if lambda_corr > 0.1 else 0.3
|
| 441 |
+
|
| 442 |
+
# Only compute generation-based losses occasionally to save time
|
| 443 |
+
should_compute_gen_losses = (lambda_temporal > 0 or lambda_periodic > 0 or lambda_bias > 0 or lambda_corr > 0)
|
| 444 |
+
|
| 445 |
+
if should_compute_gen_losses and torch.rand(1).item() < 0.3:
|
| 446 |
+
with torch.no_grad():
|
| 447 |
+
# Must provide peak_hour for generation
|
| 448 |
+
generated, _ = model.generate_hierarchical(
|
| 449 |
+
spatial_cond,
|
| 450 |
+
peak_hour=peak_hour_gt,
|
| 451 |
+
n_steps_per_level=n_steps_generate
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if lambda_corr > 0:
|
| 455 |
+
loss_corr = self.compute_correlation_loss(generated, real_traffic)
|
| 456 |
+
|
| 457 |
+
if lambda_temporal > 0:
|
| 458 |
+
loss_temporal = self.compute_temporal_structure_loss(generated, real_traffic)
|
| 459 |
+
|
| 460 |
+
if lambda_periodic > 0:
|
| 461 |
+
loss_periodic = self.compute_multi_periodic_consistency_loss(generated, real_traffic)
|
| 462 |
+
|
| 463 |
+
if lambda_bias > 0:
|
| 464 |
+
loss_bias = self.compute_bias_loss(generated, real_traffic)
|
| 465 |
+
|
| 466 |
+
# ---------------------------------------------------------------------
|
| 467 |
+
# 5. Combined Loss
|
| 468 |
+
# ---------------------------------------------------------------------
|
| 469 |
+
|
| 470 |
+
# FM Loss
|
| 471 |
+
loss_fm_total = (
|
| 472 |
+
lambda_level1 * loss_l1_fm +
|
| 473 |
+
lambda_level2 * loss_l2_fm +
|
| 474 |
+
lambda_level3 * loss_l3_fm
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Boundary Loss
|
| 478 |
+
loss_boundary_total = lambda_boundary * (loss_l1_bound + loss_l2_bound + loss_l3_bound)
|
| 479 |
+
|
| 480 |
+
# Bias Loss
|
| 481 |
+
loss_bias_total = lambda_bias * loss_bias
|
| 482 |
+
|
| 483 |
+
# Peak Classification Weight (static 0.5 for now)
|
| 484 |
+
lambda_peak = 5.0
|
| 485 |
+
|
| 486 |
+
total_loss = (
|
| 487 |
+
loss_fm_total +
|
| 488 |
+
loss_boundary_total +
|
| 489 |
+
loss_bias_total +
|
| 490 |
+
lambda_temporal * loss_temporal +
|
| 491 |
+
lambda_periodic * loss_periodic +
|
| 492 |
+
lambda_peak * loss_peak_cls +
|
| 493 |
+
lambda_corr * loss_corr # 加入总 Loss
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
return {
|
| 497 |
+
'loss_level1': loss_l1_fm,
|
| 498 |
+
'loss_level2': loss_l2_fm,
|
| 499 |
+
'loss_level3': loss_l3_fm,
|
| 500 |
+
'loss_boundary': loss_boundary_total,
|
| 501 |
+
'loss_bias': loss_bias_total,
|
| 502 |
+
'loss_temporal': loss_temporal,
|
| 503 |
+
'loss_periodic': loss_periodic,
|
| 504 |
+
'loss_peak_cls': loss_peak_cls,
|
| 505 |
+
'loss_corr': loss_corr,
|
| 506 |
+
'loss_total': total_loss,
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
# =============================================================================
|
| 513 |
+
# Complete Hierarchical Flow Matching Model with Encoder
|
| 514 |
+
# =============================================================================
|
| 515 |
+
|
| 516 |
+
class HierarchicalFlowMatchingSystemV4(nn.Module):
|
| 517 |
+
"""
|
| 518 |
+
Complete system combining:
|
| 519 |
+
- Multi-modal spatial encoder
|
| 520 |
+
- Hierarchical Flow Matching model
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
def __init__(
|
| 524 |
+
self,
|
| 525 |
+
spatial_dim: int = 192,
|
| 526 |
+
hidden_dim: int = 256,
|
| 527 |
+
poi_dim: int = 20,
|
| 528 |
+
n_layers_level3: int = 6,
|
| 529 |
+
fusion_method: Literal['generative', 'contrastive'] = 'generative' # Default
|
| 530 |
+
):
|
| 531 |
+
super().__init__()
|
| 532 |
+
self.fusion_method = fusion_method
|
| 533 |
+
self.spatial_dim = spatial_dim
|
| 534 |
+
|
| 535 |
+
# 1. Environment Encoder (Multi-modal)
|
| 536 |
+
self.spatial_encoder = MultiModalSpatialEncoderV4(spatial_dim, poi_dim)
|
| 537 |
+
|
| 538 |
+
# NOTE: No TrafficCLIPEncoder in Generative Mode
|
| 539 |
+
self.traffic_encoder = None
|
| 540 |
+
|
| 541 |
+
# 2. Flow Matching Generative Model
|
| 542 |
+
self.fm_model = HierarchicalFlowMatchingV4(spatial_dim, hidden_dim, n_layers_level3)
|
| 543 |
+
|
| 544 |
+
# 3. Loss
|
| 545 |
+
self.loss_fn = HierarchicalFlowMatchingLoss()
|
| 546 |
+
|
| 547 |
+
def forward(self, batch: Dict, mode: str = 'train', loss_cfg: Optional[Dict] = None) -> Dict:
|
| 548 |
+
"""
|
| 549 |
+
Args:
|
| 550 |
+
batch: dict with spatial and traffic data
|
| 551 |
+
mode: 'train' or 'generate'
|
| 552 |
+
Returns:
|
| 553 |
+
outputs: dict with losses or generated samples
|
| 554 |
+
"""
|
| 555 |
+
# Encode spatial features
|
| 556 |
+
spatial_cond_dict = self.spatial_encoder(batch)
|
| 557 |
+
loss_cfg = loss_cfg or {}
|
| 558 |
+
|
| 559 |
+
if mode == 'train':
|
| 560 |
+
real_traffic = batch['traffic_seq']
|
| 561 |
+
|
| 562 |
+
# Calculate Losses
|
| 563 |
+
losses = self.loss_fn(
|
| 564 |
+
model=self.fm_model,
|
| 565 |
+
real_traffic=real_traffic,
|
| 566 |
+
spatial_cond=spatial_cond_dict,
|
| 567 |
+
fusion_method=self.fusion_method,
|
| 568 |
+
**loss_cfg
|
| 569 |
+
)
|
| 570 |
+
return {'losses': losses}
|
| 571 |
+
|
| 572 |
+
elif mode == 'generate':
|
| 573 |
+
# Inference logic: Explicit Peak Conditioning
|
| 574 |
+
# 1. Use the auxiliary head to predict peak location
|
| 575 |
+
pred_logits = spatial_cond_dict['pred_peak_logits']
|
| 576 |
+
pred_peak_hour = pred_logits.argmax(dim=1) # [B]
|
| 577 |
+
|
| 578 |
+
# 2. Allow manual override if 'manual_peak_hour' is in batch
|
| 579 |
+
if 'manual_peak_hour' in batch:
|
| 580 |
+
pred_peak_hour = batch['manual_peak_hour']
|
| 581 |
+
|
| 582 |
+
# Generate hierarchical samples
|
| 583 |
+
generated, intermediates = self.fm_model.generate_hierarchical(
|
| 584 |
+
spatial_cond_dict,
|
| 585 |
+
peak_hour=pred_peak_hour,
|
| 586 |
+
n_steps_per_level=loss_cfg.get('n_steps_generate', 50),
|
| 587 |
+
)
|
| 588 |
+
return {'generated': generated, 'intermediates': intermediates, 'pred_peak_hour': pred_peak_hour}
|
| 589 |
+
|
| 590 |
+
else:
|
| 591 |
+
raise ValueError(f"Unknown mode: {mode}")
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# =============================================================================
|
| 595 |
+
# Trainer
|
| 596 |
+
# =============================================================================
|
| 597 |
+
|
| 598 |
+
class HierarchicalFlowMatchingTrainerV4:
|
| 599 |
+
"""
|
| 600 |
+
Trainer for Hierarchical Flow Matching V4.
|
| 601 |
+
"""
|
| 602 |
+
|
| 603 |
+
def __init__(
|
| 604 |
+
self,
|
| 605 |
+
model: HierarchicalFlowMatchingSystemV4,
|
| 606 |
+
train_loader: DataLoader,
|
| 607 |
+
val_loader: DataLoader,
|
| 608 |
+
lr: float = 1e-4,
|
| 609 |
+
weight_decay: float = 0.01,
|
| 610 |
+
checkpoint_dir: str = "checkpoints_hfm_v4",
|
| 611 |
+
lambda_level1: float = 1.0,
|
| 612 |
+
lambda_level2: float = 1.0,
|
| 613 |
+
lambda_level3: float = 1.0,
|
| 614 |
+
lambda_temporal: float = 0.1,
|
| 615 |
+
lambda_periodic: float = 0.1,
|
| 616 |
+
lambda_boundary: float = 1.0,
|
| 617 |
+
lambda_bias: float = 1.0,
|
| 618 |
+
lambda_corr: float = 0.5,
|
| 619 |
+
warmup_epochs: int = 5,
|
| 620 |
+
):
|
| 621 |
+
self.model = model
|
| 622 |
+
self.train_loader = train_loader
|
| 623 |
+
self.val_loader = val_loader
|
| 624 |
+
self.checkpoint_dir = checkpoint_dir
|
| 625 |
+
|
| 626 |
+
# Loss weights
|
| 627 |
+
self.loss_cfg = {
|
| 628 |
+
'lambda_level1': lambda_level1,
|
| 629 |
+
'lambda_level2': lambda_level2,
|
| 630 |
+
'lambda_level3': lambda_level3,
|
| 631 |
+
'lambda_temporal': lambda_temporal,
|
| 632 |
+
'lambda_periodic': lambda_periodic,
|
| 633 |
+
'lambda_boundary': lambda_boundary,
|
| 634 |
+
'lambda_bias': lambda_bias,
|
| 635 |
+
'lambda_corr': lambda_corr,
|
| 636 |
+
'teacher_forcing_ratio': 1.0,
|
| 637 |
+
'n_steps_generate': 10
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
# Warmup
|
| 641 |
+
self.warmup_epochs = warmup_epochs
|
| 642 |
+
self.base_lr = lr
|
| 643 |
+
|
| 644 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 645 |
+
self.model = self.model.to(self.device)
|
| 646 |
+
|
| 647 |
+
# Optimizer
|
| 648 |
+
self.optimizer = torch.optim.AdamW(
|
| 649 |
+
self.model.parameters(),
|
| 650 |
+
lr=lr,
|
| 651 |
+
weight_decay=weight_decay,
|
| 652 |
+
betas=(0.9, 0.99),
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 656 |
+
|
| 657 |
+
# Training history
|
| 658 |
+
self.history = {
|
| 659 |
+
'train_loss': [],
|
| 660 |
+
'val_loss': [],
|
| 661 |
+
'train_loss_level1': [],
|
| 662 |
+
'train_loss_level2': [],
|
| 663 |
+
'train_loss_level3': [],
|
| 664 |
+
'train_loss_temporal': [],
|
| 665 |
+
'train_loss_periodic': [],
|
| 666 |
+
'train_loss_peak_cls': [],
|
| 667 |
+
'train_loss_boundary': [],
|
| 668 |
+
'train_loss_bias': [],
|
| 669 |
+
'train_loss_corr': [],
|
| 670 |
+
'val_mae': [],
|
| 671 |
+
'val_corr': [],
|
| 672 |
+
'val_var_ratio': [],
|
| 673 |
+
'lr': [],
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
def get_lr_scale(self, epoch: int, total_epochs: int) -> float:
|
| 677 |
+
"""Get learning rate scale with warmup and cosine decay."""
|
| 678 |
+
if epoch < self.warmup_epochs:
|
| 679 |
+
return (epoch + 1) / self.warmup_epochs
|
| 680 |
+
else:
|
| 681 |
+
progress = (epoch - self.warmup_epochs) / (total_epochs - self.warmup_epochs)
|
| 682 |
+
return 0.5 * (1 + np.cos(np.pi * progress))
|
| 683 |
+
|
| 684 |
+
def set_lr(self, scale: float):
|
| 685 |
+
"""Set learning rate."""
|
| 686 |
+
for param_group in self.optimizer.param_groups:
|
| 687 |
+
param_group['lr'] = self.base_lr * scale
|
| 688 |
+
|
| 689 |
+
def train_epoch(self, epoch: int, total_epochs: int) -> Dict[str, float]:
|
| 690 |
+
"""Train one epoch."""
|
| 691 |
+
self.model.train()
|
| 692 |
+
|
| 693 |
+
# Set learning rate
|
| 694 |
+
lr_scale = self.get_lr_scale(epoch, total_epochs)
|
| 695 |
+
self.set_lr(lr_scale)
|
| 696 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 697 |
+
|
| 698 |
+
# Teacher forcing ratio
|
| 699 |
+
self.loss_cfg['teacher_forcing_ratio'] = max(0.5, 1.0 - epoch / (2 * total_epochs))
|
| 700 |
+
|
| 701 |
+
total_loss = 0.0
|
| 702 |
+
loss_level1 = 0.0
|
| 703 |
+
loss_level2 = 0.0
|
| 704 |
+
loss_level3 = 0.0
|
| 705 |
+
loss_temporal = 0.0
|
| 706 |
+
loss_periodic = 0.0
|
| 707 |
+
loss_peak_cls = 0.0
|
| 708 |
+
loss_boundary = 0.0
|
| 709 |
+
loss_bias = 0.0
|
| 710 |
+
loss_corr = 0.0
|
| 711 |
+
|
| 712 |
+
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1} [Train]")
|
| 713 |
+
for batch in pbar:
|
| 714 |
+
batch = {
|
| 715 |
+
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
| 716 |
+
for k, v in batch.items()
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
# Forward pass
|
| 720 |
+
self.optimizer.zero_grad()
|
| 721 |
+
output = self.model(
|
| 722 |
+
batch,
|
| 723 |
+
mode='train',
|
| 724 |
+
loss_cfg=self.loss_cfg,
|
| 725 |
+
)
|
| 726 |
+
losses = output['losses']
|
| 727 |
+
|
| 728 |
+
# Total loss
|
| 729 |
+
total_batch_loss = losses['loss_total']
|
| 730 |
+
|
| 731 |
+
# Backward pass
|
| 732 |
+
total_batch_loss.backward()
|
| 733 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 734 |
+
self.optimizer.step()
|
| 735 |
+
|
| 736 |
+
# Accumulate losses
|
| 737 |
+
total_loss += total_batch_loss.item()
|
| 738 |
+
loss_level1 += losses['loss_level1'].item()
|
| 739 |
+
loss_level2 += losses['loss_level2'].item()
|
| 740 |
+
loss_level3 += losses['loss_level3'].item()
|
| 741 |
+
loss_temporal += losses.get('loss_temporal', torch.tensor(0.0)).item()
|
| 742 |
+
loss_periodic += losses.get('loss_periodic', torch.tensor(0.0)).item()
|
| 743 |
+
loss_peak_cls += losses.get('loss_peak_cls', torch.tensor(0.0)).item()
|
| 744 |
+
loss_boundary += losses.get('loss_boundary', torch.tensor(0.0)).item()
|
| 745 |
+
loss_bias += losses.get('loss_bias', torch.tensor(0.0)).item()
|
| 746 |
+
loss_corr += losses.get('loss_corr', torch.tensor(0.0)).item()
|
| 747 |
+
|
| 748 |
+
# Update progress bar
|
| 749 |
+
pbar.set_postfix({
|
| 750 |
+
'loss': total_loss / (len(pbar) + 1),
|
| 751 |
+
'corr': loss_corr / (len(pbar) + 1),
|
| 752 |
+
'peak': loss_peak_cls / (len(pbar) + 1),
|
| 753 |
+
'bnd': loss_boundary / (len(pbar) + 1),
|
| 754 |
+
'bias': loss_bias / (len(pbar) + 1),
|
| 755 |
+
'lr': f'{current_lr:.2e}',
|
| 756 |
+
})
|
| 757 |
+
|
| 758 |
+
n_batches = len(self.train_loader)
|
| 759 |
+
return {
|
| 760 |
+
'loss_total': total_loss / n_batches,
|
| 761 |
+
'loss_level1': loss_level1 / n_batches,
|
| 762 |
+
'loss_level2': loss_level2 / n_batches,
|
| 763 |
+
'loss_level3': loss_level3 / n_batches,
|
| 764 |
+
'loss_temporal': loss_temporal / n_batches,
|
| 765 |
+
'loss_periodic': loss_periodic / n_batches,
|
| 766 |
+
'loss_peak_cls': loss_peak_cls / n_batches,
|
| 767 |
+
'loss_boundary': loss_boundary / n_batches,
|
| 768 |
+
'loss_bias': loss_bias / n_batches,
|
| 769 |
+
'loss_corr': loss_corr / n_batches,
|
| 770 |
+
'lr': current_lr,
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
@torch.no_grad()
|
| 774 |
+
def validate(self, epoch: int) -> Dict[str, float]:
|
| 775 |
+
"""Validate."""
|
| 776 |
+
self.model.eval()
|
| 777 |
+
|
| 778 |
+
total_loss = 0.0
|
| 779 |
+
all_mae = []
|
| 780 |
+
all_corr = []
|
| 781 |
+
all_var_ratio = []
|
| 782 |
+
|
| 783 |
+
pbar = tqdm(self.val_loader, desc=f"Epoch {epoch + 1} [Val]")
|
| 784 |
+
for batch in pbar:
|
| 785 |
+
batch = {
|
| 786 |
+
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
| 787 |
+
for k, v in batch.items()
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
# Loss
|
| 791 |
+
output = self.model(
|
| 792 |
+
batch,
|
| 793 |
+
mode='train',
|
| 794 |
+
loss_cfg=self.loss_cfg,
|
| 795 |
+
)
|
| 796 |
+
losses = output['losses']
|
| 797 |
+
total_loss += losses['loss_total'].item()
|
| 798 |
+
|
| 799 |
+
# Generate samples
|
| 800 |
+
# Note: generate now internally handles peak_hour logic in System.forward
|
| 801 |
+
gen_output = self.model(
|
| 802 |
+
batch,
|
| 803 |
+
mode='generate',
|
| 804 |
+
loss_cfg={'n_steps_generate': 50},
|
| 805 |
+
)
|
| 806 |
+
real = batch['traffic_seq'].cpu().numpy()
|
| 807 |
+
generated = gen_output['generated'].cpu().numpy()
|
| 808 |
+
|
| 809 |
+
# Metrics
|
| 810 |
+
mae = np.mean(np.abs(real - generated))
|
| 811 |
+
all_mae.append(mae)
|
| 812 |
+
|
| 813 |
+
# Variance ratio
|
| 814 |
+
real_var = np.var(real, axis=1).mean()
|
| 815 |
+
gen_var = np.var(generated, axis=1).mean()
|
| 816 |
+
var_ratio = gen_var / (real_var + 1e-8)
|
| 817 |
+
all_var_ratio.append(var_ratio)
|
| 818 |
+
|
| 819 |
+
# Correlation
|
| 820 |
+
for i in range(len(real)):
|
| 821 |
+
r_std = np.std(real[i])
|
| 822 |
+
g_std = np.std(generated[i])
|
| 823 |
+
if r_std > 1e-6 and g_std > 1e-6:
|
| 824 |
+
corr = np.corrcoef(real[i], generated[i])[0, 1]
|
| 825 |
+
if not np.isnan(corr):
|
| 826 |
+
all_corr.append(corr)
|
| 827 |
+
|
| 828 |
+
n_batches = len(self.val_loader)
|
| 829 |
+
return {
|
| 830 |
+
'loss_total': total_loss / n_batches,
|
| 831 |
+
'mae': np.mean(all_mae),
|
| 832 |
+
'correlation': np.mean(all_corr) if all_corr else 0.0,
|
| 833 |
+
'var_ratio': np.mean(all_var_ratio),
|
| 834 |
+
}
|
| 835 |
+
|
| 836 |
+
def train(self, epochs: int):
|
| 837 |
+
"""Full training loop."""
|
| 838 |
+
print("=" * 80)
|
| 839 |
+
print("Hierarchical Flow Matching V4 - Training")
|
| 840 |
+
print(f"Fusion Method: {self.model.fusion_method}")
|
| 841 |
+
print("=" * 80)
|
| 842 |
+
print(f"Device: {self.device}")
|
| 843 |
+
print(f"Epochs: {epochs}")
|
| 844 |
+
print(f"Base learning rate: {self.base_lr:.2e}")
|
| 845 |
+
print("=" * 80)
|
| 846 |
+
|
| 847 |
+
# [修改 1] 初始化两个最佳指标跟踪变量
|
| 848 |
+
best_val_loss = float('inf')
|
| 849 |
+
best_val_corr = -1.0 # 初始化相关性为 -1
|
| 850 |
+
|
| 851 |
+
for epoch in range(epochs):
|
| 852 |
+
# Train
|
| 853 |
+
train_losses = self.train_epoch(epoch, epochs)
|
| 854 |
+
|
| 855 |
+
# Validate
|
| 856 |
+
val_losses = self.validate(epoch)
|
| 857 |
+
|
| 858 |
+
# Print summary
|
| 859 |
+
print(f"\nEpoch {epoch + 1}/{epochs}")
|
| 860 |
+
print(f" Train Loss: {train_losses['loss_total']:.6f}")
|
| 861 |
+
print(f" Peak Cls Loss: {train_losses['loss_peak_cls']:.6f}")
|
| 862 |
+
print(f" Boundary Loss: {train_losses['loss_boundary']:.6f}")
|
| 863 |
+
print(f" Bias Loss: {train_losses['loss_bias']:.6f}")
|
| 864 |
+
print(f" Val Loss: {val_losses['loss_total']:.6f}")
|
| 865 |
+
print(f" Val MAE: {val_losses['mae']:.4f}")
|
| 866 |
+
print(f" Val Correlation: {val_losses['correlation']:.4f}")
|
| 867 |
+
print(f" Val Var Ratio: {val_losses['var_ratio']:.4f}")
|
| 868 |
+
|
| 869 |
+
# Save history
|
| 870 |
+
self.history['train_loss'].append(train_losses['loss_total'])
|
| 871 |
+
self.history['val_loss'].append(val_losses['loss_total'])
|
| 872 |
+
self.history['train_loss_level1'].append(train_losses['loss_level1'])
|
| 873 |
+
self.history['train_loss_level2'].append(train_losses['loss_level2'])
|
| 874 |
+
self.history['train_loss_level3'].append(train_losses['loss_level3'])
|
| 875 |
+
self.history['train_loss_temporal'].append(train_losses['loss_temporal'])
|
| 876 |
+
self.history['train_loss_periodic'].append(train_losses['loss_periodic'])
|
| 877 |
+
self.history['train_loss_peak_cls'].append(train_losses['loss_peak_cls'])
|
| 878 |
+
self.history['train_loss_boundary'].append(train_losses['loss_boundary'])
|
| 879 |
+
self.history['train_loss_bias'].append(train_losses['loss_bias'])
|
| 880 |
+
self.history['val_mae'].append(val_losses['mae'])
|
| 881 |
+
self.history['val_corr'].append(val_losses['correlation'])
|
| 882 |
+
self.history['val_var_ratio'].append(val_losses['var_ratio'])
|
| 883 |
+
self.history['lr'].append(train_losses['lr'])
|
| 884 |
+
|
| 885 |
+
# Logic A: Save the model with the lowest loss (as the mathematically optimal fallback)
|
| 886 |
+
if val_losses['loss_total'] < best_val_loss:
|
| 887 |
+
best_val_loss = val_losses['loss_total']
|
| 888 |
+
self.save_checkpoint(epoch, val_losses, filename='best_loss_model.pt')
|
| 889 |
+
print(f" ✓ [Best Loss model saved! Loss: {best_val_loss:.4f}]")
|
| 890 |
+
|
| 891 |
+
# Logic B: Save the model with the highest correlation (as the business-practical best)
|
| 892 |
+
if val_losses['correlation'] > best_val_corr:
|
| 893 |
+
best_val_corr = val_losses['correlation']
|
| 894 |
+
self.save_checkpoint(epoch, val_losses, filename='best_corr_model.pt')
|
| 895 |
+
print(f" ★ [Best Correlation model saved! Corr: {best_val_corr:.4f}]")
|
| 896 |
+
|
| 897 |
+
# Always save the latest version
|
| 898 |
+
self.save_checkpoint(epoch, val_losses, filename='latest_model.pt')
|
| 899 |
+
|
| 900 |
+
# Save history
|
| 901 |
+
self.save_history()
|
| 902 |
+
|
| 903 |
+
print("\n" + "=" * 80)
|
| 904 |
+
print("Training Completed!")
|
| 905 |
+
print(f"Best validation loss: {best_val_loss:.6f}")
|
| 906 |
+
print(f"Best validation corr: {best_val_corr:.4f}") # 打印最佳相关性
|
| 907 |
+
print(f"Checkpoints saved to: {self.checkpoint_dir}")
|
| 908 |
+
print("=" * 80)
|
| 909 |
+
|
| 910 |
+
def save_checkpoint(self, epoch: int, losses: Dict, filename: str = 'best_model.pt'):
|
| 911 |
+
"""Save checkpoint."""
|
| 912 |
+
checkpoint = {
|
| 913 |
+
'epoch': epoch,
|
| 914 |
+
'model_state_dict': self.model.state_dict(),
|
| 915 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 916 |
+
'losses': losses,
|
| 917 |
+
'history': self.history,
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
path = os.path.join(self.checkpoint_dir, filename)
|
| 921 |
+
torch.save(checkpoint, path)
|
| 922 |
+
|
| 923 |
+
def save_history(self):
|
| 924 |
+
"""Save training history."""
|
| 925 |
+
path = os.path.join(self.checkpoint_dir, 'training_history.json')
|
| 926 |
+
|
| 927 |
+
history_serializable = {}
|
| 928 |
+
for key, values in self.history.items():
|
| 929 |
+
history_serializable[key] = [
|
| 930 |
+
float(v) if isinstance(v, (np.floating, np.integer)) else v
|
| 931 |
+
for v in values
|
| 932 |
+
]
|
| 933 |
+
|
| 934 |
+
with open(path, 'w') as f:
|
| 935 |
+
json.dump(history_serializable, f, indent=2)
|
hierarchical_flow_matching_v4.py
ADDED
|
@@ -0,0 +1,1019 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hierarchical Flow Matching with Mamba/SSM Backbone (V4) - Generative Version
|
| 3 |
+
============================================================================
|
| 4 |
+
|
| 5 |
+
Architecture: Pure Diffusion/Flow Matching (No GANs).
|
| 6 |
+
Fusion Method: Generative (Implicit alignment via conditional generation).
|
| 7 |
+
|
| 8 |
+
Core improvements over V3:
|
| 9 |
+
1. Three-level cascaded Flow Matching architecture.
|
| 10 |
+
2. Multi-modal spatial context encoding.
|
| 11 |
+
3. Long-sequence modeling backbone.
|
| 12 |
+
4. Explicit Peak Conditioning.
|
| 13 |
+
5. Physical Constraints (Non-negative output enforced).
|
| 14 |
+
|
| 15 |
+
Author: Optimization Team
|
| 16 |
+
Date: 2026-01-21
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import math
|
| 22 |
+
import numpy as np
|
| 23 |
+
from typing import Dict, Optional, Tuple, List
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# =============================================================================
|
| 32 |
+
# Mamba/SSM Backbone for Long Sequence Modeling
|
| 33 |
+
# =============================================================================
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class MambaConfig:
|
| 37 |
+
"""Configuration for Mamba block."""
|
| 38 |
+
d_model: int = 256
|
| 39 |
+
d_state: int = 64
|
| 40 |
+
d_conv: int = 4
|
| 41 |
+
expand: int = 2
|
| 42 |
+
dt_rank: str = "auto"
|
| 43 |
+
dt_min: float = 0.001
|
| 44 |
+
dt_max: float = 0.1
|
| 45 |
+
dt_init: str = "random"
|
| 46 |
+
dt_scale: float = 1.0
|
| 47 |
+
dt_init_floor: float = 1e-4
|
| 48 |
+
bias: bool = True
|
| 49 |
+
conv_bias: bool = True
|
| 50 |
+
pscan: bool = True
|
| 51 |
+
use_cuda: bool = True
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _selective_scan_diagonal(
|
| 55 |
+
log_a: torch.Tensor, # [B, L, N]
|
| 56 |
+
b: torch.Tensor, # [B, L, N]
|
| 57 |
+
) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
Parallel (vectorized) diagonal linear recurrence:
|
| 60 |
+
h_t = a_t * h_{t-1} + b_t, h_{-1}=0
|
| 61 |
+
where a_t = exp(log_a_t), computed without Python loops.
|
| 62 |
+
"""
|
| 63 |
+
# log_p[t] = sum_{i<=t} log_a[i]
|
| 64 |
+
log_p = torch.cumsum(log_a, dim=1) # [B, L, N]
|
| 65 |
+
inv_p = torch.exp(-log_p)
|
| 66 |
+
s = torch.cumsum(b * inv_p, dim=1) # [B, L, N]
|
| 67 |
+
h = torch.exp(log_p) * s
|
| 68 |
+
return h
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Mamba(nn.Module):
|
| 72 |
+
"""
|
| 73 |
+
Mamba block for efficient long-sequence modeling.
|
| 74 |
+
|
| 75 |
+
Based on: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
|
| 76 |
+
Pure-PyTorch implementation (vectorized diagonal selective scan) for traffic
|
| 77 |
+
sequence generation (no external kernels / dependencies).
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, config: MambaConfig):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.config = config
|
| 83 |
+
|
| 84 |
+
d_model = config.d_model
|
| 85 |
+
d_state = config.d_state
|
| 86 |
+
d_conv = config.d_conv
|
| 87 |
+
expand = config.expand
|
| 88 |
+
|
| 89 |
+
self.d_inner = int(expand * d_model)
|
| 90 |
+
|
| 91 |
+
# (1) Input projection: x -> (u, gate)
|
| 92 |
+
self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=config.bias)
|
| 93 |
+
|
| 94 |
+
# (2) Depthwise conv for short-range mixing (Mamba-style local context)
|
| 95 |
+
self.dwconv = nn.Conv1d(
|
| 96 |
+
in_channels=self.d_inner,
|
| 97 |
+
out_channels=self.d_inner,
|
| 98 |
+
kernel_size=d_conv,
|
| 99 |
+
padding=d_conv - 1,
|
| 100 |
+
groups=self.d_inner,
|
| 101 |
+
bias=config.conv_bias,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# (3) Input-dependent SSM parameters (B, C, dt)
|
| 105 |
+
self.B_proj = nn.Linear(self.d_inner, d_state, bias=False)
|
| 106 |
+
self.C_proj = nn.Linear(self.d_inner, d_state, bias=False)
|
| 107 |
+
self.dt_proj = nn.Linear(self.d_inner, d_state, bias=True)
|
| 108 |
+
|
| 109 |
+
# Diagonal A (negative, stable)
|
| 110 |
+
self.A_log = nn.Parameter(torch.zeros(d_state))
|
| 111 |
+
|
| 112 |
+
# Skip connection from u (Mamba "D" term)
|
| 113 |
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
| 114 |
+
|
| 115 |
+
# (4) State -> inner -> model projections
|
| 116 |
+
self.out_state_proj = nn.Linear(d_state, self.d_inner, bias=False)
|
| 117 |
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=config.bias)
|
| 118 |
+
|
| 119 |
+
# Initialize FiLM-like stability: start close to identity
|
| 120 |
+
nn.init.zeros_(self.A_log)
|
| 121 |
+
nn.init.zeros_(self.dt_proj.weight)
|
| 122 |
+
nn.init.constant_(self.dt_proj.bias, math.log(math.expm1(0.01))) # softplus^-1
|
| 123 |
+
nn.init.zeros_(self.out_state_proj.weight)
|
| 124 |
+
|
| 125 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 126 |
+
"""
|
| 127 |
+
Args:
|
| 128 |
+
x: [B, L, D] input sequence
|
| 129 |
+
Returns:
|
| 130 |
+
y: [B, L, D] output sequence
|
| 131 |
+
"""
|
| 132 |
+
B, L, _ = x.shape
|
| 133 |
+
|
| 134 |
+
# Input projection
|
| 135 |
+
u, gate = self.in_proj(x).chunk(2, dim=-1) # [B, L, d_inner] each
|
| 136 |
+
|
| 137 |
+
# Depthwise conv (causal-ish via padding then crop)
|
| 138 |
+
u_conv = self.dwconv(u.transpose(1, 2))[:, :, :L].transpose(1, 2) # [B, L, d_inner]
|
| 139 |
+
u_conv = F.silu(u_conv)
|
| 140 |
+
|
| 141 |
+
# Input-dependent SSM params
|
| 142 |
+
dt = F.softplus(self.dt_proj(u_conv)) # [B, L, d_state]
|
| 143 |
+
dt = dt.clamp(min=self.config.dt_min, max=self.config.dt_max)
|
| 144 |
+
|
| 145 |
+
B_t = self.B_proj(u_conv) # [B, L, d_state]
|
| 146 |
+
C_t = self.C_proj(u_conv) # [B, L, d_state]
|
| 147 |
+
|
| 148 |
+
# Diagonal state transition: a_t = exp(A * dt)
|
| 149 |
+
A = -torch.exp(self.A_log).view(1, 1, -1) # [1, 1, d_state]
|
| 150 |
+
log_a = A * dt # [B, L, d_state]
|
| 151 |
+
b = B_t * dt # [B, L, d_state]
|
| 152 |
+
|
| 153 |
+
# Selective scan (vectorized)
|
| 154 |
+
h = _selective_scan_diagonal(log_a, b) # [B, L, d_state]
|
| 155 |
+
|
| 156 |
+
# Output from states
|
| 157 |
+
y_state = h * C_t
|
| 158 |
+
y_inner = self.out_state_proj(y_state) # [B, L, d_inner]
|
| 159 |
+
|
| 160 |
+
# Skip + gate (Mamba-style)
|
| 161 |
+
y_inner = y_inner + u_conv * self.D.view(1, 1, -1)
|
| 162 |
+
y_inner = y_inner * torch.sigmoid(gate)
|
| 163 |
+
|
| 164 |
+
return self.out_proj(y_inner)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# =============================================================================
|
| 168 |
+
# Multi-Scale Dilated Convolution Backbone
|
| 169 |
+
# =============================================================================
|
| 170 |
+
|
| 171 |
+
class MultiScaleDilatedConv(nn.Module):
|
| 172 |
+
"""
|
| 173 |
+
Multi-scale dilated convolution for capturing temporal patterns at different scales.
|
| 174 |
+
|
| 175 |
+
Receptive fields:
|
| 176 |
+
- Scale 1 (dilation=1): Daily patterns (24 hours)
|
| 177 |
+
- Scale 2 (example): Weekly patterns (7 days) -> hourly would be dilation=168
|
| 178 |
+
- Scale 3 (example): Longer cycles (e.g., 28 days) -> hourly would be dilation=672
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
channels: int,
|
| 184 |
+
kernel_size: int = 3,
|
| 185 |
+
dilations: Optional[List[int]] = None,
|
| 186 |
+
dropout: float = 0.0,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
if dilations is None:
|
| 190 |
+
dilations = [1, 4, 16]
|
| 191 |
+
self.channels = channels
|
| 192 |
+
self.kernel_size = kernel_size
|
| 193 |
+
self.dilations = [int(d) for d in dilations if int(d) >= 1]
|
| 194 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 195 |
+
|
| 196 |
+
padding_base = (kernel_size - 1) // 2
|
| 197 |
+
|
| 198 |
+
# Depthwise-separable conv branches
|
| 199 |
+
self.branches = nn.ModuleList()
|
| 200 |
+
for d in self.dilations:
|
| 201 |
+
self.branches.append(
|
| 202 |
+
nn.Sequential(
|
| 203 |
+
nn.Conv1d(
|
| 204 |
+
channels,
|
| 205 |
+
channels,
|
| 206 |
+
kernel_size,
|
| 207 |
+
dilation=d,
|
| 208 |
+
padding=padding_base * d,
|
| 209 |
+
groups=channels,
|
| 210 |
+
bias=True,
|
| 211 |
+
),
|
| 212 |
+
nn.GELU(),
|
| 213 |
+
nn.Conv1d(channels, channels, kernel_size=1, bias=True),
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Fusion (token-wise MLP)
|
| 218 |
+
self.fusion = nn.Sequential(
|
| 219 |
+
nn.Linear(channels * len(self.dilations), channels * 2),
|
| 220 |
+
nn.GELU(),
|
| 221 |
+
nn.Dropout(0.1),
|
| 222 |
+
nn.Linear(channels * 2, channels),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
"""
|
| 227 |
+
Args:
|
| 228 |
+
x: [B, L, C] input
|
| 229 |
+
Returns:
|
| 230 |
+
y: [B, L, C] output
|
| 231 |
+
"""
|
| 232 |
+
x_t = x.transpose(1, 2) # [B, C, L]
|
| 233 |
+
outs = []
|
| 234 |
+
for branch in self.branches:
|
| 235 |
+
outs.append(branch(x_t).transpose(1, 2)) # [B, L, C]
|
| 236 |
+
y = torch.cat(outs, dim=-1) # [B, L, C * n_scales]
|
| 237 |
+
y = self.fusion(y)
|
| 238 |
+
return self.dropout(y)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# =============================================================================
|
| 242 |
+
# Hybrid Backbone: Mamba + Multi-Scale Dilated Conv
|
| 243 |
+
# =============================================================================
|
| 244 |
+
|
| 245 |
+
class HybridLongSequenceBackbone(nn.Module):
|
| 246 |
+
"""
|
| 247 |
+
Hybrid backbone combining Mamba/SSM and multi-scale dilated convolutions.
|
| 248 |
+
|
| 249 |
+
Designed for efficient long-sequence modeling with multi-scale temporal patterns.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
d_model: int = 256,
|
| 255 |
+
n_layers: int = 4,
|
| 256 |
+
d_state: int = 64,
|
| 257 |
+
use_mamba: bool = True,
|
| 258 |
+
use_dilated_conv: bool = True,
|
| 259 |
+
dilations: Optional[List[int]] = None,
|
| 260 |
+
cond_dim: Optional[int] = None,
|
| 261 |
+
dropout: float = 0.1,
|
| 262 |
+
):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.d_model = d_model
|
| 265 |
+
self.n_layers = n_layers
|
| 266 |
+
self.d_state = d_state
|
| 267 |
+
self.use_dilated_conv = use_dilated_conv
|
| 268 |
+
self.use_mamba = use_mamba
|
| 269 |
+
self.cond_dim = cond_dim
|
| 270 |
+
|
| 271 |
+
if dilations is None:
|
| 272 |
+
dilations = [1, 4, 16]
|
| 273 |
+
|
| 274 |
+
self.blocks = nn.ModuleList()
|
| 275 |
+
for _ in range(n_layers):
|
| 276 |
+
self.blocks.append(
|
| 277 |
+
_HybridBlock(
|
| 278 |
+
d_model=d_model,
|
| 279 |
+
d_state=d_state,
|
| 280 |
+
use_mamba=use_mamba,
|
| 281 |
+
use_dilated_conv=use_dilated_conv,
|
| 282 |
+
dilations=dilations,
|
| 283 |
+
cond_dim=cond_dim,
|
| 284 |
+
dropout=dropout,
|
| 285 |
+
)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
def forward(
|
| 289 |
+
self,
|
| 290 |
+
x: torch.Tensor,
|
| 291 |
+
t_emb: Optional[torch.Tensor] = None, # [B, D]
|
| 292 |
+
cond: Optional[torch.Tensor] = None, # [B, C]
|
| 293 |
+
) -> torch.Tensor:
|
| 294 |
+
"""
|
| 295 |
+
Args:
|
| 296 |
+
x: [B, L, D] input sequence
|
| 297 |
+
Returns:
|
| 298 |
+
y: [B, L, D] output sequence
|
| 299 |
+
"""
|
| 300 |
+
for block in self.blocks:
|
| 301 |
+
x = block(x, t_emb=t_emb, cond=cond)
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _valid_num_groups(channels: int, requested: int) -> int:
|
| 306 |
+
g = min(requested, channels)
|
| 307 |
+
while g > 1 and (channels % g) != 0:
|
| 308 |
+
g -= 1
|
| 309 |
+
return max(g, 1)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class _HybridBlock(nn.Module):
|
| 313 |
+
def __init__(
|
| 314 |
+
self,
|
| 315 |
+
d_model: int,
|
| 316 |
+
d_state: int,
|
| 317 |
+
use_mamba: bool,
|
| 318 |
+
use_dilated_conv: bool,
|
| 319 |
+
dilations: List[int],
|
| 320 |
+
cond_dim: Optional[int],
|
| 321 |
+
dropout: float,
|
| 322 |
+
):
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.use_mamba = use_mamba
|
| 325 |
+
self.use_dilated_conv = use_dilated_conv
|
| 326 |
+
self.cond_dim = cond_dim
|
| 327 |
+
|
| 328 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 329 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 330 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 331 |
+
|
| 332 |
+
self.mamba = (
|
| 333 |
+
Mamba(MambaConfig(d_model=d_model, d_state=d_state))
|
| 334 |
+
if use_mamba
|
| 335 |
+
else nn.Identity()
|
| 336 |
+
)
|
| 337 |
+
self.conv = (
|
| 338 |
+
MultiScaleDilatedConv(
|
| 339 |
+
channels=d_model,
|
| 340 |
+
kernel_size=3,
|
| 341 |
+
dilations=dilations,
|
| 342 |
+
dropout=dropout,
|
| 343 |
+
)
|
| 344 |
+
if use_dilated_conv
|
| 345 |
+
else nn.Identity()
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
self.ffn = nn.Sequential(
|
| 349 |
+
nn.Linear(d_model, d_model * 4),
|
| 350 |
+
nn.GELU(),
|
| 351 |
+
nn.Dropout(dropout),
|
| 352 |
+
nn.Linear(d_model * 4, d_model),
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
self.dropout = nn.Dropout(dropout)
|
| 356 |
+
|
| 357 |
+
self.film = FiLMModulation(d_model, cond_dim) if cond_dim is not None else None
|
| 358 |
+
self.ada_gn = AdaptiveGroupNorm(d_model, cond_dim) if cond_dim is not None else None
|
| 359 |
+
|
| 360 |
+
def _cond(self, h: torch.Tensor, cond: Optional[torch.Tensor]) -> torch.Tensor:
|
| 361 |
+
if cond is None or self.film is None or self.ada_gn is None:
|
| 362 |
+
return h
|
| 363 |
+
h = self.film(h, cond)
|
| 364 |
+
h = self.ada_gn(h, cond)
|
| 365 |
+
return h
|
| 366 |
+
|
| 367 |
+
def forward(
|
| 368 |
+
self,
|
| 369 |
+
x: torch.Tensor, # [B, L, D]
|
| 370 |
+
t_emb: Optional[torch.Tensor] = None, # [B, D]
|
| 371 |
+
cond: Optional[torch.Tensor] = None, # [B, C]
|
| 372 |
+
) -> torch.Tensor:
|
| 373 |
+
# Mamba/SSM
|
| 374 |
+
h = self.norm1(x)
|
| 375 |
+
if t_emb is not None:
|
| 376 |
+
h = h + t_emb.unsqueeze(1)
|
| 377 |
+
h = self._cond(h, cond)
|
| 378 |
+
h = self.mamba(h)
|
| 379 |
+
x = x + self.dropout(h)
|
| 380 |
+
|
| 381 |
+
# Multi-scale dilated conv
|
| 382 |
+
if self.use_dilated_conv:
|
| 383 |
+
h = self.norm2(x)
|
| 384 |
+
if t_emb is not None:
|
| 385 |
+
h = h + 0.5 * t_emb.unsqueeze(1)
|
| 386 |
+
h = self._cond(h, cond)
|
| 387 |
+
h = self.conv(h)
|
| 388 |
+
x = x + self.dropout(h)
|
| 389 |
+
|
| 390 |
+
# FFN
|
| 391 |
+
h = self.norm3(x)
|
| 392 |
+
if t_emb is not None:
|
| 393 |
+
h = h + 0.5 * t_emb.unsqueeze(1)
|
| 394 |
+
h = self._cond(h, cond)
|
| 395 |
+
h = self.ffn(h)
|
| 396 |
+
x = x + self.dropout(h)
|
| 397 |
+
return x
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# =============================================================================
|
| 401 |
+
# FiLM Modulation for Condition Injection
|
| 402 |
+
# =============================================================================
|
| 403 |
+
|
| 404 |
+
class FiLMModulation(nn.Module):
|
| 405 |
+
"""
|
| 406 |
+
Feature-wise Linear Modulation (FiLM) for adaptive condition injection.
|
| 407 |
+
|
| 408 |
+
Dynamically modulates intermediate features based on spatial context.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
def __init__(self, d_model: int, cond_dim: int):
|
| 412 |
+
super().__init__()
|
| 413 |
+
|
| 414 |
+
self.gamma_proj = nn.Linear(cond_dim, d_model)
|
| 415 |
+
self.beta_proj = nn.Linear(cond_dim, d_model)
|
| 416 |
+
|
| 417 |
+
# Start near identity modulation
|
| 418 |
+
nn.init.zeros_(self.gamma_proj.weight)
|
| 419 |
+
nn.init.zeros_(self.gamma_proj.bias)
|
| 420 |
+
nn.init.zeros_(self.beta_proj.weight)
|
| 421 |
+
nn.init.zeros_(self.beta_proj.bias)
|
| 422 |
+
|
| 423 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 424 |
+
"""
|
| 425 |
+
Args:
|
| 426 |
+
x: [B, L, D] features
|
| 427 |
+
cond: [B, C] condition
|
| 428 |
+
Returns:
|
| 429 |
+
y: [B, L, D] modulated features
|
| 430 |
+
"""
|
| 431 |
+
gamma = self.gamma_proj(cond).unsqueeze(1) # [B, 1, D]
|
| 432 |
+
beta = self.beta_proj(cond).unsqueeze(1) # [B, 1, D]
|
| 433 |
+
return x * (1.0 + gamma) + beta
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
# =============================================================================
|
| 437 |
+
# Adaptive Group Normalization
|
| 438 |
+
# =============================================================================
|
| 439 |
+
|
| 440 |
+
class AdaptiveGroupNorm(nn.Module):
|
| 441 |
+
"""
|
| 442 |
+
Adaptive Group Normalization (AdaGN) for condition-aware normalization.
|
| 443 |
+
"""
|
| 444 |
+
|
| 445 |
+
def __init__(self, d_model: int, cond_dim: int, num_groups: int = 32):
|
| 446 |
+
super().__init__()
|
| 447 |
+
self.num_groups = _valid_num_groups(d_model, num_groups)
|
| 448 |
+
self.group_norm = nn.GroupNorm(self.num_groups, d_model, affine=False)
|
| 449 |
+
|
| 450 |
+
self.weight_proj = nn.Linear(cond_dim, d_model)
|
| 451 |
+
self.bias_proj = nn.Linear(cond_dim, d_model)
|
| 452 |
+
nn.init.zeros_(self.weight_proj.weight)
|
| 453 |
+
nn.init.zeros_(self.weight_proj.bias)
|
| 454 |
+
nn.init.zeros_(self.bias_proj.weight)
|
| 455 |
+
nn.init.zeros_(self.bias_proj.bias)
|
| 456 |
+
|
| 457 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 458 |
+
"""
|
| 459 |
+
Args:
|
| 460 |
+
x: [B, L, D] features
|
| 461 |
+
cond: [B, C] condition
|
| 462 |
+
Returns:
|
| 463 |
+
y: [B, L, D] normalized features
|
| 464 |
+
"""
|
| 465 |
+
# Group norm
|
| 466 |
+
x_norm = self.group_norm(x.transpose(1, 2)).transpose(1, 2) # [B, L, D]
|
| 467 |
+
|
| 468 |
+
# Adaptive scaling
|
| 469 |
+
weight = self.weight_proj(cond).unsqueeze(1) # [B, 1, D]
|
| 470 |
+
bias = self.bias_proj(cond).unsqueeze(1) # [B, 1, D]
|
| 471 |
+
|
| 472 |
+
return x_norm * (1.0 + weight) + bias
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class FourierTimeEmbedding(nn.Module):
|
| 476 |
+
"""Gaussian Fourier features for diffusion/FM time t in [0,1]."""
|
| 477 |
+
|
| 478 |
+
def __init__(self, d_model: int, n_freqs: int = 64):
|
| 479 |
+
super().__init__()
|
| 480 |
+
self.n_freqs = n_freqs
|
| 481 |
+
self.W = nn.Parameter(torch.randn(n_freqs) * 10.0)
|
| 482 |
+
self.proj = nn.Sequential(
|
| 483 |
+
nn.Linear(2 * n_freqs, d_model),
|
| 484 |
+
nn.GELU(),
|
| 485 |
+
nn.Linear(d_model, d_model),
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 489 |
+
# t: [B, 1]
|
| 490 |
+
t = t.clamp(0.0, 1.0)
|
| 491 |
+
w = self.W.view(1, 1, -1) # [1, 1, F]
|
| 492 |
+
angles = 2 * math.pi * t.unsqueeze(-1) * w # [B, 1, F]
|
| 493 |
+
emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).squeeze(1)
|
| 494 |
+
return self.proj(emb) # [B, D]
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def sinusoidal_positional_embedding(
|
| 498 |
+
length: int,
|
| 499 |
+
dim: int,
|
| 500 |
+
device: torch.device,
|
| 501 |
+
dtype: torch.dtype,
|
| 502 |
+
) -> torch.Tensor:
|
| 503 |
+
"""Standard sinusoidal positional embeddings [L, D]."""
|
| 504 |
+
position = torch.arange(length, device=device, dtype=dtype).unsqueeze(1) # [L, 1]
|
| 505 |
+
div_term = torch.exp(
|
| 506 |
+
torch.arange(0, dim, 2, device=device, dtype=dtype) * (-math.log(10000.0) / dim)
|
| 507 |
+
) # [D/2]
|
| 508 |
+
pe = torch.zeros(length, dim, device=device, dtype=dtype)
|
| 509 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 510 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 511 |
+
return pe
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# =============================================================================
|
| 515 |
+
# Level 1: Daily Pattern Flow Matching
|
| 516 |
+
# =============================================================================
|
| 517 |
+
|
| 518 |
+
class DailyPatternFM(nn.Module):
|
| 519 |
+
"""
|
| 520 |
+
Level 1: Daily Pattern Flow Matching.
|
| 521 |
+
|
| 522 |
+
Learns to generate day-type templates for hourly traffic:
|
| 523 |
+
- weekday template (24 hours)
|
| 524 |
+
- weekend template (24 hours)
|
| 525 |
+
|
| 526 |
+
Output is a concatenation of two 24-hour patterns: [weekday | weekend] -> 48 dims.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
def __init__(self, spatial_dim: int = 192, hidden_dim: int = 256, steps_per_day: int = 24):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.spatial_dim = spatial_dim
|
| 532 |
+
self.hidden_dim = hidden_dim
|
| 533 |
+
self.steps_per_day = steps_per_day
|
| 534 |
+
self.daytype_len = 2 * steps_per_day # weekday + weekend
|
| 535 |
+
|
| 536 |
+
self.time_embed = FourierTimeEmbedding(hidden_dim)
|
| 537 |
+
|
| 538 |
+
self.in_proj = nn.Linear(1, hidden_dim)
|
| 539 |
+
self.backbone = HybridLongSequenceBackbone(
|
| 540 |
+
d_model=hidden_dim,
|
| 541 |
+
n_layers=3,
|
| 542 |
+
d_state=64,
|
| 543 |
+
use_mamba=True,
|
| 544 |
+
use_dilated_conv=True,
|
| 545 |
+
dilations=[1, 2, 4, 8, 16],
|
| 546 |
+
cond_dim=spatial_dim,
|
| 547 |
+
dropout=0.1,
|
| 548 |
+
)
|
| 549 |
+
self.out_proj = nn.Linear(hidden_dim, 1)
|
| 550 |
+
|
| 551 |
+
def forward(
|
| 552 |
+
self,
|
| 553 |
+
x: torch.Tensor,
|
| 554 |
+
t: torch.Tensor,
|
| 555 |
+
spatial_cond: torch.Tensor,
|
| 556 |
+
) -> torch.Tensor:
|
| 557 |
+
"""
|
| 558 |
+
Args:
|
| 559 |
+
x: [B, 48] day-type templates = [weekday(24), weekend(24)]
|
| 560 |
+
t: [B, 1] time step
|
| 561 |
+
spatial_cond: [B, spatial_dim] spatial context
|
| 562 |
+
Returns:
|
| 563 |
+
v: [B, 48] velocity field
|
| 564 |
+
"""
|
| 565 |
+
B, L = x.shape
|
| 566 |
+
assert L == self.daytype_len, f"DailyPatternFM expects L={self.daytype_len}, got {L}"
|
| 567 |
+
|
| 568 |
+
t_emb = self.time_embed(t) # [B, hidden_dim]
|
| 569 |
+
pos = sinusoidal_positional_embedding(L, self.hidden_dim, x.device, x.dtype) # [L, D]
|
| 570 |
+
|
| 571 |
+
h = self.in_proj(x.unsqueeze(-1)) # [B, L, D]
|
| 572 |
+
h = h + pos.unsqueeze(0)
|
| 573 |
+
h = self.backbone(h, t_emb=t_emb, cond=spatial_cond)
|
| 574 |
+
v = self.out_proj(h).squeeze(-1) # [B, L]
|
| 575 |
+
return v
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# =============================================================================
|
| 579 |
+
# Level 2: Weekly Pattern Flow Matching
|
| 580 |
+
# =============================================================================
|
| 581 |
+
|
| 582 |
+
class WeeklyPatternFM(nn.Module):
|
| 583 |
+
"""
|
| 584 |
+
Level 2: Weekly Pattern Flow Matching.
|
| 585 |
+
|
| 586 |
+
Learns to generate a weekly periodic pattern at hourly resolution:
|
| 587 |
+
weekly_pattern: 7 days × 24 hours = 168 time steps.
|
| 588 |
+
|
| 589 |
+
This level is conditioned on day-type templates from Level 1.
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
def __init__(self, spatial_dim: int = 192, hidden_dim: int = 256, steps_per_day: int = 24):
|
| 593 |
+
super().__init__()
|
| 594 |
+
self.spatial_dim = spatial_dim
|
| 595 |
+
self.hidden_dim = hidden_dim
|
| 596 |
+
self.steps_per_day = steps_per_day
|
| 597 |
+
self.week_len = 7 * steps_per_day
|
| 598 |
+
self.daytype_len = 2 * steps_per_day
|
| 599 |
+
|
| 600 |
+
self.time_embed = FourierTimeEmbedding(hidden_dim)
|
| 601 |
+
|
| 602 |
+
self.in_proj = nn.Linear(1, hidden_dim)
|
| 603 |
+
self.daily_token_proj = nn.Linear(1, hidden_dim)
|
| 604 |
+
|
| 605 |
+
self.daily_to_weekly_attn = nn.MultiheadAttention(
|
| 606 |
+
embed_dim=hidden_dim,
|
| 607 |
+
num_heads=8,
|
| 608 |
+
dropout=0.1,
|
| 609 |
+
batch_first=True,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
self.backbone = HybridLongSequenceBackbone(
|
| 613 |
+
d_model=hidden_dim,
|
| 614 |
+
n_layers=3,
|
| 615 |
+
d_state=64,
|
| 616 |
+
use_mamba=True,
|
| 617 |
+
use_dilated_conv=True,
|
| 618 |
+
dilations=[1, 2, 4],
|
| 619 |
+
cond_dim=spatial_dim,
|
| 620 |
+
dropout=0.1,
|
| 621 |
+
)
|
| 622 |
+
self.out_proj = nn.Linear(hidden_dim, 1)
|
| 623 |
+
|
| 624 |
+
def forward(
|
| 625 |
+
self,
|
| 626 |
+
x: torch.Tensor,
|
| 627 |
+
t: torch.Tensor,
|
| 628 |
+
daily_pattern: torch.Tensor,
|
| 629 |
+
spatial_cond: torch.Tensor,
|
| 630 |
+
) -> torch.Tensor:
|
| 631 |
+
"""
|
| 632 |
+
Args:
|
| 633 |
+
x: [B, 168] weekly pattern
|
| 634 |
+
t: [B, 1] time step
|
| 635 |
+
daily_pattern: [B, 48] day-type templates (from Level 1)
|
| 636 |
+
spatial_cond: [B, spatial_dim] spatial context
|
| 637 |
+
Returns:
|
| 638 |
+
v: [B, 168] velocity field
|
| 639 |
+
"""
|
| 640 |
+
B, Lw = x.shape
|
| 641 |
+
assert Lw == self.week_len, f"WeeklyPatternFM expects L={self.week_len}, got {Lw}"
|
| 642 |
+
Bd, Ld = daily_pattern.shape
|
| 643 |
+
assert Bd == B and Ld == self.daytype_len, (
|
| 644 |
+
f"WeeklyPatternFM expects daily_pattern [B,{self.daytype_len}], got {daily_pattern.shape}"
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
t_emb = self.time_embed(t) # [B, D]
|
| 648 |
+
|
| 649 |
+
pos_w = sinusoidal_positional_embedding(Lw, self.hidden_dim, x.device, x.dtype)
|
| 650 |
+
pos_d = sinusoidal_positional_embedding(Ld, self.hidden_dim, x.device, x.dtype)
|
| 651 |
+
|
| 652 |
+
week_tokens = self.in_proj(x.unsqueeze(-1)) + pos_w.unsqueeze(0) # [B, 168, D]
|
| 653 |
+
day_tokens = self.daily_token_proj(daily_pattern.unsqueeze(-1)) + pos_d.unsqueeze(0) # [B, 48, D]
|
| 654 |
+
|
| 655 |
+
# Explicitly condition on S^d via cross-attention (decouples day/week)
|
| 656 |
+
attn_out, _ = self.daily_to_weekly_attn(week_tokens, day_tokens, day_tokens)
|
| 657 |
+
week_tokens = week_tokens + attn_out
|
| 658 |
+
|
| 659 |
+
h = self.backbone(week_tokens, t_emb=t_emb, cond=spatial_cond)
|
| 660 |
+
v = self.out_proj(h).squeeze(-1) # [B, 168]
|
| 661 |
+
return v
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
# =============================================================================
|
| 665 |
+
# Level 3: Long-term Residual Flow Matching
|
| 666 |
+
# =============================================================================
|
| 667 |
+
|
| 668 |
+
class LongTermResidualFM(nn.Module):
|
| 669 |
+
"""
|
| 670 |
+
Level 3: Long-term Residual Flow Matching.
|
| 671 |
+
|
| 672 |
+
Learns to generate fine residuals for the full sequence (672 time steps).
|
| 673 |
+
Uses Mamba + multi-scale dilated convolutions for efficient long-sequence modeling.
|
| 674 |
+
|
| 675 |
+
Explicitly conditioned on peak hour location to force peak generation.
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
def __init__(
|
| 679 |
+
self,
|
| 680 |
+
spatial_dim: int = 192,
|
| 681 |
+
hidden_dim: int = 256,
|
| 682 |
+
n_layers: int = 6,
|
| 683 |
+
steps_per_day: int = 24,
|
| 684 |
+
):
|
| 685 |
+
super().__init__()
|
| 686 |
+
self.spatial_dim = spatial_dim
|
| 687 |
+
self.hidden_dim = hidden_dim
|
| 688 |
+
self.steps_per_day = steps_per_day
|
| 689 |
+
self.week_len = 7 * steps_per_day
|
| 690 |
+
self.daytype_len = 2 * steps_per_day
|
| 691 |
+
|
| 692 |
+
self.time_embed = FourierTimeEmbedding(hidden_dim)
|
| 693 |
+
|
| 694 |
+
# Explicit Peak Position Encoding
|
| 695 |
+
# Maps 0-23 hours to a hidden vector to serve as a strong condition
|
| 696 |
+
self.peak_embed = nn.Embedding(24, hidden_dim)
|
| 697 |
+
|
| 698 |
+
# Token-wise projection of multi-channel inputs
|
| 699 |
+
self.in_proj = nn.Linear(4, hidden_dim)
|
| 700 |
+
|
| 701 |
+
# Main backbone (Mamba + multi-scale dilated conv) for long sequences
|
| 702 |
+
self.backbone = HybridLongSequenceBackbone(
|
| 703 |
+
d_model=hidden_dim,
|
| 704 |
+
n_layers=n_layers,
|
| 705 |
+
d_state=128,
|
| 706 |
+
use_mamba=True,
|
| 707 |
+
use_dilated_conv=True,
|
| 708 |
+
# (local, daily, weekly) receptive fields at hourly resolution
|
| 709 |
+
dilations=[1, 2, 4, 8, 16, 24, 48, 168],
|
| 710 |
+
cond_dim=spatial_dim,
|
| 711 |
+
dropout=0.1,
|
| 712 |
+
)
|
| 713 |
+
self.out_proj = nn.Linear(hidden_dim, 1)
|
| 714 |
+
|
| 715 |
+
def _repeat_to_length(self, pattern: torch.Tensor, target_len: int) -> torch.Tensor:
|
| 716 |
+
# pattern: [B, P]
|
| 717 |
+
B, P = pattern.shape
|
| 718 |
+
reps = (target_len + P - 1) // P
|
| 719 |
+
tiled = pattern.repeat(1, reps)
|
| 720 |
+
return tiled[:, :target_len]
|
| 721 |
+
|
| 722 |
+
def _repeat_daytype_to_length(self, daytype: torch.Tensor, target_len: int) -> torch.Tensor:
|
| 723 |
+
"""
|
| 724 |
+
Expand day-type templates (weekday/weekend) to a full 28-day hourly sequence.
|
| 725 |
+
|
| 726 |
+
Assumption (consistent with plot_traffic_decomposition*.py): sequence starts on Monday.
|
| 727 |
+
"""
|
| 728 |
+
B, L = daytype.shape
|
| 729 |
+
assert L == self.daytype_len, f"Expected daytype_len={self.daytype_len}, got {L}"
|
| 730 |
+
steps = self.steps_per_day
|
| 731 |
+
weekday = daytype[:, :steps]
|
| 732 |
+
weekend = daytype[:, steps:]
|
| 733 |
+
|
| 734 |
+
n_days = target_len // steps
|
| 735 |
+
parts = []
|
| 736 |
+
for d in range(n_days):
|
| 737 |
+
dow = d % 7
|
| 738 |
+
parts.append(weekday if dow < 5 else weekend)
|
| 739 |
+
seq = torch.cat(parts, dim=1) # [B, n_days*steps]
|
| 740 |
+
if seq.shape[1] < target_len:
|
| 741 |
+
pad = torch.zeros(B, target_len - seq.shape[1], device=seq.device, dtype=seq.dtype)
|
| 742 |
+
seq = torch.cat([seq, pad], dim=1)
|
| 743 |
+
return seq[:, :target_len]
|
| 744 |
+
|
| 745 |
+
def forward(
|
| 746 |
+
self,
|
| 747 |
+
x: torch.Tensor,
|
| 748 |
+
t: torch.Tensor,
|
| 749 |
+
coarse_signal: torch.Tensor,
|
| 750 |
+
daily_pattern: torch.Tensor,
|
| 751 |
+
weekly_trend: torch.Tensor,
|
| 752 |
+
spatial_cond: torch.Tensor,
|
| 753 |
+
peak_hour: torch.Tensor,
|
| 754 |
+
) -> torch.Tensor:
|
| 755 |
+
"""
|
| 756 |
+
Args:
|
| 757 |
+
x: [B, 672] residual sequence
|
| 758 |
+
t: [B, 1] time step
|
| 759 |
+
coarse_signal: [B, 672] periodic component (tiled weekly pattern)
|
| 760 |
+
daily_pattern: [B, 48] day-type templates
|
| 761 |
+
weekly_trend: [B, 168] weekly pattern
|
| 762 |
+
spatial_cond: [B, spatial_dim] spatial context
|
| 763 |
+
peak_hour: [B] Integer tensor (0-23) indicating explicit peak location
|
| 764 |
+
Returns:
|
| 765 |
+
v: [B, 672] velocity field
|
| 766 |
+
"""
|
| 767 |
+
B, L = x.shape
|
| 768 |
+
assert coarse_signal.shape == (B, L)
|
| 769 |
+
assert daily_pattern.shape == (B, self.daytype_len)
|
| 770 |
+
assert weekly_trend.shape == (B, self.week_len)
|
| 771 |
+
|
| 772 |
+
# Fuse Time Embedding with Peak Embedding
|
| 773 |
+
t_emb = self.time_embed(t) # [B, D]
|
| 774 |
+
peak_cond = self.peak_embed(peak_hour) # [B, D]
|
| 775 |
+
|
| 776 |
+
# Combine: Global time context + "Peak Attention" bias
|
| 777 |
+
global_cond = t_emb + peak_cond
|
| 778 |
+
|
| 779 |
+
pos = sinusoidal_positional_embedding(L, self.hidden_dim, x.device, x.dtype)
|
| 780 |
+
|
| 781 |
+
daily_rep = self._repeat_daytype_to_length(daily_pattern, L) # [B, L]
|
| 782 |
+
# weekly_trend is weekly pattern here: tile 168 -> 672 (4 weeks)
|
| 783 |
+
weekly_rep = self._repeat_to_length(weekly_trend, L) # [B, L]
|
| 784 |
+
weekly_delta = coarse_signal - daily_rep # [B, L]
|
| 785 |
+
|
| 786 |
+
# Token features: [residual, periodic, repeated_daytype, weekly_delta]
|
| 787 |
+
feats = torch.stack([x, coarse_signal, daily_rep, weekly_delta], dim=-1) # [B, L, 4]
|
| 788 |
+
h = self.in_proj(feats) + pos.unsqueeze(0)
|
| 789 |
+
|
| 790 |
+
# Pass combined global condition
|
| 791 |
+
h = self.backbone(h, t_emb=global_cond, cond=spatial_cond)
|
| 792 |
+
|
| 793 |
+
v = self.out_proj(h).squeeze(-1) # [B, L]
|
| 794 |
+
return v
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
# =============================================================================
|
| 798 |
+
# Complete Hierarchical Flow Matching Model
|
| 799 |
+
# =============================================================================
|
| 800 |
+
|
| 801 |
+
class HierarchicalFlowMatchingV4(nn.Module):
|
| 802 |
+
"""
|
| 803 |
+
Complete Hierarchical Flow Matching model with three-level cascaded architecture.
|
| 804 |
+
|
| 805 |
+
Level 1: Daily Pattern FM
|
| 806 |
+
Level 2: Weekly Pattern FM (with daily conditioning)
|
| 807 |
+
Level 3: Long-term Residual FM (with daily + weekly conditioning + explicit peak)
|
| 808 |
+
"""
|
| 809 |
+
|
| 810 |
+
def __init__(
|
| 811 |
+
self,
|
| 812 |
+
spatial_dim: int = 192,
|
| 813 |
+
hidden_dim: int = 256,
|
| 814 |
+
n_layers_level3: int = 6,
|
| 815 |
+
steps_per_day: int = 24,
|
| 816 |
+
):
|
| 817 |
+
super().__init__()
|
| 818 |
+
self.spatial_dim = spatial_dim
|
| 819 |
+
self.hidden_dim = hidden_dim
|
| 820 |
+
self.steps_per_day = steps_per_day
|
| 821 |
+
self.week_len = 7 * steps_per_day
|
| 822 |
+
self.daytype_len = 2 * steps_per_day
|
| 823 |
+
# This repo's 672-length traffic is hourly: 28 days = 4 weeks.
|
| 824 |
+
self.seq_len = 672
|
| 825 |
+
self.n_weeks = self.seq_len // self.week_len
|
| 826 |
+
|
| 827 |
+
# Three-level FM
|
| 828 |
+
self.level1_fm = DailyPatternFM(spatial_dim, hidden_dim, steps_per_day=steps_per_day)
|
| 829 |
+
self.level2_fm = WeeklyPatternFM(spatial_dim, hidden_dim, steps_per_day=steps_per_day)
|
| 830 |
+
self.level3_fm = LongTermResidualFM(
|
| 831 |
+
spatial_dim, hidden_dim, n_layers_level3, steps_per_day=steps_per_day
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
def forward(
|
| 835 |
+
self,
|
| 836 |
+
x: torch.Tensor,
|
| 837 |
+
t: torch.Tensor,
|
| 838 |
+
spatial_cond: torch.Tensor,
|
| 839 |
+
level: int = 1,
|
| 840 |
+
daily_pattern: Optional[torch.Tensor] = None,
|
| 841 |
+
weekly_trend: Optional[torch.Tensor] = None,
|
| 842 |
+
coarse_signal: Optional[torch.Tensor] = None,
|
| 843 |
+
peak_hour: Optional[torch.Tensor] = None,
|
| 844 |
+
) -> torch.Tensor:
|
| 845 |
+
"""
|
| 846 |
+
Forward pass for a specific level.
|
| 847 |
+
"""
|
| 848 |
+
if level == 1:
|
| 849 |
+
return self.level1_fm(x, t, spatial_cond)
|
| 850 |
+
|
| 851 |
+
elif level == 2:
|
| 852 |
+
assert daily_pattern is not None, "daily_pattern required for level 2"
|
| 853 |
+
return self.level2_fm(x, t, daily_pattern, spatial_cond)
|
| 854 |
+
|
| 855 |
+
elif level == 3:
|
| 856 |
+
assert daily_pattern is not None, "daily_pattern required for level 3"
|
| 857 |
+
assert weekly_trend is not None, "weekly_trend required for level 3"
|
| 858 |
+
assert coarse_signal is not None, "coarse_signal required for level 3"
|
| 859 |
+
assert peak_hour is not None, "peak_hour required for level 3 (Explicit Peak Conditioning)"
|
| 860 |
+
return self.level3_fm(x, t, coarse_signal, daily_pattern, weekly_trend, spatial_cond, peak_hour)
|
| 861 |
+
|
| 862 |
+
else:
|
| 863 |
+
raise ValueError(f"Invalid level: {level}")
|
| 864 |
+
|
| 865 |
+
# =========================================================================
|
| 866 |
+
# Generation Methods (ODE Solve)
|
| 867 |
+
# =========================================================================
|
| 868 |
+
|
| 869 |
+
def _unpack_level_conditions(
|
| 870 |
+
self,
|
| 871 |
+
spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
|
| 872 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 873 |
+
if isinstance(spatial_cond, dict):
|
| 874 |
+
return (
|
| 875 |
+
spatial_cond["level1_cond"],
|
| 876 |
+
spatial_cond["level2_cond"],
|
| 877 |
+
spatial_cond["level3_cond"],
|
| 878 |
+
)
|
| 879 |
+
return spatial_cond, spatial_cond, spatial_cond
|
| 880 |
+
|
| 881 |
+
def generate_daily_pattern(
|
| 882 |
+
self,
|
| 883 |
+
spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
|
| 884 |
+
n_steps: int = 50,
|
| 885 |
+
) -> torch.Tensor:
|
| 886 |
+
"""
|
| 887 |
+
Generate day-type templates (Level 1).
|
| 888 |
+
"""
|
| 889 |
+
spatial_cond_level1, _, _ = self._unpack_level_conditions(spatial_cond)
|
| 890 |
+
B = spatial_cond_level1.shape[0]
|
| 891 |
+
device = spatial_cond_level1.device
|
| 892 |
+
|
| 893 |
+
x = torch.randn(B, self.daytype_len, device=device)
|
| 894 |
+
dt = 1.0 / n_steps
|
| 895 |
+
|
| 896 |
+
for step in range(n_steps):
|
| 897 |
+
t = torch.full((B, 1), step / n_steps, device=device)
|
| 898 |
+
v = self.level1_fm(x, t, spatial_cond_level1)
|
| 899 |
+
v = torch.clamp(v, -10.0, 10.0)
|
| 900 |
+
x = x + dt * v
|
| 901 |
+
x = torch.clamp(x, -10.0, 10.0)
|
| 902 |
+
|
| 903 |
+
return x
|
| 904 |
+
|
| 905 |
+
def generate_weekly_trend(
|
| 906 |
+
self,
|
| 907 |
+
daily_pattern: torch.Tensor,
|
| 908 |
+
spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
|
| 909 |
+
n_steps: int = 50,
|
| 910 |
+
) -> torch.Tensor:
|
| 911 |
+
"""
|
| 912 |
+
Generate weekly pattern (Level 2).
|
| 913 |
+
"""
|
| 914 |
+
_, spatial_cond_level2, _ = self._unpack_level_conditions(spatial_cond)
|
| 915 |
+
B = spatial_cond_level2.shape[0]
|
| 916 |
+
device = spatial_cond_level2.device
|
| 917 |
+
|
| 918 |
+
x = torch.randn(B, self.week_len, device=device)
|
| 919 |
+
dt = 1.0 / n_steps
|
| 920 |
+
|
| 921 |
+
for step in range(n_steps):
|
| 922 |
+
t = torch.full((B, 1), step / n_steps, device=device)
|
| 923 |
+
v = self.level2_fm(x, t, daily_pattern, spatial_cond_level2)
|
| 924 |
+
v = torch.clamp(v, -10.0, 10.0)
|
| 925 |
+
x = x + dt * v
|
| 926 |
+
x = torch.clamp(x, -10.0, 10.0)
|
| 927 |
+
|
| 928 |
+
return x
|
| 929 |
+
|
| 930 |
+
def generate_residual(
|
| 931 |
+
self,
|
| 932 |
+
coarse_signal: torch.Tensor,
|
| 933 |
+
daily_pattern: torch.Tensor,
|
| 934 |
+
weekly_trend: torch.Tensor,
|
| 935 |
+
spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
|
| 936 |
+
peak_hour: torch.Tensor,
|
| 937 |
+
n_steps: int = 50,
|
| 938 |
+
) -> torch.Tensor:
|
| 939 |
+
"""
|
| 940 |
+
Generate fine residual (Level 3).
|
| 941 |
+
Requires peak_hour for explicit conditioning.
|
| 942 |
+
"""
|
| 943 |
+
_, _, spatial_cond_level3 = self._unpack_level_conditions(spatial_cond)
|
| 944 |
+
B = spatial_cond_level3.shape[0]
|
| 945 |
+
device = spatial_cond_level3.device
|
| 946 |
+
|
| 947 |
+
x = 0.1 * torch.randn_like(coarse_signal, device=device)
|
| 948 |
+
dt = 1.0 / n_steps
|
| 949 |
+
|
| 950 |
+
for step in range(n_steps):
|
| 951 |
+
t = torch.full((B, 1), step / n_steps, device=device)
|
| 952 |
+
# Pass peak_hour
|
| 953 |
+
v = self.level3_fm(
|
| 954 |
+
x, t, coarse_signal, daily_pattern, weekly_trend, spatial_cond_level3, peak_hour
|
| 955 |
+
)
|
| 956 |
+
v = torch.clamp(v, -5.0, 5.0)
|
| 957 |
+
x = x + dt * v
|
| 958 |
+
x = torch.clamp(x, -5.0, 5.0)
|
| 959 |
+
|
| 960 |
+
return x
|
| 961 |
+
|
| 962 |
+
def generate_hierarchical(
|
| 963 |
+
self,
|
| 964 |
+
spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
|
| 965 |
+
peak_hour: torch.Tensor, # Required input
|
| 966 |
+
n_steps_per_level: int = 50,
|
| 967 |
+
) -> Tuple[torch.Tensor, Dict]:
|
| 968 |
+
"""
|
| 969 |
+
Full hierarchical generation.
|
| 970 |
+
"""
|
| 971 |
+
spatial_cond_level1, spatial_cond_level2, spatial_cond_level3 = self._unpack_level_conditions(
|
| 972 |
+
spatial_cond
|
| 973 |
+
)
|
| 974 |
+
B = spatial_cond_level3.shape[0]
|
| 975 |
+
device = spatial_cond_level3.device
|
| 976 |
+
|
| 977 |
+
# Level 1: Generate day-type templates
|
| 978 |
+
daily_pattern = self.generate_daily_pattern(spatial_cond_level1, n_steps_per_level)
|
| 979 |
+
daily_pattern = torch.clamp(daily_pattern, -10.0, 10.0)
|
| 980 |
+
|
| 981 |
+
# Level 2: Generate weekly pattern (168 hours)
|
| 982 |
+
weekly_pattern = self.generate_weekly_trend(
|
| 983 |
+
daily_pattern, spatial_cond_level2, n_steps_per_level
|
| 984 |
+
)
|
| 985 |
+
weekly_pattern = torch.clamp(weekly_pattern, -10.0, 10.0)
|
| 986 |
+
|
| 987 |
+
# Construct periodic component for 4 weeks (672 hours)
|
| 988 |
+
coarse_signal = weekly_pattern.repeat(1, self.n_weeks) # [B, 672]
|
| 989 |
+
coarse_signal = torch.clamp(coarse_signal, -10.0, 10.0)
|
| 990 |
+
|
| 991 |
+
# Level 3: Generate fine residual
|
| 992 |
+
# Pass peak_hour to residual generator
|
| 993 |
+
residual = self.generate_residual(
|
| 994 |
+
coarse_signal,
|
| 995 |
+
daily_pattern,
|
| 996 |
+
weekly_pattern,
|
| 997 |
+
spatial_cond_level3,
|
| 998 |
+
peak_hour=peak_hour,
|
| 999 |
+
n_steps=n_steps_per_level,
|
| 1000 |
+
)
|
| 1001 |
+
residual = torch.clamp(residual, -5.0, 5.0)
|
| 1002 |
+
|
| 1003 |
+
# Final output
|
| 1004 |
+
generated = coarse_signal + residual
|
| 1005 |
+
|
| 1006 |
+
# =========================================================================
|
| 1007 |
+
# [MODIFIED] Physical Constraint: Enforce non-negative traffic
|
| 1008 |
+
# Previously was: generated = torch.clamp(generated, -10.0, 10.0)
|
| 1009 |
+
# =========================================================================
|
| 1010 |
+
generated = torch.clamp(generated, min=0.0, max=10.0)
|
| 1011 |
+
|
| 1012 |
+
intermediates = {
|
| 1013 |
+
'daily_pattern': daily_pattern,
|
| 1014 |
+
'weekly_pattern': weekly_pattern,
|
| 1015 |
+
'coarse_signal': coarse_signal,
|
| 1016 |
+
'residual': residual,
|
| 1017 |
+
}
|
| 1018 |
+
|
| 1019 |
+
return generated, intermediates
|
index.html
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Perception Layer - Data Alignment</title>
|
| 7 |
+
|
| 8 |
+
<script src="https://api.mapbox.com/mapbox-gl-js/v2.15.0/mapbox-gl.js"></script>
|
| 9 |
+
<link href="https://api.mapbox.com/mapbox-gl-js/v2.15.0/mapbox-gl.css" rel="stylesheet" />
|
| 10 |
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
| 11 |
+
|
| 12 |
+
<link rel="stylesheet" href="style.css">
|
| 13 |
+
</head>
|
| 14 |
+
<body>
|
| 15 |
+
|
| 16 |
+
<div id="loading" class="loading-overlay">
|
| 17 |
+
<div class="spinner"></div>
|
| 18 |
+
<h2>SYSTEM INITIALIZING</h2>
|
| 19 |
+
<p>Loading Spatial & Temporal Data...</p>
|
| 20 |
+
</div>
|
| 21 |
+
|
| 22 |
+
<div class="sidebar">
|
| 23 |
+
<div class="header">
|
| 24 |
+
<h1>Overall</h1>
|
| 25 |
+
|
| 26 |
+
<div class="search-section">
|
| 27 |
+
<div class="search-container">
|
| 28 |
+
<input type="text" id="search-input" placeholder="Search ID..." autocomplete="off">
|
| 29 |
+
<button id="search-btn" class="cyber-btn-small">GO</button>
|
| 30 |
+
<button id="clear-search-btn" class="cyber-btn-small" title="Clear Markers">✕</button>
|
| 31 |
+
</div>
|
| 32 |
+
|
| 33 |
+
<div class="search-mode">
|
| 34 |
+
<input type="checkbox" id="keep-markers-check" checked>
|
| 35 |
+
<label for="keep-markers-check">Keep Previous Markers</label>
|
| 36 |
+
</div>
|
| 37 |
+
</div>
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
<div class="card">
|
| 41 |
+
<h2>📈 Temporal Modality</h2>
|
| 42 |
+
<div class="chart-container">
|
| 43 |
+
<canvas id="energyChart"></canvas>
|
| 44 |
+
</div>
|
| 45 |
+
</div>
|
| 46 |
+
|
| 47 |
+
<div class="card details-card">
|
| 48 |
+
<h2>📍 Spatial Metadata</h2>
|
| 49 |
+
<div class="stat-row">
|
| 50 |
+
<div><span class="label">Station ID</span> <span id="selected-id" class="value highlight">--</span></div>
|
| 51 |
+
<div><span class="label">Total Nodes</span> <span id="total-stations" class="value">--</span></div>
|
| 52 |
+
</div>
|
| 53 |
+
<div id="station-details" class="details-content">
|
| 54 |
+
<p class="placeholder-text">Waiting for selection...</p>
|
| 55 |
+
</div>
|
| 56 |
+
</div>
|
| 57 |
+
</div>
|
| 58 |
+
|
| 59 |
+
<div id="prediction-panel" class="sidebar-right">
|
| 60 |
+
<div class="header">
|
| 61 |
+
<h1>Traffic Prediction</h1>
|
| 62 |
+
<button id="close-pred-btn" class="cyber-btn-small">✕</button>
|
| 63 |
+
</div>
|
| 64 |
+
<div class="details-content">
|
| 65 |
+
<div class="stat-row" style="margin-bottom: 20px;">
|
| 66 |
+
<div><span class="label">Target Station ID</span> <span id="pred-station-id" class="value highlight" style="color: #f39c12;">--</span></div>
|
| 67 |
+
</div>
|
| 68 |
+
|
| 69 |
+
<div class="chart-container" style="height: 250px; position: relative;">
|
| 70 |
+
<canvas id="predictionChart"></canvas>
|
| 71 |
+
</div>
|
| 72 |
+
|
| 73 |
+
<!-- <div class="legend-box" style="margin-top: 20px; font-size: 0.9em; padding: 10px; background: rgba(0,0,0,0.3); border-radius: 4px;">
|
| 74 |
+
<div style="display:flex; align-items:center; margin-bottom:8px;">
|
| 75 |
+
<span style="display:inline-block; width:12px; height:12px; background:#00cec9; margin-right:10px; border-radius:50%;"></span>
|
| 76 |
+
<span>Real Data (Observed)</span>
|
| 77 |
+
</div>
|
| 78 |
+
<div style="display:flex; align-items:center;">
|
| 79 |
+
<span style="display:inline-block; width:12px; height:12px; background:#f39c12; margin-right:10px; border-radius:50%;"></span>
|
| 80 |
+
<span>AI Prediction (Model + POI)</span>
|
| 81 |
+
</div>
|
| 82 |
+
</div> -->
|
| 83 |
+
|
| 84 |
+
<div id="site-map-container" style="margin-top: 20px; display: none; border-top: 1px solid rgba(243, 156, 18, 0.3); padding-top: 15px;">
|
| 85 |
+
<h3 style="font-size: 13px; color: #f39c12; margin-bottom: 10px; text-transform: uppercase; letter-spacing: 1px;">
|
| 86 |
+
<span class="icon">📍</span> Optimal Site Analysis
|
| 87 |
+
</h3>
|
| 88 |
+
|
| 89 |
+
<div style="background: rgba(0,0,0,0.5); padding: 10px; border-radius: 6px; border: 1px solid rgba(255,255,255,0.05); display: flex; justify-content: center; align-items: center;">
|
| 90 |
+
<img id="site-map-img" src="" alt="LSI Site Map" style="width: 75%; border-radius: 4px; box-shadow: 0 0 10px rgba(0,0,0,0.5); display: block;">
|
| 91 |
+
</div>
|
| 92 |
+
|
| 93 |
+
<div id="site-explanation" class="cyber-explanation" style="display: none;"></div>
|
| 94 |
+
|
| 95 |
+
<p style="font-size: 0.7em; color: #aaa; margin-top: 8px; line-height: 1.4;">
|
| 96 |
+
* Heatmap calculated via spatial windowing.<br>
|
| 97 |
+
<span style="color:#2ecc71;">Green = High LSI (Stable)</span> | <span style="color:#e74c3c;">Red = High Volatility</span>
|
| 98 |
+
</p>
|
| 99 |
+
</div>
|
| 100 |
+
|
| 101 |
+
<p style="font-size: 0.75em; color: #666; margin-top: 20px; line-height: 1.4; border-top: 1px solid #333; padding-top: 10px;">
|
| 102 |
+
* Powered by <strong>Hierarchical Flow Matching V4</strong>.<br>
|
| 103 |
+
Utilizes Multi-modal Spatial Embeddings (POI, Satellite, Coordinates) for context-aware traffic forecasting.
|
| 104 |
+
</p>
|
| 105 |
+
</div>
|
| 106 |
+
</div>
|
| 107 |
+
|
| 108 |
+
<button id="toggle-left-btn" class="panel-toggle-btn left-toggle">◀</button>
|
| 109 |
+
<button id="toggle-right-btn" class="panel-toggle-btn right-toggle">▶</button>
|
| 110 |
+
<div class="main-content">
|
| 111 |
+
<div class="controls-container">
|
| 112 |
+
<button id="view-toggle" class="cyber-btn">
|
| 113 |
+
<span class="icon">👁️</span> View: 3D
|
| 114 |
+
</button>
|
| 115 |
+
<button id="data-toggle" class="cyber-btn">
|
| 116 |
+
<span class="icon">📡</span> Toggle Data
|
| 117 |
+
</button>
|
| 118 |
+
|
| 119 |
+
<button id="predict-toggle" class="cyber-btn" style="border-color: #f39c12; color: #f39c12;">
|
| 120 |
+
<span class="icon">🔮</span> Prediction Mode
|
| 121 |
+
</button>
|
| 122 |
+
|
| 123 |
+
<div class="filter-wrapper">
|
| 124 |
+
<button id="filter-btn" class="cyber-btn">
|
| 125 |
+
<span class="icon">🌪️</span> Filter Volatility
|
| 126 |
+
</button>
|
| 127 |
+
<div id="filter-menu" class="filter-menu"></div>
|
| 128 |
+
</div>
|
| 129 |
+
</div>
|
| 130 |
+
|
| 131 |
+
<div class="time-panel">
|
| 132 |
+
<button id="play-btn" class="cyber-btn play-control">▶</button>
|
| 133 |
+
<div class="slider-wrapper">
|
| 134 |
+
<input type="range" id="time-slider" min="0" max="671" value="0" step="1">
|
| 135 |
+
<div class="slider-ticks">
|
| 136 |
+
<span>Day 1</span><span>Day 7</span><span>Day 14</span><span>Day 21</span><span>Day 28</span>
|
| 137 |
+
</div>
|
| 138 |
+
</div>
|
| 139 |
+
<div id="time-display" class="digital-clock" style="min-width: 170px;">Day 01 - 00:00</div>
|
| 140 |
+
</div>
|
| 141 |
+
|
| 142 |
+
<div id="map"></div>
|
| 143 |
+
</div>
|
| 144 |
+
|
| 145 |
+
<script src="script.js"></script>
|
| 146 |
+
</body>
|
| 147 |
+
</html>
|
multimodal_spatial_encoder_v4.py
ADDED
|
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Modal Spatial Context Encoder (V4)
|
| 3 |
+
=========================================
|
| 4 |
+
|
| 5 |
+
Fuses POI features and satellite imagery into a unified spatial context embedding.
|
| 6 |
+
|
| 7 |
+
Key components:
|
| 8 |
+
1. POI Encoder: MLP with learnable category importance weights
|
| 9 |
+
2. Satellite Image Encoder: ResNet-18 with multi-scale features
|
| 10 |
+
3. Coordinate Encoder: Fourier features with learnable frequencies
|
| 11 |
+
4. Fusion Strategy: Cross-attention + adaptive gating
|
| 12 |
+
5. Condition Injection: FiLM/AdaGN modulation
|
| 13 |
+
6. [NEW] Auxiliary Head: Peak Hour Classification (Explicit Peak Prediction)
|
| 14 |
+
|
| 15 |
+
Author: Optimization Team
|
| 16 |
+
Date: 2026-01-21
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import numpy as np
|
| 23 |
+
from typing import Dict, Optional
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# =============================================================================
|
| 27 |
+
# POI Encoder with Learnable Importance Weights
|
| 28 |
+
# =============================================================================
|
| 29 |
+
|
| 30 |
+
class POIEncoder(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
POI encoder with learnable category importance weights.
|
| 33 |
+
|
| 34 |
+
Input: POI count/density vector [B, poi_dim]
|
| 35 |
+
Output: POI embedding [B, spatial_dim]
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, poi_dim: int = 20, spatial_dim: int = 192):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.poi_dim = poi_dim
|
| 41 |
+
self.spatial_dim = spatial_dim
|
| 42 |
+
|
| 43 |
+
# Learnable category importance weights
|
| 44 |
+
self.category_importance = nn.Parameter(torch.ones(poi_dim))
|
| 45 |
+
|
| 46 |
+
# Category token embeddings (POI-Enhancer inspired: attention-weighted semantic fusion)
|
| 47 |
+
self.category_embed = nn.Embedding(poi_dim, spatial_dim)
|
| 48 |
+
|
| 49 |
+
# Deep encoder with residual connections
|
| 50 |
+
self.encoder = nn.Sequential(
|
| 51 |
+
nn.Linear(poi_dim, 256),
|
| 52 |
+
nn.GELU(),
|
| 53 |
+
nn.Dropout(0.1),
|
| 54 |
+
nn.Linear(256, 256),
|
| 55 |
+
nn.GELU(),
|
| 56 |
+
nn.LayerNorm(256),
|
| 57 |
+
nn.Linear(256, spatial_dim),
|
| 58 |
+
nn.LayerNorm(spatial_dim),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Attention pooling over category tokens
|
| 62 |
+
self.token_attn = nn.Sequential(
|
| 63 |
+
nn.Linear(spatial_dim, 128),
|
| 64 |
+
nn.GELU(),
|
| 65 |
+
nn.Linear(128, 1),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Gate between MLP vector and token-pooled vector
|
| 69 |
+
self.fuse_gate = nn.Sequential(
|
| 70 |
+
nn.Linear(spatial_dim * 2, 1),
|
| 71 |
+
nn.Sigmoid(),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, poi_dist: torch.Tensor, return_tokens: bool = False):
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
poi_dist: [B, poi_dim] POI distribution
|
| 78 |
+
Returns:
|
| 79 |
+
features: [B, spatial_dim] POI embedding
|
| 80 |
+
"""
|
| 81 |
+
# Apply learnable importance weights
|
| 82 |
+
weights = F.softmax(self.category_importance, dim=0)
|
| 83 |
+
weighted_poi = poi_dist * weights
|
| 84 |
+
|
| 85 |
+
# Log transform for count data (handles skewed distributions)
|
| 86 |
+
poi_log = torch.log1p(weighted_poi)
|
| 87 |
+
|
| 88 |
+
# (1) Global vector via MLP
|
| 89 |
+
features_mlp = self.encoder(poi_log)
|
| 90 |
+
|
| 91 |
+
# (2) Category tokens + attention pooling (attention score-weighted merging)
|
| 92 |
+
# token_scale: [B, poi_dim, 1]
|
| 93 |
+
token_scale = poi_log.unsqueeze(-1)
|
| 94 |
+
# tokens: [B, poi_dim, D]
|
| 95 |
+
tokens = token_scale * self.category_embed.weight.unsqueeze(0)
|
| 96 |
+
attn_logits = self.token_attn(tokens).squeeze(-1) # [B, poi_dim]
|
| 97 |
+
attn = F.softmax(attn_logits, dim=-1).unsqueeze(-1) # [B, poi_dim, 1]
|
| 98 |
+
features_tok = (tokens * attn).sum(dim=1) # [B, D]
|
| 99 |
+
|
| 100 |
+
# Combine (learned trade-off)
|
| 101 |
+
g = self.fuse_gate(torch.cat([features_mlp, features_tok], dim=-1)) # [B, 1]
|
| 102 |
+
features = g * features_mlp + (1.0 - g) * features_tok
|
| 103 |
+
|
| 104 |
+
if return_tokens:
|
| 105 |
+
return features, tokens
|
| 106 |
+
return features
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# =============================================================================
|
| 110 |
+
# Satellite Image Encoder (ResNet-18 backbone)
|
| 111 |
+
# =============================================================================
|
| 112 |
+
|
| 113 |
+
class ResidualBlock(nn.Module):
|
| 114 |
+
"""Basic residual block for ResNet."""
|
| 115 |
+
|
| 116 |
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
|
| 119 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 120 |
+
self.relu = nn.ReLU(inplace=True)
|
| 121 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
| 122 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 123 |
+
|
| 124 |
+
self.stride = stride
|
| 125 |
+
if stride != 1 or in_channels != out_channels:
|
| 126 |
+
self.shortcut = nn.Sequential(
|
| 127 |
+
nn.Conv2d(in_channels, out_channels, 1, stride=stride),
|
| 128 |
+
nn.BatchNorm2d(out_channels),
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
self.shortcut = None
|
| 132 |
+
|
| 133 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 134 |
+
identity = x
|
| 135 |
+
out = self.conv1(x)
|
| 136 |
+
out = self.bn1(out)
|
| 137 |
+
out = self.relu(out)
|
| 138 |
+
|
| 139 |
+
out = self.conv2(out)
|
| 140 |
+
out = self.bn2(out)
|
| 141 |
+
|
| 142 |
+
if self.shortcut is not None:
|
| 143 |
+
identity = self.shortcut(x)
|
| 144 |
+
|
| 145 |
+
out = out + identity
|
| 146 |
+
out = self.relu(out)
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SatelliteImageEncoder(nn.Module):
|
| 151 |
+
"""
|
| 152 |
+
ResNet-18-based satellite image encoder with multi-scale feature extraction.
|
| 153 |
+
|
| 154 |
+
Input: Satellite image [B, 3, 64, 64]
|
| 155 |
+
Output: Image embedding [B, spatial_dim]
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, spatial_dim: int = 192, n_heads: int = 8, token_layers: int = 2):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.spatial_dim = spatial_dim
|
| 161 |
+
|
| 162 |
+
# Initial layer
|
| 163 |
+
self.conv1 = nn.Sequential(
|
| 164 |
+
nn.Conv2d(3, 64, 7, stride=2, padding=3),
|
| 165 |
+
nn.BatchNorm2d(64),
|
| 166 |
+
nn.ReLU(inplace=True),
|
| 167 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# ResNet blocks
|
| 171 |
+
self.layer1 = self._make_layer(64, 64, 2, stride=1)
|
| 172 |
+
self.layer2 = self._make_layer(64, 128, 2, stride=2)
|
| 173 |
+
self.layer3 = self._make_layer(128, 256, 2, stride=2)
|
| 174 |
+
self.layer4 = self._make_layer(256, 512, 2, stride=2)
|
| 175 |
+
|
| 176 |
+
# Multi-scale feature aggregation
|
| 177 |
+
self.pool1 = nn.AdaptiveAvgPool2d(1)
|
| 178 |
+
self.pool2 = nn.AdaptiveAvgPool2d(1)
|
| 179 |
+
self.pool3 = nn.AdaptiveAvgPool2d(1)
|
| 180 |
+
self.pool4 = nn.AdaptiveAvgPool2d(1)
|
| 181 |
+
|
| 182 |
+
# Learnable scale weights
|
| 183 |
+
self.scale_weights = nn.Parameter(torch.tensor([1.0, 1.0, 1.0, 1.0]))
|
| 184 |
+
|
| 185 |
+
# Final projection
|
| 186 |
+
self.proj = nn.Sequential(
|
| 187 |
+
nn.Linear(64 + 128 + 256 + 512, 384),
|
| 188 |
+
nn.GELU(),
|
| 189 |
+
nn.Dropout(0.1),
|
| 190 |
+
nn.Linear(384, spatial_dim),
|
| 191 |
+
nn.LayerNorm(spatial_dim),
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Region-level tokens (RemoteCLIP-inspired: patch/region awareness)
|
| 195 |
+
self.token_proj3 = nn.Linear(256, spatial_dim)
|
| 196 |
+
self.token_proj4 = nn.Linear(512, spatial_dim)
|
| 197 |
+
self.img_cls = nn.Parameter(torch.zeros(1, 1, spatial_dim))
|
| 198 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 199 |
+
d_model=spatial_dim,
|
| 200 |
+
nhead=n_heads,
|
| 201 |
+
dim_feedforward=spatial_dim * 4,
|
| 202 |
+
dropout=0.1,
|
| 203 |
+
activation="gelu",
|
| 204 |
+
batch_first=True,
|
| 205 |
+
norm_first=True,
|
| 206 |
+
)
|
| 207 |
+
self.token_mixer = nn.TransformerEncoder(enc_layer, num_layers=int(token_layers))
|
| 208 |
+
|
| 209 |
+
def _make_layer(self, in_channels: int, out_channels: int, blocks: int, stride: int):
|
| 210 |
+
layers = []
|
| 211 |
+
layers.append(ResidualBlock(in_channels, out_channels, stride))
|
| 212 |
+
for _ in range(1, blocks):
|
| 213 |
+
layers.append(ResidualBlock(out_channels, out_channels, 1))
|
| 214 |
+
return nn.Sequential(*layers)
|
| 215 |
+
|
| 216 |
+
def forward(self, x: torch.Tensor, return_tokens: bool = False):
|
| 217 |
+
"""
|
| 218 |
+
Args:
|
| 219 |
+
x: [B, 3, 64, 64] satellite image
|
| 220 |
+
Returns:
|
| 221 |
+
features: [B, spatial_dim] image embedding
|
| 222 |
+
"""
|
| 223 |
+
x = self.conv1(x) # [B, 64, 16, 16]
|
| 224 |
+
|
| 225 |
+
x1 = self.layer1(x) # [B, 64, 16, 16]
|
| 226 |
+
x2 = self.layer2(x1) # [B, 128, 8, 8]
|
| 227 |
+
x3 = self.layer3(x2) # [B, 256, 4, 4]
|
| 228 |
+
x4 = self.layer4(x3) # [B, 512, 2, 2]
|
| 229 |
+
|
| 230 |
+
# Multi-scale pooling
|
| 231 |
+
f1 = self.pool1(x1).flatten(1) # [B, 64]
|
| 232 |
+
f2 = self.pool2(x2).flatten(1) # [B, 128]
|
| 233 |
+
f3 = self.pool3(x3).flatten(1) # [B, 256]
|
| 234 |
+
f4 = self.pool4(x4).flatten(1) # [B, 512]
|
| 235 |
+
|
| 236 |
+
# Weighted fusion
|
| 237 |
+
weights = F.softmax(self.scale_weights, dim=0)
|
| 238 |
+
fused = torch.cat([
|
| 239 |
+
f1 * weights[0],
|
| 240 |
+
f2 * weights[1],
|
| 241 |
+
f3 * weights[2],
|
| 242 |
+
f4 * weights[3],
|
| 243 |
+
], dim=-1)
|
| 244 |
+
|
| 245 |
+
# Final projection
|
| 246 |
+
features = self.proj(fused)
|
| 247 |
+
|
| 248 |
+
if not return_tokens:
|
| 249 |
+
return features
|
| 250 |
+
|
| 251 |
+
# Build region tokens from intermediate feature maps (4x4 + 2x2 = 20 tokens)
|
| 252 |
+
t3 = x3.flatten(2).transpose(1, 2) # [B, 16, 256]
|
| 253 |
+
t4 = x4.flatten(2).transpose(1, 2) # [B, 4, 512]
|
| 254 |
+
t3 = self.token_proj3(t3) # [B, 16, D]
|
| 255 |
+
t4 = self.token_proj4(t4) # [B, 4, D]
|
| 256 |
+
tokens = torch.cat([t3, t4], dim=1) # [B, 20, D]
|
| 257 |
+
|
| 258 |
+
# Mix tokens with a tiny Transformer, include a [CLS] token
|
| 259 |
+
cls = self.img_cls.expand(tokens.shape[0], -1, -1)
|
| 260 |
+
tokens_with_cls = torch.cat([cls, tokens], dim=1) # [B, 21, D]
|
| 261 |
+
tokens_with_cls = self.token_mixer(tokens_with_cls)
|
| 262 |
+
|
| 263 |
+
# tokens_with_cls[:, 0] is CLS; keep both CLS and spatial tokens
|
| 264 |
+
cls_out = tokens_with_cls[:, 0] # [B, D]
|
| 265 |
+
spatial_tokens = tokens_with_cls[:, 1:] # [B, 20, D]
|
| 266 |
+
|
| 267 |
+
# Blend CLS with pooled global feature for stability
|
| 268 |
+
feat = 0.5 * features + 0.5 * cls_out
|
| 269 |
+
return feat, spatial_tokens
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# =============================================================================
|
| 273 |
+
# Coordinate Encoder with Learnable Fourier Features
|
| 274 |
+
# =============================================================================
|
| 275 |
+
|
| 276 |
+
class CoordinateEncoder(nn.Module):
|
| 277 |
+
"""
|
| 278 |
+
Coordinate encoder with learnable Fourier frequencies.
|
| 279 |
+
|
| 280 |
+
Input: Coordinates [B, 2] (latitude, longitude)
|
| 281 |
+
Output: Coordinate embedding [B, spatial_dim]
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(self, coord_dim: int = 2, spatial_dim: int = 192):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.coord_dim = coord_dim
|
| 287 |
+
self.spatial_dim = spatial_dim
|
| 288 |
+
|
| 289 |
+
# Multi-scale learnable Fourier frequencies
|
| 290 |
+
n_freqs = 64
|
| 291 |
+
init_freqs = 2 ** torch.linspace(0, 8, n_freqs)
|
| 292 |
+
self.freqs = nn.Parameter(init_freqs)
|
| 293 |
+
|
| 294 |
+
fourier_dim = coord_dim * n_freqs * 2
|
| 295 |
+
|
| 296 |
+
# Deep encoder
|
| 297 |
+
self.encoder = nn.Sequential(
|
| 298 |
+
nn.Linear(fourier_dim + coord_dim, 512),
|
| 299 |
+
nn.GELU(),
|
| 300 |
+
nn.Dropout(0.1),
|
| 301 |
+
nn.Linear(512, 384),
|
| 302 |
+
nn.GELU(),
|
| 303 |
+
nn.LayerNorm(384),
|
| 304 |
+
nn.Linear(384, spatial_dim),
|
| 305 |
+
nn.LayerNorm(spatial_dim),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def forward(self, coords: torch.Tensor) -> torch.Tensor:
|
| 309 |
+
"""
|
| 310 |
+
Args:
|
| 311 |
+
coords: [B, 2] coordinates
|
| 312 |
+
Returns:
|
| 313 |
+
features: [B, spatial_dim] coordinate embedding
|
| 314 |
+
"""
|
| 315 |
+
# Fourier features with learnable frequencies
|
| 316 |
+
coords_scaled = coords.unsqueeze(-1) * self.freqs # [B, 2, n_freqs]
|
| 317 |
+
fourier = torch.cat([
|
| 318 |
+
torch.sin(coords_scaled * np.pi),
|
| 319 |
+
torch.cos(coords_scaled * np.pi),
|
| 320 |
+
], dim=-1).flatten(-2) # [B, fourier_dim]
|
| 321 |
+
|
| 322 |
+
# Combine with raw coordinates
|
| 323 |
+
combined = torch.cat([coords, fourier], dim=-1)
|
| 324 |
+
|
| 325 |
+
# Encode
|
| 326 |
+
features = self.encoder(combined)
|
| 327 |
+
|
| 328 |
+
return features
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# =============================================================================
|
| 332 |
+
# Cross-Attention Fusion Module
|
| 333 |
+
# =============================================================================
|
| 334 |
+
|
| 335 |
+
class CrossAttentionFusion(nn.Module):
|
| 336 |
+
"""
|
| 337 |
+
Cross-attention fusion for multi-modal conditioning.
|
| 338 |
+
|
| 339 |
+
Fuses POI, satellite, and coordinate embeddings via multi-head attention.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, spatial_dim: int = 192, n_heads: int = 8):
|
| 343 |
+
super().__init__()
|
| 344 |
+
self.spatial_dim = spatial_dim
|
| 345 |
+
self.n_heads = n_heads
|
| 346 |
+
|
| 347 |
+
# Two rounds of cross-attention (vector mode: 3 tokens; token mode: CLS->context)
|
| 348 |
+
self.cross_attn1 = nn.MultiheadAttention(
|
| 349 |
+
spatial_dim, num_heads=n_heads, dropout=0.1, batch_first=True
|
| 350 |
+
)
|
| 351 |
+
self.cross_attn2 = nn.MultiheadAttention(
|
| 352 |
+
spatial_dim, num_heads=n_heads, dropout=0.1, batch_first=True
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Layer norms
|
| 356 |
+
self.norm1 = nn.LayerNorm(spatial_dim)
|
| 357 |
+
self.norm2 = nn.LayerNorm(spatial_dim)
|
| 358 |
+
self.norm3 = nn.LayerNorm(spatial_dim)
|
| 359 |
+
self.norm4 = nn.LayerNorm(spatial_dim)
|
| 360 |
+
|
| 361 |
+
# Feed-forward networks
|
| 362 |
+
self.ffn1 = nn.Sequential(
|
| 363 |
+
nn.Linear(spatial_dim, spatial_dim * 4),
|
| 364 |
+
nn.GELU(),
|
| 365 |
+
nn.Dropout(0.1),
|
| 366 |
+
nn.Linear(spatial_dim * 4, spatial_dim),
|
| 367 |
+
)
|
| 368 |
+
self.ffn2 = nn.Sequential(
|
| 369 |
+
nn.Linear(spatial_dim, spatial_dim * 4),
|
| 370 |
+
nn.GELU(),
|
| 371 |
+
nn.Dropout(0.1),
|
| 372 |
+
nn.Linear(spatial_dim * 4, spatial_dim),
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Adaptive gating for modality importance
|
| 376 |
+
self.gate = nn.Sequential(
|
| 377 |
+
nn.Linear(spatial_dim * 3, 256),
|
| 378 |
+
nn.GELU(),
|
| 379 |
+
nn.Linear(256, 3),
|
| 380 |
+
nn.Softmax(dim=-1),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Token-mode: learnable fusion token
|
| 384 |
+
self.fusion_cls = nn.Parameter(torch.zeros(1, 1, spatial_dim))
|
| 385 |
+
self.token_out_gate = nn.Sequential(
|
| 386 |
+
nn.Linear(spatial_dim * 2, 1),
|
| 387 |
+
nn.Sigmoid(),
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def forward(
|
| 391 |
+
self,
|
| 392 |
+
sat_feat: torch.Tensor,
|
| 393 |
+
poi_feat: torch.Tensor,
|
| 394 |
+
coord_feat: torch.Tensor,
|
| 395 |
+
sat_tokens: Optional[torch.Tensor] = None,
|
| 396 |
+
poi_tokens: Optional[torch.Tensor] = None,
|
| 397 |
+
coord_token: Optional[torch.Tensor] = None,
|
| 398 |
+
) -> torch.Tensor:
|
| 399 |
+
"""
|
| 400 |
+
Args:
|
| 401 |
+
sat_feat: [B, spatial_dim] satellite embedding
|
| 402 |
+
poi_feat: [B, spatial_dim] POI embedding
|
| 403 |
+
coord_feat: [B, spatial_dim] coordinate embedding
|
| 404 |
+
Returns:
|
| 405 |
+
fused: [B, spatial_dim] fused embedding
|
| 406 |
+
"""
|
| 407 |
+
# ---------------------------------------------------------------------
|
| 408 |
+
# (A) Vector mode (backward-compatible): treat each modality as 1 token.
|
| 409 |
+
# ---------------------------------------------------------------------
|
| 410 |
+
if sat_tokens is None and poi_tokens is None and coord_token is None:
|
| 411 |
+
# Stack as sequence [B, 3, D]
|
| 412 |
+
modalities = torch.stack([sat_feat, poi_feat, coord_feat], dim=1)
|
| 413 |
+
|
| 414 |
+
# First round of cross-attention
|
| 415 |
+
attn_out1, _ = self.cross_attn1(modalities, modalities, modalities)
|
| 416 |
+
modalities = self.norm1(modalities + attn_out1)
|
| 417 |
+
ffn_out1 = self.ffn1(modalities)
|
| 418 |
+
modalities = self.norm2(modalities + ffn_out1)
|
| 419 |
+
|
| 420 |
+
# Second round
|
| 421 |
+
attn_out2, _ = self.cross_attn2(modalities, modalities, modalities)
|
| 422 |
+
modalities = self.norm3(modalities + attn_out2)
|
| 423 |
+
ffn_out2 = self.ffn2(modalities)
|
| 424 |
+
modalities = self.norm4(modalities + ffn_out2)
|
| 425 |
+
|
| 426 |
+
# Unpack
|
| 427 |
+
sat_out, poi_out, coord_out = modalities.unbind(dim=1)
|
| 428 |
+
|
| 429 |
+
# Adaptive gating
|
| 430 |
+
concat = torch.cat([sat_out, poi_out, coord_out], dim=-1)
|
| 431 |
+
weights = self.gate(concat) # [B, 3]
|
| 432 |
+
|
| 433 |
+
# Weighted fusion
|
| 434 |
+
fused = (
|
| 435 |
+
weights[:, 0:1] * sat_out +
|
| 436 |
+
weights[:, 1:2] * poi_out +
|
| 437 |
+
weights[:, 2:3] * coord_out
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
return fused
|
| 441 |
+
|
| 442 |
+
# ---------------------------------------------------------------------
|
| 443 |
+
# (B) Token mode: CLS attends over (sat tokens + poi tokens + coord token).
|
| 444 |
+
# RemoteCLIP-inspired region tokens + POI-Enhancer-inspired semantic tokens.
|
| 445 |
+
# ---------------------------------------------------------------------
|
| 446 |
+
B = sat_feat.shape[0]
|
| 447 |
+
context = []
|
| 448 |
+
if sat_tokens is not None:
|
| 449 |
+
context.append(sat_tokens)
|
| 450 |
+
else:
|
| 451 |
+
context.append(sat_feat.unsqueeze(1))
|
| 452 |
+
|
| 453 |
+
if poi_tokens is not None:
|
| 454 |
+
context.append(poi_tokens)
|
| 455 |
+
else:
|
| 456 |
+
context.append(poi_feat.unsqueeze(1))
|
| 457 |
+
|
| 458 |
+
if coord_token is not None:
|
| 459 |
+
context.append(coord_token.unsqueeze(1))
|
| 460 |
+
else:
|
| 461 |
+
context.append(coord_feat.unsqueeze(1))
|
| 462 |
+
|
| 463 |
+
context_tokens = torch.cat(context, dim=1) # [B, L, D]
|
| 464 |
+
cls = self.fusion_cls.expand(B, -1, -1) # [B, 1, D]
|
| 465 |
+
|
| 466 |
+
# Two rounds of CLS->context attention + FFN (Transformer-like)
|
| 467 |
+
attn1, _ = self.cross_attn1(cls, context_tokens, context_tokens)
|
| 468 |
+
cls = self.norm1(cls + attn1)
|
| 469 |
+
cls = self.norm2(cls + self.ffn1(cls))
|
| 470 |
+
|
| 471 |
+
attn2, _ = self.cross_attn2(cls, context_tokens, context_tokens)
|
| 472 |
+
cls = self.norm3(cls + attn2)
|
| 473 |
+
cls = self.norm4(cls + self.ffn2(cls))
|
| 474 |
+
|
| 475 |
+
cls_vec = cls.squeeze(1) # [B, D]
|
| 476 |
+
|
| 477 |
+
# Keep the original adaptive gating as a global shortcut, then learn to mix.
|
| 478 |
+
concat = torch.cat([sat_feat, poi_feat, coord_feat], dim=-1)
|
| 479 |
+
weights = self.gate(concat)
|
| 480 |
+
gated = (
|
| 481 |
+
weights[:, 0:1] * sat_feat +
|
| 482 |
+
weights[:, 1:2] * poi_feat +
|
| 483 |
+
weights[:, 2:3] * coord_feat
|
| 484 |
+
)
|
| 485 |
+
mix = self.token_out_gate(torch.cat([cls_vec, gated], dim=-1)) # [B, 1]
|
| 486 |
+
fused = mix * cls_vec + (1.0 - mix) * gated
|
| 487 |
+
return fused
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
# =============================================================================
|
| 491 |
+
# Multi-Scale Condition Generator
|
| 492 |
+
# =============================================================================
|
| 493 |
+
|
| 494 |
+
class MultiScaleConditionGenerator(nn.Module):
|
| 495 |
+
"""
|
| 496 |
+
Generate stage-specific multi-scale conditions.
|
| 497 |
+
|
| 498 |
+
Produces different condition embeddings for each hierarchical level.
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
def __init__(self, spatial_dim: int = 192):
|
| 502 |
+
super().__init__()
|
| 503 |
+
|
| 504 |
+
# Level 1 (daily): global patterns
|
| 505 |
+
self.level1_proj = nn.Sequential(
|
| 506 |
+
nn.Linear(spatial_dim, 256),
|
| 507 |
+
nn.GELU(),
|
| 508 |
+
nn.Dropout(0.1),
|
| 509 |
+
nn.Linear(256, spatial_dim),
|
| 510 |
+
nn.LayerNorm(spatial_dim),
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Level 2 (weekly): periodic structure
|
| 514 |
+
self.level2_proj = nn.Sequential(
|
| 515 |
+
nn.Linear(spatial_dim, 256),
|
| 516 |
+
nn.GELU(),
|
| 517 |
+
nn.Dropout(0.1),
|
| 518 |
+
nn.Linear(256, spatial_dim),
|
| 519 |
+
nn.LayerNorm(spatial_dim),
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Level 3 (residual): fine details
|
| 523 |
+
self.level3_proj = nn.Sequential(
|
| 524 |
+
nn.Linear(spatial_dim, 384),
|
| 525 |
+
nn.GELU(),
|
| 526 |
+
nn.Dropout(0.1),
|
| 527 |
+
nn.Linear(384, spatial_dim),
|
| 528 |
+
nn.LayerNorm(spatial_dim),
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
def forward(self, base_condition: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 532 |
+
"""Generate stage-specific conditions."""
|
| 533 |
+
return {
|
| 534 |
+
'level1_cond': self.level1_proj(base_condition),
|
| 535 |
+
'level2_cond': self.level2_proj(base_condition),
|
| 536 |
+
'level3_cond': self.level3_proj(base_condition),
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
# =============================================================================
|
| 541 |
+
# Complete Multi-Modal Spatial Encoder
|
| 542 |
+
# =============================================================================
|
| 543 |
+
|
| 544 |
+
class MultiModalSpatialEncoderV4(nn.Module):
|
| 545 |
+
"""
|
| 546 |
+
Complete multi-modal spatial encoder combining:
|
| 547 |
+
- POI features
|
| 548 |
+
- Satellite imagery
|
| 549 |
+
- Geographic coordinates
|
| 550 |
+
- Cross-attention fusion
|
| 551 |
+
- Multi-scale condition generation
|
| 552 |
+
- [NEW] Auxiliary Peak Hour Classification
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
def __init__(self, spatial_dim: int = 192, poi_dim: int = 20):
|
| 556 |
+
super().__init__()
|
| 557 |
+
self.spatial_dim = spatial_dim
|
| 558 |
+
self.poi_dim = poi_dim
|
| 559 |
+
|
| 560 |
+
# Individual encoders
|
| 561 |
+
self.poi_encoder = POIEncoder(poi_dim, spatial_dim)
|
| 562 |
+
self.satellite_encoder = SatelliteImageEncoder(spatial_dim)
|
| 563 |
+
self.coord_encoder = CoordinateEncoder(2, spatial_dim)
|
| 564 |
+
|
| 565 |
+
# Multi-modal fusion
|
| 566 |
+
self.fusion = CrossAttentionFusion(spatial_dim, n_heads=8)
|
| 567 |
+
|
| 568 |
+
# Multi-scale condition generation
|
| 569 |
+
self.multiscale_generator = MultiScaleConditionGenerator(spatial_dim)
|
| 570 |
+
|
| 571 |
+
# [NEW] Auxiliary Head: Peak Hour Prediction
|
| 572 |
+
# Predicts which hour (0-23) has the maximum traffic
|
| 573 |
+
self.peak_hour_classifier = nn.Sequential(
|
| 574 |
+
nn.Linear(spatial_dim, 128),
|
| 575 |
+
nn.GELU(),
|
| 576 |
+
nn.Dropout(0.1),
|
| 577 |
+
nn.Linear(128, 24) # 24 hours classification
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 581 |
+
"""
|
| 582 |
+
Args:
|
| 583 |
+
batch: dict with keys:
|
| 584 |
+
- 'satellite_img': [B, 3, 64, 64]
|
| 585 |
+
- 'poi_dist': [B, poi_dim]
|
| 586 |
+
- 'coords': [B, 2]
|
| 587 |
+
Returns:
|
| 588 |
+
outputs: dict with conditions and predicted peak logits
|
| 589 |
+
"""
|
| 590 |
+
# Encode each modality (token-aware for stronger multi-modal fusion)
|
| 591 |
+
sat_feat, sat_tokens = self.satellite_encoder(batch['satellite_img'], return_tokens=True)
|
| 592 |
+
poi_feat, poi_tokens = self.poi_encoder(batch['poi_dist'], return_tokens=True)
|
| 593 |
+
coord_feat = self.coord_encoder(batch['coords'])
|
| 594 |
+
|
| 595 |
+
# Fuse modalities (CLS attends to region tokens + semantic tokens)
|
| 596 |
+
base_condition = self.fusion(
|
| 597 |
+
sat_feat,
|
| 598 |
+
poi_feat,
|
| 599 |
+
coord_feat,
|
| 600 |
+
sat_tokens=sat_tokens,
|
| 601 |
+
poi_tokens=poi_tokens,
|
| 602 |
+
coord_token=coord_feat,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Generate multi-scale conditions
|
| 606 |
+
stage_conditions = self.multiscale_generator(base_condition)
|
| 607 |
+
|
| 608 |
+
# [NEW] Predict peak hour
|
| 609 |
+
pred_peak_logits = self.peak_hour_classifier(base_condition)
|
| 610 |
+
|
| 611 |
+
outputs = {
|
| 612 |
+
'base_condition': base_condition,
|
| 613 |
+
'pred_peak_logits': pred_peak_logits, # Auxiliary output
|
| 614 |
+
**stage_conditions,
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
return outputs
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
if __name__ == "__main__":
|
| 621 |
+
# Test the encoder
|
| 622 |
+
B = 4
|
| 623 |
+
spatial_dim = 192
|
| 624 |
+
poi_dim = 20
|
| 625 |
+
|
| 626 |
+
encoder = MultiModalSpatialEncoderV4(spatial_dim, poi_dim)
|
| 627 |
+
|
| 628 |
+
# Create dummy batch
|
| 629 |
+
batch = {
|
| 630 |
+
'satellite_img': torch.randn(B, 3, 64, 64),
|
| 631 |
+
'poi_dist': torch.randn(B, poi_dim),
|
| 632 |
+
'coords': torch.randn(B, 2),
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
# Forward pass
|
| 636 |
+
outputs = encoder(batch)
|
| 637 |
+
|
| 638 |
+
print("Multi-Modal Spatial Encoder V4 Test:")
|
| 639 |
+
print(f" Base condition shape: {outputs['base_condition'].shape}")
|
| 640 |
+
print(f" Peak Logits shape: {outputs['pred_peak_logits'].shape}")
|
| 641 |
+
print(f" Level 1 condition shape: {outputs['level1_cond'].shape}")
|
| 642 |
+
print(f" Level 2 condition shape: {outputs['level2_cond'].shape}")
|
| 643 |
+
print(f" Level 3 condition shape: {outputs['level3_cond'].shape}")
|
| 644 |
+
|
| 645 |
+
print("\nEncoder test passed!")
|
prediction_backend.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import requests
|
| 6 |
+
import random
|
| 7 |
+
import base64
|
| 8 |
+
import matplotlib
|
| 9 |
+
matplotlib.use('Agg')
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
# Import V4 Model System
|
| 16 |
+
from hierarchical_flow_matching_training_v4 import HierarchicalFlowMatchingSystemV4
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# Config
|
| 20 |
+
# =============================================================================
|
| 21 |
+
MAPBOX_ACCESS_TOKEN = "pk.eyJ1IjoieXlhaXl5IiwiYSI6ImNtaTVpMTVlaTJmdzMybW9zcmFieGxpdHUifQ.181d6E5fzLw1CEZMEPU53Q"
|
| 22 |
+
MAPBOX_ZOOM = 15
|
| 23 |
+
FETCH_SIZE = 256
|
| 24 |
+
IMAGE_SIZE = 64
|
| 25 |
+
SEED = 42
|
| 26 |
+
|
| 27 |
+
SPATIAL_DIM = 192
|
| 28 |
+
HIDDEN_DIM = 256
|
| 29 |
+
POI_DIM = 20
|
| 30 |
+
N_LAYERS_LEVEL3 = 6
|
| 31 |
+
N_STEPS = 50
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MapboxSatelliteFetcher:
|
| 35 |
+
"""
|
| 36 |
+
Dynamically fetches satellite imagery, strictly aligning with
|
| 37 |
+
the image preprocessing logic used during training.
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, access_token=MAPBOX_ACCESS_TOKEN, zoom=MAPBOX_ZOOM, fetch_size=FETCH_SIZE, target_size=IMAGE_SIZE):
|
| 40 |
+
self.access_token = access_token
|
| 41 |
+
self.zoom = zoom
|
| 42 |
+
self.fetch_size = fetch_size
|
| 43 |
+
self.target_size = target_size
|
| 44 |
+
|
| 45 |
+
def fetch(self, lon, lat, station_id=None, return_pil=False):
|
| 46 |
+
""" Fetches static satellite map centered at [lon, lat] """
|
| 47 |
+
url = f"https://api.mapbox.com/styles/v1/mapbox/satellite-v9/static/{lon},{lat},{self.zoom},0,0/{self.fetch_size}x{self.fetch_size}?access_token={self.access_token}"
|
| 48 |
+
try:
|
| 49 |
+
response = requests.get(url, timeout=10)
|
| 50 |
+
response.raise_for_status()
|
| 51 |
+
|
| 52 |
+
img = Image.open(BytesIO(response.content)).convert("RGB")
|
| 53 |
+
original_pil = img.copy() # Store high-res original for micro-grid slicing
|
| 54 |
+
|
| 55 |
+
# Resize to model input size (64x64)
|
| 56 |
+
img_resized = img.resize((self.target_size, self.target_size), Image.BILINEAR)
|
| 57 |
+
arr = np.asarray(img_resized, dtype=np.float32) / 255.0
|
| 58 |
+
|
| 59 |
+
# Convert HWC to CHW format
|
| 60 |
+
chw = arr.transpose(2, 0, 1)
|
| 61 |
+
tensor_np = np.clip(chw, 0.0, 1.0).astype(np.float32, copy=False)
|
| 62 |
+
|
| 63 |
+
if return_pil:
|
| 64 |
+
return tensor_np, original_pil
|
| 65 |
+
return tensor_np
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"[Mapbox Fetcher Error] Station {station_id}: {e}")
|
| 69 |
+
fallback_np = np.zeros((3, self.target_size, self.target_size), dtype=np.float32)
|
| 70 |
+
if return_pil:
|
| 71 |
+
return fallback_np, Image.new('RGB', (self.fetch_size, self.fetch_size), color='black')
|
| 72 |
+
return fallback_np
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TrafficPredictor:
|
| 76 |
+
"""Traffic generation model predictor and site selection analyzer"""
|
| 77 |
+
def __init__(self, model_path, spatial_path, traffic_path, local_sat_dir="real_spatial_data/satellite_png", device=None):
|
| 78 |
+
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 79 |
+
self.local_sat_dir = local_sat_dir
|
| 80 |
+
|
| 81 |
+
# Load spatial features cache
|
| 82 |
+
if os.path.exists(spatial_path):
|
| 83 |
+
self.data_cache = np.load(spatial_path, allow_pickle=True)
|
| 84 |
+
else:
|
| 85 |
+
raise FileNotFoundError(f"Spatial path {spatial_path} not found!")
|
| 86 |
+
|
| 87 |
+
# Load traffic records for validation comparison
|
| 88 |
+
if os.path.exists(traffic_path):
|
| 89 |
+
self.traffic_data = np.load(traffic_path, allow_pickle=True)['bs_record']
|
| 90 |
+
else:
|
| 91 |
+
self.traffic_data = None
|
| 92 |
+
|
| 93 |
+
# ==========================================
|
| 94 |
+
# Strictly simulate Dataset length truncation to ensure
|
| 95 |
+
# normalization extrema are 100% aligned with training.
|
| 96 |
+
# ==========================================
|
| 97 |
+
n_traffic = len(self.traffic_data) if self.traffic_data is not None else float('inf')
|
| 98 |
+
n_poi = len(self.data_cache['poi_distributions'])
|
| 99 |
+
n_coords = len(self.data_cache['coordinates'])
|
| 100 |
+
self.n_valid = min(n_traffic, n_poi, n_coords)
|
| 101 |
+
|
| 102 |
+
# Calculate coordinate bounds for normalization
|
| 103 |
+
raw_coords_valid = self.data_cache['coordinates'][:self.n_valid].astype(np.float32)
|
| 104 |
+
self.coord_min = raw_coords_valid.min(axis=0)
|
| 105 |
+
self.coord_max = raw_coords_valid.max(axis=0)
|
| 106 |
+
|
| 107 |
+
self.satellite_fetcher = MapboxSatelliteFetcher()
|
| 108 |
+
self.model = self._load_model(model_path)
|
| 109 |
+
|
| 110 |
+
def _load_model(self, model_path):
|
| 111 |
+
print(f"Loading V4 Model on {self.device}...")
|
| 112 |
+
model = HierarchicalFlowMatchingSystemV4(
|
| 113 |
+
spatial_dim=SPATIAL_DIM,
|
| 114 |
+
hidden_dim=HIDDEN_DIM,
|
| 115 |
+
poi_dim=POI_DIM,
|
| 116 |
+
n_layers_level3=N_LAYERS_LEVEL3
|
| 117 |
+
).to(self.device)
|
| 118 |
+
|
| 119 |
+
if os.path.exists(model_path):
|
| 120 |
+
ckpt = torch.load(model_path, map_location=self.device, weights_only=False)
|
| 121 |
+
state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
|
| 122 |
+
model.load_state_dict(state_dict)
|
| 123 |
+
print(f"Checkpoint loaded successfully from {model_path}.")
|
| 124 |
+
else:
|
| 125 |
+
raise FileNotFoundError(f"Model checkpoint not found at {model_path}")
|
| 126 |
+
|
| 127 |
+
model.eval()
|
| 128 |
+
return model
|
| 129 |
+
|
| 130 |
+
def predict(self, idx, use_local_img_for_debug=False):
|
| 131 |
+
"""Generates traffic predictions and performs LSI spatial analysis"""
|
| 132 |
+
try:
|
| 133 |
+
torch.manual_seed(SEED)
|
| 134 |
+
idx = int(idx)
|
| 135 |
+
|
| 136 |
+
# 1. Process POI distributions
|
| 137 |
+
raw_poi = self.data_cache['poi_distributions'][idx].astype(np.float32).copy()
|
| 138 |
+
raw_poi = np.clip(raw_poi, 0.0, None)
|
| 139 |
+
poi_sum = float(raw_poi.sum())
|
| 140 |
+
if poi_sum > 1e-8:
|
| 141 |
+
raw_poi = raw_poi / poi_sum
|
| 142 |
+
else:
|
| 143 |
+
raw_poi = np.zeros_like(raw_poi)
|
| 144 |
+
poi_tensor = torch.from_numpy(raw_poi).unsqueeze(0).to(self.device)
|
| 145 |
+
|
| 146 |
+
# 2. Process Geographical Coordinates
|
| 147 |
+
raw_loc = self.data_cache['coordinates'][idx].astype(np.float32)
|
| 148 |
+
lon, lat = raw_loc[0], raw_loc[1]
|
| 149 |
+
|
| 150 |
+
# Fetch Satellite Image (Local Debug vs. Remote API)
|
| 151 |
+
if use_local_img_for_debug and os.path.exists(f"{self.local_sat_dir}/{idx}.png"):
|
| 152 |
+
img = Image.open(f"{self.local_sat_dir}/{idx}.png").convert("RGB")
|
| 153 |
+
original_pil = img.copy()
|
| 154 |
+
else:
|
| 155 |
+
_, original_pil = self.satellite_fetcher.fetch(lon, lat, station_id=str(idx), return_pil=True)
|
| 156 |
+
|
| 157 |
+
# 3. LSI Heatmap Generation (Location Stability Index)
|
| 158 |
+
lsi_grid, best_idx, best_traffic = self.generate_lsi_heatmap(original_pil, lat, lon, poi_tensor, grid_size=3)
|
| 159 |
+
site_map_b64 = self.create_site_map_base64(original_pil, lsi_grid, best_idx)
|
| 160 |
+
|
| 161 |
+
# 4. Map Projections: Convert grid indices back to physical Lat/Lon
|
| 162 |
+
grid_size = 3
|
| 163 |
+
best_row, best_col = best_idx
|
| 164 |
+
|
| 165 |
+
# Calculate spans based on Web Mercator projection at specific zoom
|
| 166 |
+
lon_span = 360.0 / (2 ** MAPBOX_ZOOM)
|
| 167 |
+
lat_span = lon_span * math.cos(math.radians(lat))
|
| 168 |
+
step_lon = lon_span / grid_size
|
| 169 |
+
step_lat = lat_span / grid_size
|
| 170 |
+
|
| 171 |
+
# Offset from base center
|
| 172 |
+
best_lat = lat - (best_row - grid_size // 2) * step_lat
|
| 173 |
+
best_lon = lon + (best_col - grid_size // 2) * step_lon
|
| 174 |
+
best_loc = [float(best_lon), float(best_lat)]
|
| 175 |
+
|
| 176 |
+
# 5. Multidimensional NLG (Natural Language Generation) Engine
|
| 177 |
+
best_lsi_value = float(lsi_grid[best_idx])
|
| 178 |
+
avg_lsi_value = float(np.mean(lsi_grid))
|
| 179 |
+
min_lsi_value = float(np.min(lsi_grid))
|
| 180 |
+
|
| 181 |
+
# Core Performance Metrics
|
| 182 |
+
improvement_avg = ((best_lsi_value - avg_lsi_value) / avg_lsi_value) * 100 if avg_lsi_value > 0 else 0
|
| 183 |
+
spatial_contrast = ((best_lsi_value - min_lsi_value) / min_lsi_value) * 100 if min_lsi_value > 0 else 0
|
| 184 |
+
|
| 185 |
+
# Feature A: POI Semantic Mapping
|
| 186 |
+
poi_idx = int(torch.argmax(poi_tensor[0]))
|
| 187 |
+
poi_categories = [
|
| 188 |
+
"Commercial/Retail", "Residential Complex", "Transit Hub",
|
| 189 |
+
"Corporate/Office", "Public/Recreational", "Industrial Zone",
|
| 190 |
+
"Mixed-Use Urban", "Educational/Campus"
|
| 191 |
+
]
|
| 192 |
+
dominant_poi = poi_categories[poi_idx % len(poi_categories)]
|
| 193 |
+
|
| 194 |
+
# Feature B: Temporal Tide Analysis (Extract daily peak from 672-hour sequence)
|
| 195 |
+
daily_pattern = best_traffic.reshape(-1, 24).mean(axis=0)
|
| 196 |
+
peak_hour = int(np.argmax(daily_pattern))
|
| 197 |
+
|
| 198 |
+
if 7 <= peak_hour <= 10:
|
| 199 |
+
peak_type = "Morning Rush (07:00-10:00)"
|
| 200 |
+
elif 16 <= peak_hour <= 19:
|
| 201 |
+
peak_type = "Evening Rush (16:00-19:00)"
|
| 202 |
+
elif 11 <= peak_hour <= 15:
|
| 203 |
+
peak_type = "Midday Active (11:00-15:00)"
|
| 204 |
+
else:
|
| 205 |
+
peak_type = "Night/Off-peak Active"
|
| 206 |
+
|
| 207 |
+
# Load description based on average volume
|
| 208 |
+
avg_load = float(np.mean(best_traffic))
|
| 209 |
+
if avg_load > 6.0: load_desc = "High-Capacity"
|
| 210 |
+
elif avg_load > 3.0: load_desc = "Moderate-Load"
|
| 211 |
+
else: load_desc = "Baseline/Sparse"
|
| 212 |
+
|
| 213 |
+
# Dynamic Text Assembly (4-Stage Structure)
|
| 214 |
+
# Stage 1: Spatial Environment Diagnosis
|
| 215 |
+
if spatial_contrast > 40:
|
| 216 |
+
p1 = f"Spatial scan detects a highly heterogeneous {dominant_poi} sector with steep traffic gradients. "
|
| 217 |
+
else:
|
| 218 |
+
p1 = f"Spatial scan indicates a relatively uniform {dominant_poi} matrix. "
|
| 219 |
+
|
| 220 |
+
# Stage 2: Temporal Characteristics
|
| 221 |
+
p2 = f"Flow Matching model projects a {load_desc} demand curve, heavily anchored by a {peak_type} signature. "
|
| 222 |
+
|
| 223 |
+
# Stage 3: Decision Output
|
| 224 |
+
p3 = f"Micro-grid ({best_row}, {best_col}) is isolated as the topological optimum, yielding a peak Location Stability Index (LSI) of {best_lsi_value:.2f}. "
|
| 225 |
+
|
| 226 |
+
# Stage 4: Business Value Assessment
|
| 227 |
+
if improvement_avg > 15:
|
| 228 |
+
p4 = f"Deploying infrastructure here intercepts peak volatility, providing a {improvement_avg:.1f}% structural stability gain over the regional average."
|
| 229 |
+
else:
|
| 230 |
+
p4 = f"This precise coordinate offers a marginal yet critical {improvement_avg:.1f}% variance reduction, ensuring optimal load-balancing."
|
| 231 |
+
|
| 232 |
+
explanation_text = p1 + p2 + p3 + p4
|
| 233 |
+
# ===============================================
|
| 234 |
+
|
| 235 |
+
# Finalize output sequence
|
| 236 |
+
gen_seq_real = np.clip(best_traffic, 0.0, 10.0)
|
| 237 |
+
real_seq = self.traffic_data[idx].tolist() if self.traffic_data is not None else []
|
| 238 |
+
|
| 239 |
+
return {
|
| 240 |
+
"station_id": idx,
|
| 241 |
+
"prediction": gen_seq_real.tolist(),
|
| 242 |
+
"real": real_seq,
|
| 243 |
+
"site_map_b64": site_map_b64,
|
| 244 |
+
"best_loc": best_loc,
|
| 245 |
+
"explanation": explanation_text,
|
| 246 |
+
"status": "success"
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
import traceback
|
| 251 |
+
traceback.print_exc()
|
| 252 |
+
return {"error": str(e), "status": "failed"}
|
| 253 |
+
|
| 254 |
+
@torch.no_grad()
|
| 255 |
+
def generate_lsi_heatmap(self, img_pil, base_lat, base_lon, poi_tensor, grid_size=3):
|
| 256 |
+
w, h = img_pil.size
|
| 257 |
+
patch_w, patch_h = w // grid_size, h // grid_size
|
| 258 |
+
|
| 259 |
+
patches = []
|
| 260 |
+
coords = []
|
| 261 |
+
|
| 262 |
+
# Calculate precise Lat/Lon spans for 256x256 image area
|
| 263 |
+
lon_span = 360.0 / (2 ** MAPBOX_ZOOM)
|
| 264 |
+
lat_span = lon_span * math.cos(math.radians(base_lat))
|
| 265 |
+
step_lon = lon_span / grid_size
|
| 266 |
+
step_lat = lat_span / grid_size
|
| 267 |
+
|
| 268 |
+
for i in range(grid_size):
|
| 269 |
+
for j in range(grid_size):
|
| 270 |
+
# Slice and resize patches for model
|
| 271 |
+
box = (j * patch_w, i * patch_h, (j+1) * patch_w, (i+1) * patch_h)
|
| 272 |
+
patch = img_pil.crop(box).resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR).convert("RGB")
|
| 273 |
+
arr = np.array(patch).transpose(2,0,1) / 255.0
|
| 274 |
+
patches.append(torch.tensor(arr, dtype=torch.float32))
|
| 275 |
+
|
| 276 |
+
# Normalize coordinates for the model
|
| 277 |
+
offset_lat = base_lat - (i - grid_size//2) * step_lat
|
| 278 |
+
offset_lon = base_lon + (j - grid_size//2) * step_lon
|
| 279 |
+
raw_coord = np.array([offset_lon, offset_lat], dtype=np.float32)
|
| 280 |
+
norm_coord = (raw_coord - self.coord_min) / (self.coord_max - self.coord_min + 1e-8)
|
| 281 |
+
coords.append(torch.tensor(norm_coord, dtype=torch.float32))
|
| 282 |
+
|
| 283 |
+
batch_size = grid_size ** 2
|
| 284 |
+
# Assemble GPU Batch
|
| 285 |
+
batch_gpu = {
|
| 286 |
+
'satellite_img': torch.stack(patches).to(self.device),
|
| 287 |
+
'poi_dist': poi_tensor.repeat(batch_size, 1),
|
| 288 |
+
'coords': torch.stack(coords).to(self.device),
|
| 289 |
+
'traffic_seq': torch.zeros(batch_size, 672, dtype=torch.float32).to(self.device)
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
# Batch Inference (Computes 9 regions simultaneously)
|
| 293 |
+
output = self.model(batch_gpu, mode='generate', loss_cfg={'n_steps_generate': N_STEPS})
|
| 294 |
+
outputs = output['generated'].cpu().numpy() # [9, 672]
|
| 295 |
+
|
| 296 |
+
## LSI Calculation: 1 / (std + epsilon)
|
| 297 |
+
# Higher LSI means lower variance (more stable traffic)
|
| 298 |
+
stds = outputs.std(axis=1)
|
| 299 |
+
lsis = 1.0 / (stds + 1e-6)
|
| 300 |
+
lsi_grid = lsis.reshape(grid_size, grid_size)
|
| 301 |
+
|
| 302 |
+
# Find the most stable coordinate
|
| 303 |
+
best_idx = np.unravel_index(np.argmax(lsi_grid), lsi_grid.shape)
|
| 304 |
+
best_traffic = outputs[best_idx[0] * grid_size + best_idx[1]]
|
| 305 |
+
|
| 306 |
+
return lsi_grid, best_idx, best_traffic
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def create_site_map_base64(self, img_pil, lsi_grid, best_idx):
|
| 310 |
+
"""Generates heatmap overlay visualization and encodes to Base64"""
|
| 311 |
+
img_arr = np.array(img_pil)
|
| 312 |
+
h, w, _ = img_arr.shape
|
| 313 |
+
grid_h, grid_w = lsi_grid.shape
|
| 314 |
+
|
| 315 |
+
fig, ax = plt.subplots(figsize=(4, 4), dpi=120)
|
| 316 |
+
|
| 317 |
+
# Overlay heatmap on satellite image
|
| 318 |
+
ax.imshow(img_arr)
|
| 319 |
+
im = ax.imshow(lsi_grid, cmap='RdYlGn', alpha=0.45, extent=[0, w, h, 0], interpolation='bicubic')
|
| 320 |
+
|
| 321 |
+
cell_w, cell_h = w / grid_w, h / grid_h
|
| 322 |
+
best_row, best_col = best_idx
|
| 323 |
+
center_x = best_col * cell_w + cell_w / 2
|
| 324 |
+
center_y = best_row * cell_h + cell_h / 2
|
| 325 |
+
|
| 326 |
+
# Draw target star and LSI indicator
|
| 327 |
+
ax.plot(center_x, center_y, marker='*', color='red', markersize=20, markeredgecolor='white', markeredgewidth=1.5)
|
| 328 |
+
best_lsi = lsi_grid[best_idx]
|
| 329 |
+
ax.annotate(f"LSI: {best_lsi:.2f}", xy=(center_x, center_y), xytext=(10, 10),
|
| 330 |
+
textcoords='offset points', color='white', fontweight='bold',
|
| 331 |
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="red", alpha=0.8, edgecolor="white"))
|
| 332 |
+
|
| 333 |
+
ax.axis('off')
|
| 334 |
+
plt.tight_layout()
|
| 335 |
+
|
| 336 |
+
# Convert to Base64 for API transmission
|
| 337 |
+
buf = BytesIO()
|
| 338 |
+
fig.savefig(buf, format="png", bbox_inches='tight', transparent=True)
|
| 339 |
+
plt.close(fig)
|
| 340 |
+
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 341 |
+
|
| 342 |
+
if __name__ == "__main__":
|
| 343 |
+
MODEL_PATH = "best_corr_model.pt"
|
| 344 |
+
SPATIAL_PATH = "data/spatial_features.npz"
|
| 345 |
+
TRAFFIC_PATH = "data/bs_record_energy_normalized_sampled.npz"
|
| 346 |
+
|
| 347 |
+
predictor = TrafficPredictor(
|
| 348 |
+
model_path=MODEL_PATH,
|
| 349 |
+
spatial_path=SPATIAL_PATH,
|
| 350 |
+
traffic_path=TRAFFIC_PATH
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
test_id = 277
|
| 354 |
+
result = predictor.predict(test_id, use_local_img_for_debug=False)
|
| 355 |
+
|
| 356 |
+
if result.get("status") == "success":
|
| 357 |
+
print(f"Prediction successful for Station {test_id}!")
|
| 358 |
+
else:
|
| 359 |
+
print(f"Prediction failed: {result.get('error')}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Flask==3.0.0
|
| 2 |
+
flask-cors==4.0.0
|
| 3 |
+
numpy==1.26.4
|
| 4 |
+
Pillow==10.2.0
|
| 5 |
+
requests==2.31.0
|
| 6 |
+
matplotlib==3.8.2
|
| 7 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 8 |
+
torch==2.1.2+cpu
|
script.js
ADDED
|
@@ -0,0 +1,954 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// ==========================================
|
| 2 |
+
// 1. Config (Focused on Shanghai)
|
| 3 |
+
// ==========================================
|
| 4 |
+
const CONFIG = {
|
| 5 |
+
MAPBOX_TOKEN: 'pk.eyJ1IjoieXlhaXl5IiwiYSI6ImNtaTVpMTVlaTJmdzMybW9zcmFieGxpdHUifQ.181d6E5fzLw1CEZMEPU53Q',
|
| 6 |
+
// API_BASE: 'http://127.0.0.1:5000/api', // Local
|
| 7 |
+
API_BASE: '/api', // Online
|
| 8 |
+
|
| 9 |
+
// Shanghai City Center
|
| 10 |
+
DEFAULT_CENTER: [121.4737, 31.2304],
|
| 11 |
+
DEFAULT_ZOOM: 10.5,
|
| 12 |
+
|
| 13 |
+
// Shanghai Coordinate Bounds [Southwest, Northeast]
|
| 14 |
+
SHANGHAI_BOUNDS: [
|
| 15 |
+
[120.80, 30.60], // Southwest
|
| 16 |
+
[122.50, 31.90] // Northeast
|
| 17 |
+
]
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
// ==========================================
|
| 21 |
+
// 2. Globals
|
| 22 |
+
// ==========================================
|
| 23 |
+
let chartInstance = null;
|
| 24 |
+
let predictionChartInstance = null;
|
| 25 |
+
let currentMarker = null;
|
| 26 |
+
let mapInstance = null;
|
| 27 |
+
let globalStationData = [];
|
| 28 |
+
let animationFrameId = null;
|
| 29 |
+
let isPredictionMode = false;
|
| 30 |
+
let predictionMarker = null;
|
| 31 |
+
let optimalMarker = null;
|
| 32 |
+
|
| 33 |
+
// ==========================================
|
| 34 |
+
// 3. API Logic
|
| 35 |
+
// ==========================================
|
| 36 |
+
async function fetchLocations() {
|
| 37 |
+
console.log("Requesting backend data...");
|
| 38 |
+
const res = await fetch(`${CONFIG.API_BASE}/stations/locations`);
|
| 39 |
+
if (!res.ok) throw new Error(`API Error: ${res.status}`);
|
| 40 |
+
return await res.json();
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
async function fetchStationDetail(id) {
|
| 44 |
+
try {
|
| 45 |
+
const res = await fetch(`${CONFIG.API_BASE}/stations/detail/${id}`);
|
| 46 |
+
return await res.json();
|
| 47 |
+
} catch (e) {
|
| 48 |
+
console.error("Fetch Detail Error:", e);
|
| 49 |
+
return null;
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Fetch AI Prediction Data
|
| 54 |
+
async function fetchPrediction(id) {
|
| 55 |
+
try {
|
| 56 |
+
const res = await fetch(`${CONFIG.API_BASE}/predict/${id}?t=${Date.now()}`);
|
| 57 |
+
const data = await res.json();
|
| 58 |
+
if (data.error) throw new Error(data.error);
|
| 59 |
+
return data;
|
| 60 |
+
} catch (e) {
|
| 61 |
+
console.error("Prediction API Error:", e);
|
| 62 |
+
alert("Prediction failed: " + e.message);
|
| 63 |
+
return null;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
function loadSatellitePatch(lng, lat) {
|
| 68 |
+
// Logic for loading static satellite imagery patch
|
| 69 |
+
const img = document.getElementById('satellite-patch');
|
| 70 |
+
const placeholder = document.getElementById('sat-placeholder');
|
| 71 |
+
if(!img) return;
|
| 72 |
+
|
| 73 |
+
img.style.display = 'none';
|
| 74 |
+
placeholder.style.display = 'flex';
|
| 75 |
+
placeholder.innerHTML = '<p>Loading...</p>';
|
| 76 |
+
|
| 77 |
+
img.src = `https://api.mapbox.com/styles/v1/mapbox/satellite-v9/static/${lng},${lat},16,0,0/320x200?access_token=${CONFIG.MAPBOX_TOKEN}`;
|
| 78 |
+
img.onload = () => { img.style.display = 'block'; placeholder.style.display = 'none'; };
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// ==========================================
|
| 82 |
+
// 4. Chart Logic (Normal & Prediction)
|
| 83 |
+
// ==========================================
|
| 84 |
+
function renderChart(recordData) {
|
| 85 |
+
const ctx = document.getElementById('energyChart').getContext('2d');
|
| 86 |
+
if (chartInstance) chartInstance.destroy();
|
| 87 |
+
|
| 88 |
+
chartInstance = new Chart(ctx, {
|
| 89 |
+
type: 'line',
|
| 90 |
+
data: {
|
| 91 |
+
labels: recordData.map((_, i) => i),
|
| 92 |
+
datasets: [
|
| 93 |
+
{
|
| 94 |
+
label: 'Traffic', data: recordData,
|
| 95 |
+
borderColor: '#00cec9', backgroundColor: 'rgba(0, 206, 201, 0.1)',
|
| 96 |
+
borderWidth: 1.5, fill: true, pointRadius: 0, tension: 0.3
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
label: 'Current', data: [], type: 'scatter',
|
| 100 |
+
pointRadius: 6, pointBackgroundColor: '#ffffff',
|
| 101 |
+
pointBorderColor: '#e84393', pointBorderWidth: 3
|
| 102 |
+
}
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
options: {
|
| 106 |
+
responsive: true, maintainAspectRatio: false, animation: false,
|
| 107 |
+
plugins: { legend: { display: false } },
|
| 108 |
+
scales: { x: { display: false }, y: { grid: { color: 'rgba(255,255,255,0.05)' }, ticks: { color: '#64748b', font: {size: 10} } } }
|
| 109 |
+
}
|
| 110 |
+
});
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
function updateChartCursor(timeIndex) {
|
| 114 |
+
if (chartInstance && chartInstance.data.datasets[0].data.length > timeIndex) {
|
| 115 |
+
const yValue = chartInstance.data.datasets[0].data[timeIndex];
|
| 116 |
+
chartInstance.data.datasets[1].data = [{x: timeIndex, y: yValue}];
|
| 117 |
+
chartInstance.update('none');
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// Render AI Prediction Comparison Chart
|
| 122 |
+
function renderPredictionChart(realData, predData) {
|
| 123 |
+
const canvas = document.getElementById('predictionChart');
|
| 124 |
+
if (!canvas) return;
|
| 125 |
+
const ctx = canvas.getContext('2d');
|
| 126 |
+
|
| 127 |
+
if (predictionChartInstance) {
|
| 128 |
+
predictionChartInstance.destroy();
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// Generate X-axis labels (e.g., H0, H1...)
|
| 132 |
+
const labels = realData.map((_, i) => `H${i}`);
|
| 133 |
+
|
| 134 |
+
predictionChartInstance = new Chart(ctx, {
|
| 135 |
+
type: 'line',
|
| 136 |
+
data: {
|
| 137 |
+
labels: labels,
|
| 138 |
+
datasets: [
|
| 139 |
+
{
|
| 140 |
+
label: 'Real Traffic',
|
| 141 |
+
data: realData,
|
| 142 |
+
borderColor: 'rgba(0, 206, 201, 0.8)', // Cyan
|
| 143 |
+
backgroundColor: 'rgba(0, 206, 201, 0.1)',
|
| 144 |
+
borderWidth: 1.5,
|
| 145 |
+
pointRadius: 0,
|
| 146 |
+
fill: true,
|
| 147 |
+
tension: 0.3
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
label: 'AI Prediction',
|
| 151 |
+
data: predData,
|
| 152 |
+
borderColor: '#f39c12', // Orange
|
| 153 |
+
backgroundColor: 'transparent',
|
| 154 |
+
borderWidth: 2,
|
| 155 |
+
borderDash: [5, 5], // Dashed line effect
|
| 156 |
+
pointRadius: 0,
|
| 157 |
+
fill: false,
|
| 158 |
+
tension: 0.3
|
| 159 |
+
}
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
options: {
|
| 163 |
+
responsive: true,
|
| 164 |
+
maintainAspectRatio: false,
|
| 165 |
+
interaction: {
|
| 166 |
+
mode: 'index',
|
| 167 |
+
intersect: false, // Tooltip shows both values simultaneously
|
| 168 |
+
},
|
| 169 |
+
plugins: {
|
| 170 |
+
legend: {
|
| 171 |
+
display: true,
|
| 172 |
+
labels: { color: '#e0e0e0', font: { size: 10 } }
|
| 173 |
+
}
|
| 174 |
+
},
|
| 175 |
+
scales: {
|
| 176 |
+
x: {
|
| 177 |
+
display: true,
|
| 178 |
+
grid: { color: 'rgba(255,255,255,0.05)' },
|
| 179 |
+
ticks: { color: '#64748b', font: {size: 9}, maxTicksLimit: 14 }
|
| 180 |
+
},
|
| 181 |
+
y: {
|
| 182 |
+
grid: { color: 'rgba(255,255,255,0.1)' },
|
| 183 |
+
ticks: { color: '#888', font: {size: 10} },
|
| 184 |
+
beginAtZero: true
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
});
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// ==========================================
|
| 192 |
+
// 5. Map Manager
|
| 193 |
+
// ==========================================
|
| 194 |
+
function initMap() {
|
| 195 |
+
mapboxgl.accessToken = CONFIG.MAPBOX_TOKEN;
|
| 196 |
+
mapInstance = new mapboxgl.Map({
|
| 197 |
+
container: 'map',
|
| 198 |
+
style: 'mapbox://styles/mapbox/satellite-streets-v12',
|
| 199 |
+
center: CONFIG.DEFAULT_CENTER,
|
| 200 |
+
zoom: CONFIG.DEFAULT_ZOOM,
|
| 201 |
+
pitch: 60,
|
| 202 |
+
bearing: -15,
|
| 203 |
+
antialias: true,
|
| 204 |
+
maxBounds: CONFIG.SHANGHAI_BOUNDS,
|
| 205 |
+
minZoom: 9
|
| 206 |
+
});
|
| 207 |
+
mapInstance.addControl(new mapboxgl.NavigationControl(), 'top-right');
|
| 208 |
+
return mapInstance;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
function setupMapEnvironment(map) {
|
| 212 |
+
map.addSource('mapbox-dem', {
|
| 213 |
+
'type': 'raster-dem',
|
| 214 |
+
'url': 'mapbox://mapbox.mapbox-terrain-dem-v1',
|
| 215 |
+
'tileSize': 512,
|
| 216 |
+
'maxzoom': 14 });
|
| 217 |
+
|
| 218 |
+
map.setTerrain({ 'source': 'mapbox-dem',
|
| 219 |
+
'exaggeration': 1.5 });
|
| 220 |
+
|
| 221 |
+
map.addLayer({
|
| 222 |
+
'id': 'sky',
|
| 223 |
+
'type': 'sky',
|
| 224 |
+
'paint': { 'sky-type': 'atmosphere', 'sky-atmosphere-sun': [0.0, 0.0], 'sky-atmosphere-sun-intensity': 15 }
|
| 225 |
+
});
|
| 226 |
+
|
| 227 |
+
if (map.setFog) {
|
| 228 |
+
map.setFog({ 'range': [0.5, 10],
|
| 229 |
+
'color': '#240b36',
|
| 230 |
+
'horizon-blend': 0.1,
|
| 231 |
+
'high-color': '#0f172a',
|
| 232 |
+
'space-color': '#000000',
|
| 233 |
+
'star-intensity': 0.6 });
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
const labelLayerId = map.getStyle().layers.find(l => l.type === 'symbol' && l.layout['text-field']).id;
|
| 237 |
+
if (!map.getLayer('3d-buildings')) {
|
| 238 |
+
map.addLayer({
|
| 239 |
+
'id': '3d-buildings', 'source': 'composite',
|
| 240 |
+
'source-layer': 'building', 'filter': ['==', 'extrude', 'true'],
|
| 241 |
+
'type': 'fill-extrusion', 'minzoom': 11,
|
| 242 |
+
'paint': {
|
| 243 |
+
'fill-extrusion-color': ['interpolate', ['linear'], ['get', 'height'], 0, '#0f0c29', 30, '#1e2a4a', 200, '#4b6cb7'],
|
| 244 |
+
'fill-extrusion-height': ['get', 'height'], 'fill-extrusion-base': ['get', 'min_height'], 'fill-extrusion-opacity': 0.6
|
| 245 |
+
}
|
| 246 |
+
}, labelLayerId);
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
function updateGeoJSONData(map, stations, mode = 'avg', timeIndex = 0) {
|
| 251 |
+
const pointFeatures = [];
|
| 252 |
+
const polygonFeatures = [];
|
| 253 |
+
const r = 0.00025; // Marker radius
|
| 254 |
+
|
| 255 |
+
stations.forEach(s => {
|
| 256 |
+
const lng = s.loc[0], lat = s.loc[1];
|
| 257 |
+
let valH = (mode === 'avg') ? (s.val_h || 0) : ((s.vals && s.vals[timeIndex]) !== undefined ? s.vals[timeIndex] : 0);
|
| 258 |
+
let valC = (s.val_c !== undefined) ? s.val_c : 0;
|
| 259 |
+
|
| 260 |
+
const props = { id: s.id, load_avg: valH, load_std: valC };
|
| 261 |
+
|
| 262 |
+
pointFeatures.push({ type: 'Feature', geometry: {
|
| 263 |
+
type: 'Point', coordinates: [lng, lat] }, properties: props });
|
| 264 |
+
polygonFeatures.push({ type: 'Feature', geometry: {
|
| 265 |
+
type: 'Polygon', coordinates: [[ [lng-r, lat-r], [lng+r, lat-r], [lng+r, lat+r], [lng-r, lat+r], [lng-r, lat-r] ]] }, properties: props });
|
| 266 |
+
});
|
| 267 |
+
|
| 268 |
+
if (map.getSource('stations-points')) {
|
| 269 |
+
map.getSource('stations-points').setData({
|
| 270 |
+
type: 'FeatureCollection',
|
| 271 |
+
features: pointFeatures });
|
| 272 |
+
|
| 273 |
+
map.getSource('stations-polygons').setData({
|
| 274 |
+
type: 'FeatureCollection',
|
| 275 |
+
features: polygonFeatures });
|
| 276 |
+
}
|
| 277 |
+
return { points: { type: 'FeatureCollection', features: pointFeatures }, polys: { type: 'FeatureCollection', features: polygonFeatures } };
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
function addStationLayers(map, geoData, statsLoad, statsColor) {
|
| 281 |
+
map.addSource('stations-points', { type: 'geojson', data: geoData.points });
|
| 282 |
+
map.addSource('stations-polygons', { type: 'geojson', data: geoData.polys });
|
| 283 |
+
|
| 284 |
+
map.addLayer({
|
| 285 |
+
id: 'stations-heatmap', type: 'heatmap', source: 'stations-points', maxzoom: 14,
|
| 286 |
+
paint: {
|
| 287 |
+
'heatmap-weight': ['interpolate', ['linear'], ['get', 'load_avg'], statsLoad.min, 0, statsLoad.max, 1],
|
| 288 |
+
'heatmap-intensity': ['interpolate', ['linear'], ['zoom'], 0, 1, 13, 3],
|
| 289 |
+
'heatmap-color': ['interpolate', ['linear'], ['heatmap-density'], 0, 'rgba(0,0,0,0)', 0.2, '#0984e3', 0.4, '#00cec9', 0.6, '#a29bfe', 0.8, '#fd79a8', 1, '#ffffff'],
|
| 290 |
+
'heatmap-radius': ['interpolate', ['linear'], ['zoom'], 0, 2, 13, 25],
|
| 291 |
+
'heatmap-opacity': ['interpolate', ['linear'], ['zoom'], 12, 1, 14, 0]
|
| 292 |
+
}
|
| 293 |
+
});
|
| 294 |
+
|
| 295 |
+
map.addLayer({
|
| 296 |
+
id: 'stations-2d-dots', type: 'circle', source: 'stations-points', minzoom: 12,
|
| 297 |
+
paint: {
|
| 298 |
+
'circle-radius': 3,
|
| 299 |
+
'circle-color': ['step', ['get', 'load_std'], '#1e1e2e', statsColor.t1, '#0984e3', statsColor.t2, '#00cec9', statsColor.t3, '#fd79a8', statsColor.t4, '#e84393'],
|
| 300 |
+
'circle-stroke-width': 1, 'circle-stroke-color': '#fff', 'circle-opacity': 0.8
|
| 301 |
+
}
|
| 302 |
+
});
|
| 303 |
+
|
| 304 |
+
map.addLayer({
|
| 305 |
+
id: 'stations-3d-pillars', type: 'fill-extrusion', source: 'stations-polygons', minzoom: 12,
|
| 306 |
+
paint: {
|
| 307 |
+
'fill-extrusion-color': ['step', ['get', 'load_std'], '#1e1e2e', statsColor.t1, '#0984e3', statsColor.t2, '#00cec9', statsColor.t3, '#fd79a8', statsColor.t4, '#e84393'],
|
| 308 |
+
'fill-extrusion-height': ['interpolate', ['linear'], ['get', 'load_avg'], 0, 0, statsLoad.min, 5, statsLoad.max, 300],
|
| 309 |
+
'fill-extrusion-opacity': 0.7
|
| 310 |
+
}
|
| 311 |
+
});
|
| 312 |
+
|
| 313 |
+
map.addLayer({ id: 'stations-hitbox', type: 'circle', source: 'stations-points',
|
| 314 |
+
paint: { 'circle-radius': 10, 'circle-color': 'transparent', 'circle-opacity': 0 } });
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
// ==========================================
|
| 318 |
+
// 6. Map Interactions
|
| 319 |
+
// ==========================================
|
| 320 |
+
function setupInteraction(map) {
|
| 321 |
+
const popup = new mapboxgl.Popup({ closeButton: false, closeOnClick: false, className: 'cyber-popup' });
|
| 322 |
+
|
| 323 |
+
map.on('mouseenter', 'stations-hitbox', (e) => {
|
| 324 |
+
map.getCanvas().style.cursor = 'pointer';
|
| 325 |
+
if (isPredictionMode) return;
|
| 326 |
+
|
| 327 |
+
const props = e.features[0].properties;
|
| 328 |
+
const coordinates = e.features[0].geometry.coordinates.slice();
|
| 329 |
+
|
| 330 |
+
while (Math.abs(e.lngLat.lng - coordinates[0]) > 180) { coordinates[0] += e.lngLat.lng > coordinates[0] ? 360 : -360; }
|
| 331 |
+
|
| 332 |
+
popup.setLngLat(coordinates)
|
| 333 |
+
.setHTML(`
|
| 334 |
+
<div style="font-weight:bold; color:#fff; border-bottom:1px solid #444; padding-bottom:2px; margin-bottom:2px;">Station ${props.id}</div>
|
| 335 |
+
<div style="color:#00cec9;">Load: <span style="color:#fff;">${props.load_avg.toFixed(2)}</span></div>
|
| 336 |
+
<div style="color:#fd79a8;">Stability: <span style="color:#fff;">${props.load_std.toFixed(4)}</span></div>
|
| 337 |
+
`).addTo(map);
|
| 338 |
+
});
|
| 339 |
+
|
| 340 |
+
map.on('mouseleave', 'stations-hitbox', () => {
|
| 341 |
+
if (!isPredictionMode) map.getCanvas().style.cursor = '';
|
| 342 |
+
popup.remove();
|
| 343 |
+
});
|
| 344 |
+
|
| 345 |
+
// Core Interaction Logic
|
| 346 |
+
map.on('click', 'stations-hitbox', async (e) => {
|
| 347 |
+
const coordinates = e.features[0].geometry.coordinates.slice();
|
| 348 |
+
const id = e.features[0].properties.id;
|
| 349 |
+
|
| 350 |
+
// 1. Prediction Mode Logic
|
| 351 |
+
if (isPredictionMode) {
|
| 352 |
+
const predPanel = document.getElementById('prediction-panel');
|
| 353 |
+
const predIdDisplay = document.getElementById('pred-station-id');
|
| 354 |
+
const siteMapContainer = document.getElementById('site-map-container');
|
| 355 |
+
const siteMapImg = document.getElementById('site-map-img');
|
| 356 |
+
|
| 357 |
+
predPanel.classList.add('active');
|
| 358 |
+
|
| 359 |
+
const rightBtn = document.getElementById('toggle-right-btn');
|
| 360 |
+
if (rightBtn) rightBtn.classList.add('active');
|
| 361 |
+
|
| 362 |
+
predIdDisplay.innerText = `${id} (Calculating...)`;
|
| 363 |
+
|
| 364 |
+
// Clear previous optimal site marker when a new station is clicked
|
| 365 |
+
if (optimalMarker) {
|
| 366 |
+
optimalMarker.remove();
|
| 367 |
+
optimalMarker = null;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
// Drop orange selection pin and draw 3x3 grid
|
| 371 |
+
if (!predictionMarker) {
|
| 372 |
+
predictionMarker = new mapboxgl.Marker({ color: '#f39c12' })
|
| 373 |
+
.setLngLat(coordinates).addTo(map);
|
| 374 |
+
} else {
|
| 375 |
+
predictionMarker.setLngLat(coordinates);
|
| 376 |
+
}
|
| 377 |
+
updatePredictionGrid(map, coordinates[0], coordinates[1]);
|
| 378 |
+
|
| 379 |
+
if (siteMapContainer) siteMapContainer.style.display = 'none';
|
| 380 |
+
if (siteMapImg) siteMapImg.src = '';
|
| 381 |
+
|
| 382 |
+
if(predictionChartInstance) {
|
| 383 |
+
predictionChartInstance.destroy();
|
| 384 |
+
predictionChartInstance = null;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
// Call Prediction API
|
| 388 |
+
const result = await fetchPrediction(id);
|
| 389 |
+
if(result && result.status === "success") {
|
| 390 |
+
predIdDisplay.innerText = id;
|
| 391 |
+
renderPredictionChart(result.real, result.prediction);
|
| 392 |
+
|
| 393 |
+
// Render returned Base64 site heatmap and mark optimal location
|
| 394 |
+
if (result.site_map_b64 && siteMapContainer && siteMapImg) {
|
| 395 |
+
siteMapImg.src = `data:image/png;base64,${result.site_map_b64}`;
|
| 396 |
+
siteMapContainer.style.display = 'block';
|
| 397 |
+
|
| 398 |
+
// Typewriter Effect for AI Explanation
|
| 399 |
+
const explanationBox = document.getElementById('site-explanation');
|
| 400 |
+
if (explanationBox && result.explanation) {
|
| 401 |
+
explanationBox.style.display = 'block';
|
| 402 |
+
|
| 403 |
+
// Reset content and add blinking cursor
|
| 404 |
+
explanationBox.innerHTML = `<strong>> SYSTEM LOG: AI DECISION</strong><br><span id="typewriter-text"></span><span class="cursor" style="animation: blink 1s step-end infinite;">_</span>`;
|
| 405 |
+
|
| 406 |
+
const textTarget = document.getElementById('typewriter-text');
|
| 407 |
+
const fullText = result.explanation;
|
| 408 |
+
let charIndex = 0;
|
| 409 |
+
|
| 410 |
+
function typeWriter() {
|
| 411 |
+
if (charIndex < fullText.length) {
|
| 412 |
+
textTarget.innerHTML += fullText.charAt(charIndex);
|
| 413 |
+
charIndex++;
|
| 414 |
+
// Randomize typing speed for realistic terminal feel
|
| 415 |
+
setTimeout(typeWriter, Math.random() * 20 + 10);
|
| 416 |
+
}
|
| 417 |
+
}
|
| 418 |
+
typeWriter();
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
// Mark green optimal Pin on physical map coordinates
|
| 422 |
+
if (result.best_loc) {
|
| 423 |
+
|
| 424 |
+
// Remove orange marker to avoid overlap
|
| 425 |
+
if (predictionMarker) {
|
| 426 |
+
predictionMarker.remove();
|
| 427 |
+
predictionMarker = null;
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
// Create custom "Green Pulse" DOM element defined in CSS
|
| 431 |
+
const customPin = document.createElement('div');
|
| 432 |
+
customPin.className = 'optimal-pulse-pin';
|
| 433 |
+
|
| 434 |
+
optimalMarker = new mapboxgl.Marker(customPin)
|
| 435 |
+
.setLngLat(result.best_loc)
|
| 436 |
+
.setPopup(new mapboxgl.Popup({ offset: 25, closeButton: false, className: 'cyber-popup' })
|
| 437 |
+
.setHTML('<div style="color:#2ecc71; font-weight:bold; font-size:14px;">🌟 Best LSI Site</div>'))
|
| 438 |
+
.addTo(map);
|
| 439 |
+
|
| 440 |
+
optimalMarker.togglePopup();
|
| 441 |
+
|
| 442 |
+
// Smoothly fly to the optimal site location
|
| 443 |
+
map.flyTo({
|
| 444 |
+
center: result.best_loc,
|
| 445 |
+
zoom: 16.5,
|
| 446 |
+
speed: 1.2
|
| 447 |
+
});
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
} else {
|
| 451 |
+
predIdDisplay.innerText = `${id} (Failed)`;
|
| 452 |
+
}
|
| 453 |
+
return;
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
// 2. Standard Detail Mode Logic
|
| 457 |
+
if (currentMarker) currentMarker.remove();
|
| 458 |
+
currentMarker = new mapboxgl.Marker().setLngLat(coordinates).addTo(map);
|
| 459 |
+
|
| 460 |
+
const pitch = map.getPitch();
|
| 461 |
+
map.flyTo({ center: coordinates, zoom: 15, pitch: pitch > 10 ? 60 : 0, speed: 1.5 });
|
| 462 |
+
|
| 463 |
+
document.getElementById('selected-id').innerText = id;
|
| 464 |
+
|
| 465 |
+
try {
|
| 466 |
+
document.getElementById('station-details').innerHTML = '<p class="placeholder-text">Loading details...</p>';
|
| 467 |
+
|
| 468 |
+
const detailData = await fetchStationDetail(id);
|
| 469 |
+
if (detailData) {
|
| 470 |
+
const stats = detailData.stats || {avg:0, std:0};
|
| 471 |
+
|
| 472 |
+
document.getElementById('station-details').innerHTML =
|
| 473 |
+
`<div style="margin-top:10px;">
|
| 474 |
+
<p><strong>Longitude:</strong> ${detailData.loc[0].toFixed(4)}</p>
|
| 475 |
+
<p><strong>Latitude:</strong> ${detailData.loc[1].toFixed(4)}</p>
|
| 476 |
+
<hr style="border:0; border-top:1px solid #444; margin:5px 0;">
|
| 477 |
+
<p><strong>Avg Load:</strong> <span style="color:#00cec9">${stats.avg.toFixed(4)}</span></p>
|
| 478 |
+
<p><strong>Stability:</strong> <span style="color:#fd79a8">${stats.std.toFixed(4)}</span></p>
|
| 479 |
+
</div>`;
|
| 480 |
+
|
| 481 |
+
if (detailData.bs_record) {
|
| 482 |
+
renderChart(detailData.bs_record);
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
} catch (err) {
|
| 486 |
+
console.error("Failed to fetch clicked station details:", err);
|
| 487 |
+
document.getElementById('station-details').innerHTML = '<p style="color:red">Error loading data</p>';
|
| 488 |
+
}
|
| 489 |
+
});
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
// Prediction Mode State Control
|
| 493 |
+
function setupPredictionMode(map) {
|
| 494 |
+
const predictBtn = document.getElementById('predict-toggle');
|
| 495 |
+
const predPanel = document.getElementById('prediction-panel');
|
| 496 |
+
const closePredBtn = document.getElementById('close-pred-btn');
|
| 497 |
+
|
| 498 |
+
if (!predictBtn) return;
|
| 499 |
+
|
| 500 |
+
predictBtn.addEventListener('click', () => {
|
| 501 |
+
// Enforce 2D view check for prediction mode
|
| 502 |
+
const pitch = map.getPitch();
|
| 503 |
+
if (pitch > 10) {
|
| 504 |
+
alert("Prediction Mode is only available in 2D View. Please switch to 2D first.");
|
| 505 |
+
return;
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
isPredictionMode = !isPredictionMode;
|
| 509 |
+
|
| 510 |
+
if (isPredictionMode) {
|
| 511 |
+
predictBtn.classList.add('predict-on');
|
| 512 |
+
predictBtn.innerHTML = '<span class="icon">🔮</span> Mode: ON';
|
| 513 |
+
map.getCanvas().style.cursor = 'crosshair';
|
| 514 |
+
} else {
|
| 515 |
+
predictBtn.classList.remove('predict-on');
|
| 516 |
+
predictBtn.innerHTML = '<span class="icon">🔮</span> Prediction Mode';
|
| 517 |
+
map.getCanvas().style.cursor = '';
|
| 518 |
+
predPanel.classList.remove('active');
|
| 519 |
+
|
| 520 |
+
// Reset UI state when exiting prediction
|
| 521 |
+
predPanel.classList.remove('collapsed');
|
| 522 |
+
const rightBtn = document.getElementById('toggle-right-btn');
|
| 523 |
+
if(rightBtn) {
|
| 524 |
+
rightBtn.innerText = '▶';
|
| 525 |
+
rightBtn.classList.remove('active');
|
| 526 |
+
rightBtn.classList.remove('collapsed');
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
// Clear markers and grids
|
| 530 |
+
clearPredictionExtras(map);
|
| 531 |
+
}
|
| 532 |
+
});
|
| 533 |
+
|
| 534 |
+
if (closePredBtn) {
|
| 535 |
+
closePredBtn.addEventListener('click', () => {
|
| 536 |
+
predPanel.classList.remove('active');
|
| 537 |
+
const rightBtn = document.getElementById('toggle-right-btn');
|
| 538 |
+
if (rightBtn) rightBtn.classList.remove('active');
|
| 539 |
+
predictBtn.click(); // Trigger toggle to clean up state
|
| 540 |
+
});
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
// === 新增:绘制 AI 模型的 3x3 空间感知网格 ===
|
| 545 |
+
// function updatePredictionGrid(map, centerLng, centerLat) {
|
| 546 |
+
// const features = [];
|
| 547 |
+
// const step = 0.002; // 与后端 Python 的 offset = 0.002 对齐
|
| 548 |
+
// const gridSize = 3;
|
| 549 |
+
// const offset = Math.floor(gridSize / 2);
|
| 550 |
+
|
| 551 |
+
// for (let i = 0; i < gridSize; i++) {
|
| 552 |
+
// for (let j = 0; j < gridSize; j++) {
|
| 553 |
+
// // 精确还原 Python 中切片的中心坐标
|
| 554 |
+
// const cLng = centerLng + (j - offset) * step;
|
| 555 |
+
// const cLat = centerLat - (i - offset) * step;
|
| 556 |
+
// const w = step / 2;
|
| 557 |
+
|
| 558 |
+
// features.push({
|
| 559 |
+
// 'type': 'Feature',
|
| 560 |
+
// 'geometry': {
|
| 561 |
+
// 'type': 'Polygon',
|
| 562 |
+
// 'coordinates': [[
|
| 563 |
+
// [cLng - w, cLat - w], [cLng + w, cLat - w],
|
| 564 |
+
// [cLng + w, cLat + w], [cLng - w, cLat + w],
|
| 565 |
+
// [cLng - w, cLat - w]
|
| 566 |
+
// ]]
|
| 567 |
+
// }
|
| 568 |
+
// });
|
| 569 |
+
// }
|
| 570 |
+
// }
|
| 571 |
+
|
| 572 |
+
// const geojson = { 'type': 'FeatureCollection', 'features': features };
|
| 573 |
+
|
| 574 |
+
// if (map.getSource('pred-grid-source')) {
|
| 575 |
+
// map.getSource('pred-grid-source').setData(geojson);
|
| 576 |
+
// } else {
|
| 577 |
+
// map.addSource('pred-grid-source', { type: 'geojson', data: geojson });
|
| 578 |
+
// map.addLayer({
|
| 579 |
+
// 'id': 'pred-grid-fill', 'type': 'fill', 'source': 'pred-grid-source',
|
| 580 |
+
// 'paint': { 'fill-color': '#f39c12', 'fill-opacity': 0.1 }
|
| 581 |
+
// });
|
| 582 |
+
// map.addLayer({
|
| 583 |
+
// 'id': 'pred-grid-line', 'type': 'line', 'source': 'pred-grid-source',
|
| 584 |
+
// 'paint': { 'line-color': '#f39c12', 'line-width': 2, 'line-dasharray': [2, 2] }
|
| 585 |
+
// });
|
| 586 |
+
// }
|
| 587 |
+
// }
|
| 588 |
+
|
| 589 |
+
// Dynamic 3x3 grid matching the 256px satellite patch bounds
|
| 590 |
+
function updatePredictionGrid(map, centerLng, centerLat) {
|
| 591 |
+
const features = [];
|
| 592 |
+
const gridSize = 3;
|
| 593 |
+
const offset = Math.floor(gridSize / 2);
|
| 594 |
+
|
| 595 |
+
// Precise Web Mercator projection span calculation at Zoom 15
|
| 596 |
+
const zoom = 15;
|
| 597 |
+
|
| 598 |
+
// Total Longitude span for 256 pixels at this zoom
|
| 599 |
+
const lonSpan = 360 / Math.pow(2, zoom);
|
| 600 |
+
// Latitude span (scaled by local latitude)
|
| 601 |
+
const latSpan = lonSpan * Math.cos(centerLat * Math.PI / 180);
|
| 602 |
+
|
| 603 |
+
// Actual step sizes for 3x3 division
|
| 604 |
+
const stepLon = lonSpan / gridSize;
|
| 605 |
+
const stepLat = latSpan / gridSize;
|
| 606 |
+
|
| 607 |
+
for (let i = 0; i < gridSize; i++) {
|
| 608 |
+
for (let j = 0; j < gridSize; j++) {
|
| 609 |
+
// Center point of each micro-grid cell
|
| 610 |
+
const cLng = centerLng + (j - offset) * stepLon;
|
| 611 |
+
const cLat = centerLat - (i - offset) * stepLat;
|
| 612 |
+
|
| 613 |
+
const wLon = stepLon / 2;
|
| 614 |
+
const wLat = stepLat / 2;
|
| 615 |
+
|
| 616 |
+
features.push({
|
| 617 |
+
'type': 'Feature',
|
| 618 |
+
'geometry': {
|
| 619 |
+
'type': 'Polygon',
|
| 620 |
+
'coordinates': [[
|
| 621 |
+
[cLng - wLon, cLat - wLat], [cLng + wLon, cLat - wLat],
|
| 622 |
+
[cLng + wLon, cLat + wLat], [cLng - wLon, cLat + wLat],
|
| 623 |
+
[cLng - wLon, cLat - wLat]
|
| 624 |
+
]]
|
| 625 |
+
}
|
| 626 |
+
});
|
| 627 |
+
}
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
const geojson = { 'type': 'FeatureCollection', 'features': features };
|
| 631 |
+
|
| 632 |
+
if (map.getSource('pred-grid-source')) {
|
| 633 |
+
map.getSource('pred-grid-source').setData(geojson);
|
| 634 |
+
} else {
|
| 635 |
+
map.addSource('pred-grid-source', { type: 'geojson', data: geojson });
|
| 636 |
+
map.addLayer({
|
| 637 |
+
'id': 'pred-grid-fill', 'type': 'fill', 'source': 'pred-grid-source',
|
| 638 |
+
'paint': { 'fill-color': '#f39c12', 'fill-opacity': 0.1 }
|
| 639 |
+
});
|
| 640 |
+
map.addLayer({
|
| 641 |
+
'id': 'pred-grid-line', 'type': 'line', 'source': 'pred-grid-source',
|
| 642 |
+
'paint': { 'line-color': '#f39c12', 'line-width': 2, 'line-dasharray': [2, 2] }
|
| 643 |
+
});
|
| 644 |
+
}
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
// Cleanup prediction visual elements
|
| 648 |
+
function clearPredictionExtras(map) {
|
| 649 |
+
if (predictionMarker) { predictionMarker.remove(); predictionMarker = null; }
|
| 650 |
+
if (optimalMarker) { optimalMarker.remove(); optimalMarker = null; } // ====== 新增:清理绿色点 ======
|
| 651 |
+
if (map.getSource('pred-grid-source')) {
|
| 652 |
+
map.getSource('pred-grid-source').setData({ type: 'FeatureCollection', features: [] });
|
| 653 |
+
}
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
// ==========================================
|
| 657 |
+
// 7. Timeline Logic
|
| 658 |
+
// ==========================================
|
| 659 |
+
function setupTimeLapse(map, globalData) {
|
| 660 |
+
const playBtn = document.getElementById('play-btn');
|
| 661 |
+
const slider = document.getElementById('time-slider');
|
| 662 |
+
const display = document.getElementById('time-display');
|
| 663 |
+
if (!playBtn || !slider) return;
|
| 664 |
+
|
| 665 |
+
const totalHours = (globalData.length > 0 && globalData[0].vals) ? globalData[0].vals.length : 672;
|
| 666 |
+
slider.max = totalHours - 1;
|
| 667 |
+
let isPlaying = false;
|
| 668 |
+
let speed = 100;
|
| 669 |
+
|
| 670 |
+
const updateTime = (val) => {
|
| 671 |
+
const day = Math.floor(val / 24) + 1;
|
| 672 |
+
const hour = val % 24;
|
| 673 |
+
display.innerText = `Day ${day.toString().padStart(2, '0')} - ${hour.toString().padStart(2, '0')}:00`;
|
| 674 |
+
|
| 675 |
+
updateGeoJSONData(map, globalData, 'time', val);
|
| 676 |
+
updateChartCursor(val);
|
| 677 |
+
};
|
| 678 |
+
|
| 679 |
+
const play = () => {
|
| 680 |
+
let val = parseInt(slider.value);
|
| 681 |
+
val = (val + 1) % totalHours;
|
| 682 |
+
slider.value = val;
|
| 683 |
+
updateTime(val);
|
| 684 |
+
if (isPlaying) animationFrameId = setTimeout(() => requestAnimationFrame(play), speed);
|
| 685 |
+
};
|
| 686 |
+
|
| 687 |
+
playBtn.onclick = () => {
|
| 688 |
+
isPlaying = !isPlaying;
|
| 689 |
+
playBtn.innerText = isPlaying ? '⏸' : '▶';
|
| 690 |
+
if (isPlaying) play(); else clearTimeout(animationFrameId);
|
| 691 |
+
};
|
| 692 |
+
|
| 693 |
+
slider.oninput = (e) => {
|
| 694 |
+
isPlaying = false;
|
| 695 |
+
if(animationFrameId) clearTimeout(animationFrameId);
|
| 696 |
+
playBtn.innerText = '▶';
|
| 697 |
+
updateTime(parseInt(e.target.value));
|
| 698 |
+
};
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
// ==========================================
|
| 702 |
+
// 8. UI Controls
|
| 703 |
+
// ==========================================
|
| 704 |
+
function setupModeToggle(map) {
|
| 705 |
+
const btn = document.getElementById('view-toggle');
|
| 706 |
+
const timePanel = document.querySelector('.time-panel');
|
| 707 |
+
let is3D = true;
|
| 708 |
+
|
| 709 |
+
if (!btn) return;
|
| 710 |
+
|
| 711 |
+
btn.onclick = () => {
|
| 712 |
+
// Prevent switching to 3D mode if Prediction Mode is active
|
| 713 |
+
if (isPredictionMode) {
|
| 714 |
+
alert("Please exit Prediction Mode before switching to 3D.");
|
| 715 |
+
return;
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
is3D = !is3D;
|
| 719 |
+
if (is3D) {
|
| 720 |
+
// Switch to 3D View: Show pillars and tilt camera
|
| 721 |
+
if(map.getLayer('stations-3d-pillars')) map.setLayoutProperty('stations-3d-pillars', 'visibility', 'visible');
|
| 722 |
+
map.easeTo({ pitch: 60, bearing: -15 });
|
| 723 |
+
btn.innerHTML = '<span class="icon">👁️</span> View: 3D';
|
| 724 |
+
if (timePanel) {
|
| 725 |
+
timePanel.style.display = 'flex';
|
| 726 |
+
setTimeout(() => { timePanel.style.opacity = '1'; }, 10);
|
| 727 |
+
}
|
| 728 |
+
} else {
|
| 729 |
+
// Switch to 2D View: Hide pillars and reset camera pitch
|
| 730 |
+
if(map.getLayer('stations-3d-pillars')) map.setLayoutProperty('stations-3d-pillars', 'visibility', 'none');
|
| 731 |
+
map.easeTo({ pitch: 0, bearing: 0 });
|
| 732 |
+
btn.innerHTML = '<span class="icon">🗺️</span> View: 2D';
|
| 733 |
+
if (timePanel) {
|
| 734 |
+
timePanel.style.display = 'none';
|
| 735 |
+
timePanel.style.opacity = '0';
|
| 736 |
+
}
|
| 737 |
+
// Stop timelapse playback when entering 2D mode
|
| 738 |
+
const playBtn = document.getElementById('play-btn');
|
| 739 |
+
if (playBtn && playBtn.innerText === '⏸') playBtn.click();
|
| 740 |
+
}
|
| 741 |
+
};
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
function setupDataToggle(map) {
|
| 745 |
+
const btn = document.getElementById('data-toggle');
|
| 746 |
+
const layers = ['stations-3d-pillars', 'stations-2d-dots', 'stations-heatmap', 'stations-hitbox'];
|
| 747 |
+
let isVisible = true;
|
| 748 |
+
if(btn) btn.onclick = () => {
|
| 749 |
+
isVisible = !isVisible;
|
| 750 |
+
const val = isVisible ? 'visible' : 'none';
|
| 751 |
+
layers.forEach(id => { if(map.getLayer(id)) map.setLayoutProperty(id, 'visibility', val); });
|
| 752 |
+
btn.innerHTML = isVisible ? '<span class="icon">📡</span> Toggle Data' : '<span class="icon">🚫</span> Toggle Data';
|
| 753 |
+
btn.style.opacity = isVisible ? '1' : '0.6';
|
| 754 |
+
};
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
function setupFilterMenu(map, statsColor) {
|
| 758 |
+
const btn = document.getElementById('filter-btn');
|
| 759 |
+
const menu = document.getElementById('filter-menu');
|
| 760 |
+
if (!btn || !menu) return;
|
| 761 |
+
|
| 762 |
+
// Define stability levels based on Standard Deviation thresholds
|
| 763 |
+
const levels = [
|
| 764 |
+
{ label: "Level 5: Highly Unstable", color: "#e84393", filter: ['>=', 'load_std', statsColor.t4] },
|
| 765 |
+
{ label: "Level 4: Volatile", color: "#fd79a8", filter: ['all', ['>=', 'load_std', statsColor.t3], ['<', 'load_std', statsColor.t4]] },
|
| 766 |
+
{ label: "Level 3: Normal", color: "#00cec9", filter: ['all', ['>=', 'load_std', statsColor.t2], ['<', 'load_std', statsColor.t3]] },
|
| 767 |
+
{ label: "Level 2: Stable", color: "#0984e3", filter: ['all', ['>=', 'load_std', statsColor.t1], ['<', 'load_std', statsColor.t2]] },
|
| 768 |
+
{ label: "Level 1: Highly Stable", color: "#1e1e2e", filter: ['<', 'load_std', statsColor.t1] }
|
| 769 |
+
];
|
| 770 |
+
|
| 771 |
+
menu.innerHTML = '';
|
| 772 |
+
levels.forEach((lvl) => {
|
| 773 |
+
const item = document.createElement('div');
|
| 774 |
+
item.className = 'filter-item';
|
| 775 |
+
item.innerHTML = `<div class="color-box" style="background:${lvl.color}; box-shadow: 0 0 5px ${lvl.color};"></div><span>${lvl.label}</span>`;
|
| 776 |
+
item.onclick = (e) => {
|
| 777 |
+
e.stopPropagation();
|
| 778 |
+
if (item.classList.contains('selected')) {
|
| 779 |
+
item.classList.remove('selected');
|
| 780 |
+
applyFilter(map, null);
|
| 781 |
+
} else {
|
| 782 |
+
document.querySelectorAll('.filter-item').forEach(el => el.classList.remove('selected'));
|
| 783 |
+
item.classList.add('selected');
|
| 784 |
+
applyFilter(map, lvl.filter);
|
| 785 |
+
}
|
| 786 |
+
};
|
| 787 |
+
menu.appendChild(item);
|
| 788 |
+
});
|
| 789 |
+
|
| 790 |
+
// Toggle menu visibility
|
| 791 |
+
btn.onclick = (e) => { e.stopPropagation(); menu.classList.toggle('active'); };
|
| 792 |
+
document.addEventListener('click', (e) => { if (!menu.contains(e.target) && !btn.contains(e.target)) menu.classList.remove('active'); });
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
function applyFilter(map, filterExpression) {
|
| 796 |
+
const targetLayers = ['stations-3d-pillars', 'stations-2d-dots', 'stations-heatmap', 'stations-hitbox'];
|
| 797 |
+
targetLayers.forEach(layerId => { if (map.getLayer(layerId)) map.setFilter(layerId, filterExpression); });
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
function setupSearch(map, globalData) {
|
| 801 |
+
const input = document.getElementById('search-input');
|
| 802 |
+
const btn = document.getElementById('search-btn');
|
| 803 |
+
const clearBtn = document.getElementById('clear-search-btn');
|
| 804 |
+
const keepCheck = document.getElementById('keep-markers-check');
|
| 805 |
+
|
| 806 |
+
if (!input || !btn) return;
|
| 807 |
+
|
| 808 |
+
let searchMarkers = [];
|
| 809 |
+
|
| 810 |
+
const clearAllMarkers = () => {
|
| 811 |
+
searchMarkers.forEach(marker => marker.remove());
|
| 812 |
+
searchMarkers = [];
|
| 813 |
+
};
|
| 814 |
+
|
| 815 |
+
const performSearch = async () => {
|
| 816 |
+
const queryId = input.value.trim();
|
| 817 |
+
if (!queryId) return;
|
| 818 |
+
|
| 819 |
+
const target = globalData.find(s => String(s.id) === String(queryId));
|
| 820 |
+
|
| 821 |
+
if (target) {
|
| 822 |
+
if (!keepCheck.checked) {
|
| 823 |
+
clearAllMarkers();
|
| 824 |
+
}
|
| 825 |
+
|
| 826 |
+
// Fly to searched station and switch to high-detail view
|
| 827 |
+
map.flyTo({
|
| 828 |
+
center: target.loc,
|
| 829 |
+
zoom: 16,
|
| 830 |
+
pitch: 60,
|
| 831 |
+
essential: true
|
| 832 |
+
});
|
| 833 |
+
|
| 834 |
+
document.getElementById('selected-id').innerText = target.id;
|
| 835 |
+
try {
|
| 836 |
+
const detailData = await fetchStationDetail(target.id);
|
| 837 |
+
if (detailData) {
|
| 838 |
+
const stats = detailData.stats || {avg:0, std:0};
|
| 839 |
+
document.getElementById('station-details').innerHTML =
|
| 840 |
+
`<div style="margin-top:10px;">
|
| 841 |
+
<p><strong>Longitude:</strong> ${detailData.loc[0].toFixed(4)}</p>
|
| 842 |
+
<p><strong>Latitude:</strong> ${detailData.loc[1].toFixed(4)}</p>
|
| 843 |
+
<hr style="border:0; border-top:1px solid #444; margin:5px 0;">
|
| 844 |
+
<p><strong>Avg Load:</strong> <span style="color:#00cec9">${stats.avg.toFixed(4)}</span></p>
|
| 845 |
+
<p><strong>Stability:</strong> <span style="color:#fd79a8">${stats.std.toFixed(4)}</span></p>
|
| 846 |
+
</div>`;
|
| 847 |
+
|
| 848 |
+
if (detailData.bs_record) renderChart(detailData.bs_record);
|
| 849 |
+
}
|
| 850 |
+
} catch (e) {
|
| 851 |
+
console.error("Fetch details failed", e);
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
// Create red highlight marker for searched target
|
| 855 |
+
const marker = new mapboxgl.Marker({ color: '#ff0000', scale: 0.8 })
|
| 856 |
+
.setLngLat(target.loc)
|
| 857 |
+
.setPopup(new mapboxgl.Popup({ offset: 25 }).setText(`Station ID: ${target.id}`))
|
| 858 |
+
.addTo(map);
|
| 859 |
+
|
| 860 |
+
searchMarkers.push(marker);
|
| 861 |
+
|
| 862 |
+
} else {
|
| 863 |
+
alert("Station ID not found!");
|
| 864 |
+
}
|
| 865 |
+
};
|
| 866 |
+
|
| 867 |
+
btn.onclick = performSearch;
|
| 868 |
+
|
| 869 |
+
input.addEventListener('keypress', (e) => {
|
| 870 |
+
if (e.key === 'Enter') performSearch();
|
| 871 |
+
});
|
| 872 |
+
|
| 873 |
+
if (clearBtn) {
|
| 874 |
+
clearBtn.onclick = () => {
|
| 875 |
+
clearAllMarkers();
|
| 876 |
+
input.value = '';
|
| 877 |
+
};
|
| 878 |
+
}
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
// Sidebar & Panel Toggle Logic
|
| 882 |
+
function setupPanelToggles(map) {
|
| 883 |
+
const leftSidebar = document.querySelector('.sidebar');
|
| 884 |
+
const leftToggleBtn = document.getElementById('toggle-left-btn');
|
| 885 |
+
|
| 886 |
+
if (leftToggleBtn && leftSidebar) {
|
| 887 |
+
leftToggleBtn.addEventListener('click', () => {
|
| 888 |
+
leftSidebar.classList.toggle('collapsed');
|
| 889 |
+
leftToggleBtn.classList.toggle('collapsed');
|
| 890 |
+
leftToggleBtn.innerText = leftSidebar.classList.contains('collapsed') ? '▶' : '◀';
|
| 891 |
+
setTimeout(() => map.resize(), 300);
|
| 892 |
+
});
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
const rightSidebar = document.getElementById('prediction-panel');
|
| 896 |
+
const rightToggleBtn = document.getElementById('toggle-right-btn');
|
| 897 |
+
|
| 898 |
+
if (rightToggleBtn && rightSidebar) {
|
| 899 |
+
rightToggleBtn.addEventListener('click', () => {
|
| 900 |
+
rightSidebar.classList.toggle('collapsed');
|
| 901 |
+
rightToggleBtn.classList.toggle('collapsed');
|
| 902 |
+
rightToggleBtn.innerText = rightSidebar.classList.contains('collapsed') ? '◀' : '▶';
|
| 903 |
+
setTimeout(() => map.resize(), 300);
|
| 904 |
+
});
|
| 905 |
+
}
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
// ==========================================
|
| 909 |
+
// 9. Main Entry Point
|
| 910 |
+
// ==========================================
|
| 911 |
+
window.onload = async () => {
|
| 912 |
+
const map = initMap();
|
| 913 |
+
|
| 914 |
+
map.on('load', async () => {
|
| 915 |
+
setupMapEnvironment(map);
|
| 916 |
+
|
| 917 |
+
try {
|
| 918 |
+
// Load initial station metadata
|
| 919 |
+
const data = await fetchLocations();
|
| 920 |
+
globalStationData = data.stations;
|
| 921 |
+
document.getElementById('total-stations').innerText = globalStationData.length;
|
| 922 |
+
|
| 923 |
+
// Initialize Map Layers with empty data initially
|
| 924 |
+
addStationLayers(map,
|
| 925 |
+
{points: {type:'FeatureCollection', features:[]}, polys: {type:'FeatureCollection', features:[]} },
|
| 926 |
+
data.stats_height, data.stats_color);
|
| 927 |
+
|
| 928 |
+
// Immediately load data for T=0 (initial state)
|
| 929 |
+
updateGeoJSONData(map, globalStationData, 'time', 0);
|
| 930 |
+
updateChartCursor(0);
|
| 931 |
+
|
| 932 |
+
// Start Time Lapse
|
| 933 |
+
setupTimeLapse(map, globalStationData);
|
| 934 |
+
|
| 935 |
+
// Bind Interactions
|
| 936 |
+
setupPredictionMode(map); // Initialize AI Prediction events
|
| 937 |
+
setupInteraction(map); // Initialize standard map clicks/popups
|
| 938 |
+
setupModeToggle(map); // 2D/3D View switch
|
| 939 |
+
setupDataToggle(map); // Layer visibility switch
|
| 940 |
+
setupFilterMenu(map, data.stats_color); // Load-stability filters
|
| 941 |
+
setupSearch(map, globalStationData); // Search bar logic
|
| 942 |
+
|
| 943 |
+
// Initialize sidebar collapse/expand controls
|
| 944 |
+
setupPanelToggles(map);
|
| 945 |
+
|
| 946 |
+
// Remove Loading Screen
|
| 947 |
+
document.getElementById('loading').style.display = 'none';
|
| 948 |
+
} catch (e) {
|
| 949 |
+
console.error(e);
|
| 950 |
+
alert('System Initialization Failed. Check Console.');
|
| 951 |
+
document.getElementById('loading').innerHTML = '<h2>Error Loading Data</h2>';
|
| 952 |
+
}
|
| 953 |
+
});
|
| 954 |
+
};
|
server.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
from flask import Flask, jsonify, send_from_directory, request
|
| 6 |
+
from flask_cors import CORS
|
| 7 |
+
|
| 8 |
+
# Import the custom prediction backend module
|
| 9 |
+
try:
|
| 10 |
+
from prediction_backend import TrafficPredictor
|
| 11 |
+
except ImportError:
|
| 12 |
+
print("Warning: prediction_backend.py not found. Prediction features will be disabled.")
|
| 13 |
+
TrafficPredictor = None
|
| 14 |
+
except Exception as e:
|
| 15 |
+
print(f"Warning: Failed to import prediction_backend: {e}")
|
| 16 |
+
TrafficPredictor = None
|
| 17 |
+
|
| 18 |
+
# ==========================================
|
| 19 |
+
# Flask Server
|
| 20 |
+
# ==========================================
|
| 21 |
+
app = Flask(__name__, static_folder='.')
|
| 22 |
+
CORS(app)
|
| 23 |
+
|
| 24 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
# Data directory path
|
| 26 |
+
DATA_DIR = os.path.abspath(os.path.join(BASE_DIR, 'data'))
|
| 27 |
+
|
| 28 |
+
# File path configurations
|
| 29 |
+
JSON_PATH = os.path.join(DATA_DIR, 'base2info.json')
|
| 30 |
+
TRAFFIC_PATH = os.path.join(DATA_DIR, 'bs_record_energy_normalized_sampled.npz')
|
| 31 |
+
SPATIAL_PATH = os.path.join(DATA_DIR, 'spatial_features.npz')
|
| 32 |
+
MODEL_PATH = os.path.join(BASE_DIR, 'best_corr_model.pt')
|
| 33 |
+
|
| 34 |
+
# ==========================================
|
| 35 |
+
# Utility Functions
|
| 36 |
+
# ==========================================
|
| 37 |
+
def calculate_std_dev(records, avg):
|
| 38 |
+
"""Calculates standard deviation for a given set of records and their average."""
|
| 39 |
+
if not records or len(records) < 2:
|
| 40 |
+
return 0
|
| 41 |
+
variance = sum((x - avg) ** 2 for x in records) / len(records)
|
| 42 |
+
return math.sqrt(variance)
|
| 43 |
+
|
| 44 |
+
def calculate_stats(data_list):
|
| 45 |
+
"""Calculate global statistics for frontend normalization"""
|
| 46 |
+
print("Calculating statistical distribution (Avg & Std)...")
|
| 47 |
+
avgs = []
|
| 48 |
+
stds = []
|
| 49 |
+
|
| 50 |
+
for item in data_list:
|
| 51 |
+
records = item.get('bs_record', [])
|
| 52 |
+
if records:
|
| 53 |
+
avg = sum(records) / len(records)
|
| 54 |
+
std = calculate_std_dev(records, avg)
|
| 55 |
+
else:
|
| 56 |
+
avg = 0
|
| 57 |
+
std = 0
|
| 58 |
+
avgs.append(avg)
|
| 59 |
+
stds.append(std)
|
| 60 |
+
|
| 61 |
+
def get_percentiles(values):
|
| 62 |
+
"""Calculates percentiles to create data brackets for visualization."""
|
| 63 |
+
values.sort()
|
| 64 |
+
n = len(values)
|
| 65 |
+
if n == 0: return {k:0 for k in ['min','max','t1','t2','t3','t4']}
|
| 66 |
+
return {
|
| 67 |
+
"min": values[0],
|
| 68 |
+
"max": values[-1],
|
| 69 |
+
"t1": values[int(n * 0.2)],
|
| 70 |
+
"t2": values[int(n * 0.4)],
|
| 71 |
+
"t3": values[int(n * 0.6)],
|
| 72 |
+
"t4": values[int(n * 0.8)]
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
stats_h = get_percentiles(avgs) # Statistics for pillar heights
|
| 76 |
+
stats_c = get_percentiles(stds) # Statistics for pillar colors (stability)
|
| 77 |
+
return stats_h, stats_c
|
| 78 |
+
|
| 79 |
+
def _convert_numpy_type(val):
|
| 80 |
+
if isinstance(val, np.ndarray): return val.tolist()
|
| 81 |
+
elif isinstance(val, (np.integer, np.int64, np.int32, np.int16)): return int(val)
|
| 82 |
+
elif isinstance(val, (np.floating, np.float64, np.float32)): return float(val)
|
| 83 |
+
elif isinstance(val, bytes): return val.decode('utf-8')
|
| 84 |
+
else: return val
|
| 85 |
+
|
| 86 |
+
def load_and_process_data(json_path, npz_path):
|
| 87 |
+
print(f"[DataLoader] Loading basic data...")
|
| 88 |
+
print(f" - JSON: {json_path}")
|
| 89 |
+
print(f" - Traffic NPZ : {npz_path}")
|
| 90 |
+
|
| 91 |
+
if not os.path.exists(json_path) or not os.path.exists(npz_path):
|
| 92 |
+
print("[DataLoader] Error: Input files not found.")
|
| 93 |
+
return []
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
npz_data = np.load(npz_path)
|
| 97 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 98 |
+
json_map = json.load(f)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"[DataLoader] Read error: {e}")
|
| 101 |
+
return []
|
| 102 |
+
|
| 103 |
+
# Handle binary strings if present in NPZ
|
| 104 |
+
raw_bs_ids = npz_data['bs_id']
|
| 105 |
+
bs_ids = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in raw_bs_ids]
|
| 106 |
+
num_stations = len(bs_ids)
|
| 107 |
+
|
| 108 |
+
# Identify available time-series attributes in NPZ
|
| 109 |
+
station_attributes = []
|
| 110 |
+
for key in npz_data.files:
|
| 111 |
+
if key == 'bs_id': continue
|
| 112 |
+
if npz_data[key].shape[0] == num_stations:
|
| 113 |
+
station_attributes.append(key)
|
| 114 |
+
|
| 115 |
+
merged_data = []
|
| 116 |
+
match_count = 0
|
| 117 |
+
|
| 118 |
+
for i in range(num_stations):
|
| 119 |
+
current_id = bs_ids[i]
|
| 120 |
+
json_key = f"Base_{current_id}"
|
| 121 |
+
|
| 122 |
+
if json_key in json_map:
|
| 123 |
+
match_count += 1
|
| 124 |
+
entry = {
|
| 125 |
+
"id": current_id,
|
| 126 |
+
"npz_index": i, # Store original index for prediction lookups
|
| 127 |
+
"loc": json_map[json_key]["loc"]
|
| 128 |
+
}
|
| 129 |
+
for attr in station_attributes:
|
| 130 |
+
val = npz_data[attr][i]
|
| 131 |
+
entry[attr] = _convert_numpy_type(val)
|
| 132 |
+
merged_data.append(entry)
|
| 133 |
+
|
| 134 |
+
print(f"[DataLoader] Merge complete! Matched: {match_count}/{num_stations}")
|
| 135 |
+
return merged_data
|
| 136 |
+
|
| 137 |
+
# ==========================================
|
| 138 |
+
# Initialization Sequence
|
| 139 |
+
# ==========================================
|
| 140 |
+
|
| 141 |
+
print("Server Initializing...")
|
| 142 |
+
|
| 143 |
+
# 1. Load basic station data for frontend display
|
| 144 |
+
ALL_DATA = load_and_process_data(JSON_PATH, TRAFFIC_PATH)
|
| 145 |
+
|
| 146 |
+
STATS_HEIGHT = {}
|
| 147 |
+
STATS_COLOR = {}
|
| 148 |
+
|
| 149 |
+
if ALL_DATA:
|
| 150 |
+
STATS_HEIGHT, STATS_COLOR = calculate_stats(ALL_DATA)
|
| 151 |
+
else:
|
| 152 |
+
print("⚠️ CRITICAL WARNING: Data list is empty!")
|
| 153 |
+
|
| 154 |
+
# 2. Initialize AI Predictor with Spatial Features
|
| 155 |
+
predictor = None
|
| 156 |
+
if TrafficPredictor:
|
| 157 |
+
try:
|
| 158 |
+
print(f"[AI] Initializing Predictor with model: {MODEL_PATH}")
|
| 159 |
+
# Initialize the predictor using the model and spatial feature files
|
| 160 |
+
predictor = TrafficPredictor(
|
| 161 |
+
model_path=MODEL_PATH,
|
| 162 |
+
spatial_path=SPATIAL_PATH,
|
| 163 |
+
traffic_path=TRAFFIC_PATH
|
| 164 |
+
)
|
| 165 |
+
print("[AI] Predictor loaded successfully.")
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"[AI] Failed to load predictor: {e}")
|
| 168 |
+
|
| 169 |
+
# ==========================================
|
| 170 |
+
# API Routes
|
| 171 |
+
# ==========================================
|
| 172 |
+
|
| 173 |
+
@app.route('/')
|
| 174 |
+
def index():
|
| 175 |
+
"""Serves the main dashboard page."""
|
| 176 |
+
return send_from_directory('.', 'index.html')
|
| 177 |
+
|
| 178 |
+
@app.route('/<path:path>')
|
| 179 |
+
def serve_static(path):
|
| 180 |
+
"""Serves static assets (JS, CSS, Images)."""
|
| 181 |
+
return send_from_directory('.', path)
|
| 182 |
+
|
| 183 |
+
@app.route('/api/stations/locations')
|
| 184 |
+
def get_station_locations():
|
| 185 |
+
"""Returns a lightweight list of station coordinates and statistical summaries."""
|
| 186 |
+
lightweight_data = []
|
| 187 |
+
for item in ALL_DATA:
|
| 188 |
+
records = item.get('bs_record', [])
|
| 189 |
+
if records:
|
| 190 |
+
avg = sum(records) / len(records)
|
| 191 |
+
std = calculate_std_dev(records, avg)
|
| 192 |
+
else:
|
| 193 |
+
avg = 0
|
| 194 |
+
std = 0
|
| 195 |
+
|
| 196 |
+
lightweight_data.append({
|
| 197 |
+
"id": item['id'],
|
| 198 |
+
"loc": item['loc'],
|
| 199 |
+
"val_h": avg,
|
| 200 |
+
"val_c": std,
|
| 201 |
+
"vals": records
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
return jsonify({
|
| 205 |
+
"stats_height": STATS_HEIGHT,
|
| 206 |
+
"stats_color": STATS_COLOR,
|
| 207 |
+
"stations": lightweight_data
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
@app.route('/api/stations/detail/<station_id>')
|
| 211 |
+
def get_station_detail(station_id):
|
| 212 |
+
"""Returns detailed metadata and stats for a specific station."""
|
| 213 |
+
for item in ALL_DATA:
|
| 214 |
+
if str(item['id']) == str(station_id):
|
| 215 |
+
records = item.get('bs_record', [])
|
| 216 |
+
avg = sum(records)/len(records) if records else 0
|
| 217 |
+
std = calculate_std_dev(records, avg)
|
| 218 |
+
|
| 219 |
+
response = item.copy()
|
| 220 |
+
response['stats'] = {"avg": avg, "std": std}
|
| 221 |
+
return jsonify(response)
|
| 222 |
+
|
| 223 |
+
return jsonify({"error": "Station not found"}), 404
|
| 224 |
+
|
| 225 |
+
@app.route('/api/predict/<station_id>')
|
| 226 |
+
def predict_traffic(station_id):
|
| 227 |
+
"""Triggers the ML model to predict future traffic for a specific station."""
|
| 228 |
+
if not predictor:
|
| 229 |
+
return jsonify({"error": "Prediction service not available"}), 503
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
target_idx = -1
|
| 233 |
+
|
| 234 |
+
# Map Station ID to its internal index in the NPZ file
|
| 235 |
+
for item in ALL_DATA:
|
| 236 |
+
if str(item['id']) == str(station_id):
|
| 237 |
+
target_idx = item.get('npz_index', -1)
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
if target_idx == -1:
|
| 241 |
+
# Fallback: Check if the ID provided is directly a numerical index
|
| 242 |
+
if str(station_id).isdigit():
|
| 243 |
+
target_idx = int(station_id)
|
| 244 |
+
else:
|
| 245 |
+
return jsonify({"error": "Station ID not found in mapping"}), 404
|
| 246 |
+
|
| 247 |
+
# Execute prediction through the ML backend
|
| 248 |
+
result = predictor.predict(target_idx)
|
| 249 |
+
|
| 250 |
+
if "error" in result:
|
| 251 |
+
return jsonify(result), 500
|
| 252 |
+
|
| 253 |
+
return jsonify(result)
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Prediction Error: {e}")
|
| 257 |
+
return jsonify({"error": str(e)}), 500
|
| 258 |
+
|
| 259 |
+
# Local development server
|
| 260 |
+
# if __name__ == '__main__':
|
| 261 |
+
# print(f"Monitoring Data Directory: {DATA_DIR}")
|
| 262 |
+
# print("Server running on http://127.0.0.1:5000")
|
| 263 |
+
# app.run(debug=True, port=5000)
|
| 264 |
+
|
| 265 |
+
# FOR ONLINE
|
| 266 |
+
if __name__ == '__main__':
|
| 267 |
+
print(f"Monitoring Data Directory: {DATA_DIR}")
|
| 268 |
+
print("Server running on port 7860...")
|
| 269 |
+
app.run(host='0.0.0.0', port=7860) # <--- 就改这一行!取消 debug=True,改 host 和 port
|
style.css
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* =========================================
|
| 2 |
+
1. Base Reset & Layout
|
| 3 |
+
========================================= */
|
| 4 |
+
* {
|
| 5 |
+
margin: 0;
|
| 6 |
+
padding: 0;
|
| 7 |
+
box-sizing: border-box;
|
| 8 |
+
/* Clean sans-serif stack for readability in high-tech interfaces */
|
| 9 |
+
font-family: 'Segoe UI', 'Roboto', Helvetica, Arial, sans-serif;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
body {
|
| 13 |
+
background: #0f0c29; /* Deep space navy */
|
| 14 |
+
color: #e0e0e0;
|
| 15 |
+
height: 100vh;
|
| 16 |
+
overflow: hidden; /* Prevents browser scrollbars during panel transitions */
|
| 17 |
+
display: flex; /* Sidebars and main content align horizontally */
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
/* =========================================
|
| 21 |
+
2. Loading Overlay
|
| 22 |
+
========================================= */
|
| 23 |
+
.loading-overlay {
|
| 24 |
+
position: absolute;
|
| 25 |
+
top: 0;
|
| 26 |
+
left: 0;
|
| 27 |
+
width: 100%;
|
| 28 |
+
height: 100%;
|
| 29 |
+
background: #0f0c29;
|
| 30 |
+
z-index: 9999; /* Ensure it stays above map and sidebars */
|
| 31 |
+
display: flex;
|
| 32 |
+
flex-direction: column;
|
| 33 |
+
justify-content: center;
|
| 34 |
+
align-items: center;
|
| 35 |
+
transition: opacity 0.8s ease-out;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.spinner {
|
| 39 |
+
width: 50px;
|
| 40 |
+
height: 50px;
|
| 41 |
+
border: 3px solid rgba(0, 206, 201, 0.1);
|
| 42 |
+
border-top: 3px solid #00cec9; /* Neon teal accent */
|
| 43 |
+
border-radius: 50%;
|
| 44 |
+
animation: spin 1s infinite cubic-bezier(0.55, 0.15, 0.45, 0.85);
|
| 45 |
+
margin-bottom: 20px;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
@keyframes spin {
|
| 49 |
+
0% { transform: rotate(0deg); }
|
| 50 |
+
100% { transform: rotate(360deg); }
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
/* =========================================
|
| 54 |
+
3. Left Sidebar (Main Controls)
|
| 55 |
+
========================================= */
|
| 56 |
+
.sidebar {
|
| 57 |
+
width: 380px;
|
| 58 |
+
background: rgba(15, 23, 42, 0.85);
|
| 59 |
+
backdrop-filter: blur(20px) saturate(180%);
|
| 60 |
+
/* Glassmorphism: blur effect creates depth against the map */
|
| 61 |
+
-webkit-backdrop-filter: blur(20px) saturate(180%);
|
| 62 |
+
padding: 25px;
|
| 63 |
+
display: flex;
|
| 64 |
+
flex-direction: column;
|
| 65 |
+
gap: 20px;
|
| 66 |
+
box-shadow: 10px 0 30px rgba(0,0,0,0.5);
|
| 67 |
+
z-index: 10;
|
| 68 |
+
border-right: 1px solid rgba(255,255,255,0.05);
|
| 69 |
+
overflow-y: auto;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.header h1 {
|
| 73 |
+
font-size: 24px;
|
| 74 |
+
font-weight: 300;
|
| 75 |
+
letter-spacing: 1px;
|
| 76 |
+
/* Gradient text for futuristic branding */
|
| 77 |
+
background: linear-gradient(to right, #00cec9, #a29bfe);
|
| 78 |
+
-webkit-background-clip: text;
|
| 79 |
+
-webkit-text-fill-color: transparent;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/* Search Section in Left Sidebar */
|
| 83 |
+
.search-section {
|
| 84 |
+
margin-top: 15px;
|
| 85 |
+
margin-bottom: 10px;
|
| 86 |
+
display: flex;
|
| 87 |
+
flex-direction: column;
|
| 88 |
+
gap: 8px;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.search-container {
|
| 92 |
+
display: flex;
|
| 93 |
+
gap: 8px;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
#search-input {
|
| 97 |
+
flex: 1;
|
| 98 |
+
background: rgba(15, 23, 42, 0.6);
|
| 99 |
+
border: 1px solid rgba(0, 206, 201, 0.4);
|
| 100 |
+
color: #fff;
|
| 101 |
+
padding: 6px 10px;
|
| 102 |
+
border-radius: 4px;
|
| 103 |
+
outline: none;
|
| 104 |
+
font-family: 'Courier New', monospace; /* "Code" feel for data input */
|
| 105 |
+
font-size: 14px;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
#search-input:focus {
|
| 109 |
+
border-color: #00cec9;
|
| 110 |
+
box-shadow: 0 0 8px rgba(0, 206, 201, 0.2);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.search-mode {
|
| 114 |
+
display: flex;
|
| 115 |
+
align-items: center;
|
| 116 |
+
gap: 8px;
|
| 117 |
+
font-size: 11px;
|
| 118 |
+
color: #94a3b8;
|
| 119 |
+
padding-left: 2px;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.search-mode input { cursor: pointer; accent-color: #00cec9; }
|
| 123 |
+
.search-mode label { cursor: pointer; }
|
| 124 |
+
|
| 125 |
+
/* Small Utility Buttons */
|
| 126 |
+
.cyber-btn-small {
|
| 127 |
+
background: rgba(0, 206, 201, 0.1);
|
| 128 |
+
color: #00cec9;
|
| 129 |
+
border: 1px solid #00cec9;
|
| 130 |
+
padding: 0 12px;
|
| 131 |
+
border-radius: 4px;
|
| 132 |
+
cursor: pointer;
|
| 133 |
+
font-weight: bold;
|
| 134 |
+
transition: all 0.2s;
|
| 135 |
+
display: flex;
|
| 136 |
+
align-items: center;
|
| 137 |
+
justify-content: center;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
.cyber-btn-small:hover {
|
| 141 |
+
background: #00cec9;
|
| 142 |
+
color: #000;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
#clear-search-btn {
|
| 146 |
+
border-color: #fd79a8;
|
| 147 |
+
color: #fd79a8;
|
| 148 |
+
background: rgba(253, 121, 168, 0.1);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
#clear-search-btn:hover {
|
| 152 |
+
background: #fd79a8;
|
| 153 |
+
color: #fff;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
/* Cards & Charts */
|
| 157 |
+
.card {
|
| 158 |
+
background: rgba(255, 255, 255, 0.03);
|
| 159 |
+
padding: 20px;
|
| 160 |
+
border-radius: 12px;
|
| 161 |
+
border: 1px solid rgba(255,255,255,0.02);
|
| 162 |
+
transition: transform 0.2s, background 0.2s;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
.card:hover {
|
| 166 |
+
background: rgba(255, 255, 255, 0.05);
|
| 167 |
+
border-color: rgba(0, 206, 201, 0.2);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
.card h2 {
|
| 171 |
+
font-size: 12px;
|
| 172 |
+
color: #94a3b8;
|
| 173 |
+
margin-bottom: 15px;
|
| 174 |
+
border-bottom: 1px solid rgba(255,255,255,0.05);
|
| 175 |
+
padding-bottom: 8px;
|
| 176 |
+
text-transform: uppercase;
|
| 177 |
+
font-weight: 600;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.chart-container {
|
| 181 |
+
height: 140px;
|
| 182 |
+
width: 100%;
|
| 183 |
+
position: relative;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
.stat-row {
|
| 187 |
+
display: flex;
|
| 188 |
+
justify-content: space-between;
|
| 189 |
+
margin-bottom: 5px;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
.value {
|
| 193 |
+
font-size: 18px;
|
| 194 |
+
font-weight: 500;
|
| 195 |
+
color: #f1f5f9;
|
| 196 |
+
font-family: 'Courier New', monospace;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
.value.highlight { color: #00cec9; }
|
| 200 |
+
|
| 201 |
+
.details-content p {
|
| 202 |
+
font-size: 13px;
|
| 203 |
+
line-height: 1.8;
|
| 204 |
+
color: #cbd5e1;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/* =========================================
|
| 208 |
+
4. Main Content & Map Area
|
| 209 |
+
========================================= */
|
| 210 |
+
.main-content {
|
| 211 |
+
flex: 1;
|
| 212 |
+
height: 100%;
|
| 213 |
+
position: relative;
|
| 214 |
+
overflow: hidden;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
#map { width: 100%; height: 100%; }
|
| 218 |
+
|
| 219 |
+
/* Mapbox Popup Overrides */
|
| 220 |
+
.mapboxgl-popup-content {
|
| 221 |
+
background: rgba(15, 23, 42, 0.95) !important;
|
| 222 |
+
border: 1px solid rgba(0, 206, 201, 0.5);
|
| 223 |
+
box-shadow: 0 0 15px rgba(0, 206, 201, 0.2);
|
| 224 |
+
padding: 8px 12px !important;
|
| 225 |
+
border-radius: 6px !important;
|
| 226 |
+
color: #e0e0e0;
|
| 227 |
+
min-width: 120px;
|
| 228 |
+
}
|
| 229 |
+
.mapboxgl-popup-tip {
|
| 230 |
+
border-top-color: rgba(0, 206, 201, 0.5) !important;
|
| 231 |
+
margin-bottom: -1px;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
/* =========================================
|
| 235 |
+
5. Top-Left Controls
|
| 236 |
+
========================================= */
|
| 237 |
+
.controls-container {
|
| 238 |
+
position: absolute;
|
| 239 |
+
top: 20px;
|
| 240 |
+
left: 20px;
|
| 241 |
+
z-index: 100;
|
| 242 |
+
display: flex;
|
| 243 |
+
gap: 10px;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
.cyber-btn {
|
| 247 |
+
background: rgba(15, 23, 42, 0.9);
|
| 248 |
+
color: #00cec9;
|
| 249 |
+
border: 1px solid #00cec9;
|
| 250 |
+
padding: 10px 20px;
|
| 251 |
+
font-size: 14px;
|
| 252 |
+
font-weight: 600;
|
| 253 |
+
border-radius: 4px;
|
| 254 |
+
cursor: pointer;
|
| 255 |
+
box-shadow: 0 0 10px rgba(0, 206, 201, 0.2);
|
| 256 |
+
transition: all 0.3s ease;
|
| 257 |
+
display: flex;
|
| 258 |
+
align-items: center;
|
| 259 |
+
gap: 8px;
|
| 260 |
+
text-transform: uppercase;
|
| 261 |
+
letter-spacing: 1px;
|
| 262 |
+
white-space: nowrap;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
.cyber-btn:hover {
|
| 266 |
+
background: rgba(0, 206, 201, 0.15);
|
| 267 |
+
box-shadow: 0 0 20px rgba(0, 206, 201, 0.4);
|
| 268 |
+
transform: translateY(-1px);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
/* Filter Menu */
|
| 272 |
+
.filter-wrapper { position: relative; display: inline-block; }
|
| 273 |
+
|
| 274 |
+
.filter-menu {
|
| 275 |
+
position: absolute;
|
| 276 |
+
top: 50px;
|
| 277 |
+
left: 0;
|
| 278 |
+
width: 170px;
|
| 279 |
+
background: rgba(15, 23, 42, 0.95);
|
| 280 |
+
border: 1px solid #00cec9;
|
| 281 |
+
border-radius: 4px;
|
| 282 |
+
padding: 8px;
|
| 283 |
+
display: flex;
|
| 284 |
+
flex-direction: column;
|
| 285 |
+
gap: 6px;
|
| 286 |
+
opacity: 0;
|
| 287 |
+
visibility: hidden;
|
| 288 |
+
transform: translateY(-10px);
|
| 289 |
+
transition: all 0.3s ease;
|
| 290 |
+
box-shadow: 0 5px 20px rgba(0,0,0,0.5);
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
.filter-menu.active {
|
| 294 |
+
opacity: 1;
|
| 295 |
+
visibility: visible;
|
| 296 |
+
transform: translateY(0);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
.filter-item {
|
| 300 |
+
display: flex;
|
| 301 |
+
align-items: center;
|
| 302 |
+
gap: 10px;
|
| 303 |
+
padding: 8px 10px;
|
| 304 |
+
border-radius: 4px;
|
| 305 |
+
cursor: pointer;
|
| 306 |
+
transition: background 0.2s;
|
| 307 |
+
font-size: 12px;
|
| 308 |
+
color: #ccc;
|
| 309 |
+
border: 1px solid transparent;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
.filter-item:hover { background: rgba(255, 255, 255, 0.1); }
|
| 313 |
+
|
| 314 |
+
.filter-item.selected {
|
| 315 |
+
background: rgba(0, 206, 201, 0.2);
|
| 316 |
+
border-color: rgba(0, 206, 201, 0.5);
|
| 317 |
+
color: #fff; font-weight: bold;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
.color-box { width: 12px; height: 12px; border-radius: 2px; }
|
| 321 |
+
|
| 322 |
+
/* =========================================
|
| 323 |
+
6. Bottom Time Panel
|
| 324 |
+
========================================= */
|
| 325 |
+
.time-panel {
|
| 326 |
+
position: absolute;
|
| 327 |
+
bottom: 40px;
|
| 328 |
+
left: 50%;
|
| 329 |
+
transform: translateX(-50%);
|
| 330 |
+
width: 60%;
|
| 331 |
+
min-width: 500px;
|
| 332 |
+
height: 70px;
|
| 333 |
+
background: rgba(15, 23, 42, 0.9);
|
| 334 |
+
border: 1px solid #00cec9;
|
| 335 |
+
box-shadow: 0 0 20px rgba(0, 206, 201, 0.15), inset 0 0 50px rgba(0,0,0,0.6);
|
| 336 |
+
backdrop-filter: blur(10px);
|
| 337 |
+
border-radius: 50px;
|
| 338 |
+
z-index: 100;
|
| 339 |
+
display: flex;
|
| 340 |
+
align-items: center;
|
| 341 |
+
padding: 0 25px;
|
| 342 |
+
gap: 20px;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
.play-control {
|
| 346 |
+
min-width: 45px;
|
| 347 |
+
height: 45px;
|
| 348 |
+
border-radius: 50%;
|
| 349 |
+
justify-content: center;
|
| 350 |
+
padding: 0;
|
| 351 |
+
font-size: 18px;
|
| 352 |
+
border-width: 2px;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
.digital-clock {
|
| 356 |
+
font-family: 'Courier New', monospace;
|
| 357 |
+
font-size: 20px;
|
| 358 |
+
font-weight: bold;
|
| 359 |
+
color: #00cec9;
|
| 360 |
+
text-shadow: 0 0 8px rgba(0, 206, 201, 0.8);
|
| 361 |
+
background: rgba(0, 0, 0, 0.4);
|
| 362 |
+
padding: 5px 12px;
|
| 363 |
+
border-radius: 6px;
|
| 364 |
+
border: 1px solid rgba(0, 206, 201, 0.2);
|
| 365 |
+
min-width: 80px;
|
| 366 |
+
text-align: center;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
.slider-wrapper {
|
| 370 |
+
flex: 1;
|
| 371 |
+
display: flex;
|
| 372 |
+
flex-direction: column;
|
| 373 |
+
justify-content: center;
|
| 374 |
+
position: relative;
|
| 375 |
+
margin-top: -2px;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
.slider-ticks {
|
| 379 |
+
display: flex;
|
| 380 |
+
justify-content: space-between;
|
| 381 |
+
margin-top: 8px;
|
| 382 |
+
font-size: 10px;
|
| 383 |
+
color: #64748b;
|
| 384 |
+
font-family: monospace;
|
| 385 |
+
padding: 0 2px;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
#time-slider {
|
| 389 |
+
-webkit-appearance: none;
|
| 390 |
+
width: 100%;
|
| 391 |
+
height: 4px;
|
| 392 |
+
background: rgba(255, 255, 255, 0.1);
|
| 393 |
+
border-radius: 2px;
|
| 394 |
+
outline: none;
|
| 395 |
+
cursor: pointer;
|
| 396 |
+
transition: background 0.3s;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
#time-slider:hover { background: rgba(255, 255, 255, 0.2); }
|
| 400 |
+
|
| 401 |
+
#time-slider::-webkit-slider-thumb {
|
| 402 |
+
-webkit-appearance: none;
|
| 403 |
+
width: 22px;
|
| 404 |
+
height: 22px;
|
| 405 |
+
border-radius: 50%;
|
| 406 |
+
background: #0f172a;
|
| 407 |
+
border: 2px solid #00cec9;
|
| 408 |
+
box-shadow: 0 0 10px #00cec9;
|
| 409 |
+
margin-top: 0px;
|
| 410 |
+
transition: transform 0.1s;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
#time-slider::-webkit-slider-thumb:hover {
|
| 414 |
+
transform: scale(1.2);
|
| 415 |
+
background: #00cec9;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
/* =========================================
|
| 419 |
+
7. Right Sidebar (Prediction Mode)
|
| 420 |
+
========================================= */
|
| 421 |
+
/* Container for the sliding panel */
|
| 422 |
+
.sidebar-right {
|
| 423 |
+
position: fixed;
|
| 424 |
+
top: 0;
|
| 425 |
+
right: -450px; /* Hidden by default */
|
| 426 |
+
width: 400px;
|
| 427 |
+
height: 100vh;
|
| 428 |
+
background: rgba(10, 10, 30, 0.95);
|
| 429 |
+
border-left: 1px solid #f39c12;
|
| 430 |
+
backdrop-filter: blur(10px);
|
| 431 |
+
transition: right 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
| 432 |
+
z-index: 1000;
|
| 433 |
+
padding: 25px;
|
| 434 |
+
color: #fff;
|
| 435 |
+
box-shadow: -10px 0 30px rgba(0,0,0,0.5);
|
| 436 |
+
display: flex;
|
| 437 |
+
flex-direction: column;
|
| 438 |
+
overflow-y: auto;
|
| 439 |
+
}
|
| 440 |
+
/* Active state (Slid in) */
|
| 441 |
+
.sidebar-right.active {
|
| 442 |
+
right: 0;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
/* Custom Cyber-Scrollbar for Prediction Sidebar */
|
| 446 |
+
.sidebar-right::-webkit-scrollbar {
|
| 447 |
+
width: 6px;
|
| 448 |
+
}
|
| 449 |
+
.sidebar-right::-webkit-scrollbar-track {
|
| 450 |
+
background: rgba(0, 0, 0, 0.3);
|
| 451 |
+
}
|
| 452 |
+
.sidebar-right::-webkit-scrollbar-thumb {
|
| 453 |
+
background: #f39c12;
|
| 454 |
+
border-radius: 3px;
|
| 455 |
+
}
|
| 456 |
+
.sidebar-right::-webkit-scrollbar-thumb:hover {
|
| 457 |
+
background: #e67e22;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
/* Header within right sidebar */
|
| 461 |
+
.sidebar-right .header {
|
| 462 |
+
display: flex;
|
| 463 |
+
justify-content: space-between;
|
| 464 |
+
align-items: center;
|
| 465 |
+
border-bottom: 1px solid rgba(243, 156, 18, 0.3);
|
| 466 |
+
padding-bottom: 15px;
|
| 467 |
+
margin-bottom: 20px;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
.sidebar-right h1 {
|
| 471 |
+
font-size: 20px;
|
| 472 |
+
background: linear-gradient(to right, #f39c12, #f1c40f);
|
| 473 |
+
-webkit-background-clip: text;
|
| 474 |
+
-webkit-text-fill-color: transparent;
|
| 475 |
+
text-transform: uppercase;
|
| 476 |
+
letter-spacing: 1px;
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
/* Close button for right sidebar */
|
| 480 |
+
#close-pred-btn {
|
| 481 |
+
border-color: #666;
|
| 482 |
+
color: #aaa;
|
| 483 |
+
background: transparent;
|
| 484 |
+
}
|
| 485 |
+
#close-pred-btn:hover {
|
| 486 |
+
border-color: #fff;
|
| 487 |
+
color: #fff;
|
| 488 |
+
background: rgba(255,255,255,0.1);
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
/* Prediction specific button states */
|
| 492 |
+
#predict-toggle.predict-on {
|
| 493 |
+
background: rgba(243, 156, 18, 0.2);
|
| 494 |
+
box-shadow: 0 0 20px rgba(243, 156, 18, 0.4);
|
| 495 |
+
border-color: #f39c12;
|
| 496 |
+
color: #f39c12;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
/* Legend items in prediction panel */
|
| 500 |
+
.legend-box {
|
| 501 |
+
margin-top: auto; /* Push to bottom if needed, or just normal flow */
|
| 502 |
+
border: 1px solid rgba(255,255,255,0.1);
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
.dot {
|
| 506 |
+
width: 10px;
|
| 507 |
+
height: 10px;
|
| 508 |
+
display: inline-block;
|
| 509 |
+
border-radius: 50%;
|
| 510 |
+
margin-right: 8px;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
/* =========================================
|
| 515 |
+
8. Map Custom Pins & AI Log Visuals
|
| 516 |
+
========================================= */
|
| 517 |
+
.optimal-pulse-pin {
|
| 518 |
+
width: 20px;
|
| 519 |
+
height: 20px;
|
| 520 |
+
background-color: #2ecc71;
|
| 521 |
+
border-radius: 50%;
|
| 522 |
+
border: 3px solid #ffffff;
|
| 523 |
+
box-shadow: 0 0 15px #2ecc71, 0 0 30px #2ecc71;
|
| 524 |
+
animation: optimal-pulse 1.5s infinite cubic-bezier(0.66, 0, 0, 1);
|
| 525 |
+
cursor: pointer;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
@keyframes optimal-pulse {
|
| 529 |
+
to {
|
| 530 |
+
box-shadow: 0 0 0 20px rgba(46, 204, 113, 0);
|
| 531 |
+
background-color: rgba(46, 204, 113, 0.8);
|
| 532 |
+
}
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
.cyber-explanation {
|
| 536 |
+
margin-top: 12px;
|
| 537 |
+
padding: 12px;
|
| 538 |
+
background: rgba(0, 206, 201, 0.05);
|
| 539 |
+
border-left: 3px solid #00cec9;
|
| 540 |
+
border-radius: 0 4px 4px 0;
|
| 541 |
+
font-size: 11px;
|
| 542 |
+
color: #a29bfe;
|
| 543 |
+
font-family: 'Courier New', Courier, monospace;
|
| 544 |
+
line-height: 1.6;
|
| 545 |
+
box-shadow: inset 0 0 10px rgba(0, 206, 201, 0.05);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
.cyber-explanation strong {
|
| 549 |
+
color: #00cec9;
|
| 550 |
+
font-weight: bold;
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
/* =========================================
|
| 554 |
+
9. Panel Toggle Buttons & Navigation UI
|
| 555 |
+
========================================= */
|
| 556 |
+
.sidebar {
|
| 557 |
+
transition: margin-left 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
| 558 |
+
position: relative;
|
| 559 |
+
}
|
| 560 |
+
.sidebar.collapsed {
|
| 561 |
+
margin-left: -380px;
|
| 562 |
+
}
|
| 563 |
+
.sidebar-right.active.collapsed {
|
| 564 |
+
right: -400px;
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
/* Sticky toggle buttons on panel edges */
|
| 568 |
+
.panel-toggle-btn {
|
| 569 |
+
position: absolute;
|
| 570 |
+
top: 50%;
|
| 571 |
+
transform: translateY(-50%);
|
| 572 |
+
width: 22px;
|
| 573 |
+
height: 60px;
|
| 574 |
+
background: rgba(15, 23, 42, 0.9);
|
| 575 |
+
border: 1px solid #00cec9;
|
| 576 |
+
color: #00cec9;
|
| 577 |
+
cursor: pointer;
|
| 578 |
+
z-index: 1000;
|
| 579 |
+
display: flex;
|
| 580 |
+
align-items: center;
|
| 581 |
+
justify-content: center;
|
| 582 |
+
font-size: 10px;
|
| 583 |
+
box-shadow: 0 0 10px rgba(0, 206, 201, 0.2);
|
| 584 |
+
backdrop-filter: blur(5px);
|
| 585 |
+
transition: all 0.2s ease;
|
| 586 |
+
outline: none;
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
.panel-toggle-btn:hover {
|
| 590 |
+
background: #00cec9;
|
| 591 |
+
color: #000;
|
| 592 |
+
box-shadow: 0 0 15px rgba(0, 206, 201, 0.5);
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
/* Positioning logic for sidebars toggles */
|
| 596 |
+
.left-toggle {
|
| 597 |
+
left: 380px; /* Aligned with sidebar width */
|
| 598 |
+
border-left: none;
|
| 599 |
+
border-radius: 0 6px 6px 0;
|
| 600 |
+
transition: left 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
| 601 |
+
}
|
| 602 |
+
.left-toggle.collapsed {
|
| 603 |
+
left: 0;
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
.right-toggle {
|
| 607 |
+
right: -22px; /* Starts hidden */
|
| 608 |
+
border-right: none;
|
| 609 |
+
border-radius: 6px 0 0 6px;
|
| 610 |
+
transition: right 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
| 611 |
+
}
|
| 612 |
+
.right-toggle.active {
|
| 613 |
+
right: 400px;
|
| 614 |
+
}
|
| 615 |
+
.right-toggle.collapsed {
|
| 616 |
+
right: 0;
|
| 617 |
+
}
|