Karaku9 commited on
Commit
0001428
·
verified ·
1 Parent(s): e7fe3b6

Upload 15 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. Use official lightweight Python 3.10 image to minimize build size
2
+ FROM python:3.10-slim
3
+
4
+ # 2. Define the working directory inside the container
5
+ WORKDIR /app
6
+
7
+ # 3. Create a non-root user with UID 1000
8
+ RUN useradd -m -u 1000 user
9
+ USER user
10
+ ENV PATH="/home/user/.local/bin:$PATH"
11
+
12
+ # 4. Copy requirements file and install dependencies
13
+ COPY --chown=user requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # 5. Copy the entire project source code and data assets to the container
17
+ COPY --chown=user . /app
18
+
19
+ # 6. Expose port 7860
20
+ EXPOSE 7860
21
+
22
+ # 7. Execute the Flask server script
23
+ CMD ["python", "server.py"]
best_corr_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83b578e901f1f3d11421431ec5286548ae200b0ed1165afee66b6f05ad19bd76
3
+ size 563607134
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/base2info.json ADDED
The diff for this file is too large to render. See raw diff
 
data/bs_record_energy_normalized_sampled.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62a53c46d7cc8b0fca0fa9943eac045ca1853a4eb1df2a4f58c5c939179b66c7
3
+ size 74859436
data/spatial_features.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:395420bb4696091d85bae743e680b861326342965ef7c85a21eb22d9c295e4ea
3
+ size 294559
hierarchical_flow_matching_training_v4.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hierarchical Flow Matching Training Framework (V4) - Generative Fusion
3
+ ======================================================================
4
+
5
+ Complete training pipeline with:
6
+ 1. Three-level cascaded Flow Matching losses
7
+ 2. Hierarchical multi-periodic supervision
8
+ 3. Temporal structure preservation
9
+ 4. Adaptive learning rate scheduling
10
+
11
+ [FUSION] Generative Mode: Implicit alignment via conditional flow matching.
12
+ Enhanced with explicit peak conditioning and auxiliary classification.
13
+ Physical Boundary Loss & Bias Correction.
14
+ """
15
+
16
+ import os
17
+ import json
18
+ import numpy as np
19
+ from typing import Dict, Optional, Tuple, Literal
20
+ from tqdm import tqdm
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from torch.utils.data import DataLoader
26
+ from torch.optim.lr_scheduler import CosineAnnealingLR
27
+
28
+ from hierarchical_flow_matching_v4 import HierarchicalFlowMatchingV4
29
+ from multimodal_spatial_encoder_v4 import MultiModalSpatialEncoderV4
30
+
31
+
32
+ # =============================================================================
33
+ # Hierarchical Multi-Periodic Loss Functions
34
+ # =============================================================================
35
+
36
+ class HierarchicalFlowMatchingLoss(nn.Module):
37
+ """
38
+ Hierarchical Flow Matching loss with multi-periodic supervision.
39
+
40
+ Combines:
41
+ 1. Level 1 (Daily) Flow Matching loss
42
+ 2. Level 2 (Weekly) Flow Matching loss
43
+ 3. Level 3 (Residual) Flow Matching loss [Peak Conditioned]
44
+ 4. Temporal structure preservation loss
45
+ 5. Multi-periodic consistency loss
46
+ 6. Peak Hour Classification Loss
47
+ 7. Physical Boundary Loss (No-Negative Constraint)
48
+ 8. Bias Correction Loss (Global Mean Alignment)
49
+ """
50
+
51
+ def __init__(self):
52
+ super().__init__()
53
+
54
+ # Helper: Physical Constraint
55
+ def compute_boundary_loss(self, predicted_x1: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Penalize negative values in the estimated traffic.
58
+ Loss = ReLU(-x).mean() * scale
59
+ """
60
+ return F.relu(-predicted_x1).mean() * 10.0
61
+
62
+ # Helper: Bias Constraint
63
+ def compute_bias_loss(self, generated: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
64
+ """
65
+ Fix 'Parallel Lines' issue by forcing global mean alignment.
66
+ """
67
+ gen_mean = generated.mean(dim=1) # [B]
68
+ real_mean = real.mean(dim=1) # [B]
69
+ return F.l1_loss(gen_mean, real_mean) * 20.0
70
+
71
+ def compute_level1_loss(
72
+ self,
73
+ model: HierarchicalFlowMatchingV4,
74
+ real_traffic: torch.Tensor,
75
+ spatial_cond_level1: torch.Tensor,
76
+ ) -> Tuple[torch.Tensor, torch.Tensor]: # [MODIFIED] Returns tuple
77
+ """
78
+ Level 1 (Day-Type Templates) Flow Matching loss.
79
+ """
80
+ B = real_traffic.shape[0]
81
+ device = real_traffic.device
82
+
83
+ # 672 hourly samples = 28 days * 24 hours
84
+ steps_per_day = 24
85
+ n_days = 28
86
+ real_reshaped = real_traffic.reshape(B, n_days, steps_per_day) # [B, 28, 24]
87
+
88
+ # Assume sequence starts on Monday
89
+ day_of_week = torch.arange(n_days, device=device) % 7
90
+ weekday_idx = torch.where(day_of_week < 5)[0]
91
+ weekend_idx = torch.where(day_of_week >= 5)[0]
92
+
93
+ weekday_pattern = real_reshaped.index_select(1, weekday_idx).mean(dim=1) # [B, 24]
94
+ weekend_pattern = real_reshaped.index_select(1, weekend_idx).mean(dim=1) # [B, 24]
95
+
96
+ # Target is concatenation: [weekday(24), weekend(24)] -> [B, 48]
97
+ x1 = torch.cat([weekday_pattern, weekend_pattern], dim=1)
98
+
99
+ # Sample noise
100
+ x0 = torch.randn_like(x1)
101
+
102
+ # Sample time
103
+ t = torch.rand(B, 1, device=device)
104
+
105
+ # Interpolation
106
+ x_t = t * x1 + (1 - t) * x0
107
+
108
+ # Target velocity
109
+ v_target = x1 - x0
110
+
111
+ # Predict velocity
112
+ v_pred = model(x_t, t, spatial_cond_level1, level=1)
113
+
114
+ # Flow Matching loss
115
+ loss_fm = F.mse_loss(v_pred, v_target)
116
+
117
+ # Boundary Loss for Level 1
118
+ x1_est = x_t + (1 - t) * v_pred
119
+ loss_boundary = self.compute_boundary_loss(x1_est)
120
+
121
+ return loss_fm, loss_boundary
122
+
123
+ def compute_level2_loss(
124
+ self,
125
+ model: HierarchicalFlowMatchingV4,
126
+ real_traffic: torch.Tensor,
127
+ spatial_cond_level2: torch.Tensor,
128
+ spatial_cond_level1: torch.Tensor,
129
+ daily_pattern: Optional[torch.Tensor] = None,
130
+ use_teacher_forcing: bool = True,
131
+ n_steps_generate: int = 10,
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]: # [MODIFIED] Returns tuple
133
+ """
134
+ Level 2 (Weekly Pattern, 168 hours) Flow Matching loss.
135
+ """
136
+ B = real_traffic.shape[0]
137
+ device = real_traffic.device
138
+
139
+ steps_per_day = 24
140
+ n_days = 28
141
+ n_weeks = 4
142
+ real_reshaped = real_traffic.reshape(B, n_days, steps_per_day) # [B, 28, 24]
143
+
144
+ # Weekly pattern ground truth
145
+ weekly_days = []
146
+ for dow in range(7):
147
+ idx = torch.tensor([dow + 7 * w for w in range(n_weeks)], device=device, dtype=torch.long)
148
+ weekly_days.append(real_reshaped.index_select(1, idx).mean(dim=1)) # [B, 24]
149
+ weekly_pattern = torch.stack(weekly_days, dim=1).reshape(B, 7 * steps_per_day) # [B, 168]
150
+
151
+ x1 = weekly_pattern # target weekly pattern
152
+
153
+ # Get day-type templates (weekday/weekend)
154
+ if daily_pattern is None or not use_teacher_forcing:
155
+ with torch.no_grad():
156
+ daily_pattern = model.generate_daily_pattern(
157
+ spatial_cond_level1, n_steps=n_steps_generate
158
+ )
159
+ else:
160
+ # Teacher forcing
161
+ day_of_week = torch.arange(n_days, device=device) % 7
162
+ weekday_idx = torch.where(day_of_week < 5)[0]
163
+ weekend_idx = torch.where(day_of_week >= 5)[0]
164
+ weekday_pattern = real_reshaped.index_select(1, weekday_idx).mean(dim=1) # [B, 24]
165
+ weekend_pattern = real_reshaped.index_select(1, weekend_idx).mean(dim=1) # [B, 24]
166
+ daily_pattern = torch.cat([weekday_pattern, weekend_pattern], dim=1) # [B, 48]
167
+
168
+ # Sample noise
169
+ x0 = torch.randn_like(x1)
170
+
171
+ # Sample time
172
+ t = torch.rand(B, 1, device=device)
173
+
174
+ # Interpolation
175
+ x_t = t * x1 + (1 - t) * x0
176
+
177
+ # Target velocity
178
+ v_target = x1 - x0
179
+
180
+ # Predict velocity
181
+ v_pred = model(x_t, t, spatial_cond_level2, level=2, daily_pattern=daily_pattern)
182
+
183
+ # Flow Matching loss
184
+ loss_fm = F.mse_loss(v_pred, v_target)
185
+
186
+ # Boundary Loss for Level 2
187
+ x1_est = x_t + (1 - t) * v_pred
188
+ loss_boundary = self.compute_boundary_loss(x1_est)
189
+
190
+ return loss_fm, loss_boundary
191
+
192
+ def compute_level3_loss(
193
+ self,
194
+ model: HierarchicalFlowMatchingV4,
195
+ real_traffic: torch.Tensor,
196
+ spatial_cond_level3: torch.Tensor,
197
+ spatial_cond_level2: torch.Tensor,
198
+ spatial_cond_level1: torch.Tensor,
199
+ peak_hour_gt: torch.Tensor, # Explicit Peak GT
200
+ daily_pattern: Optional[torch.Tensor] = None,
201
+ weekly_trend: Optional[torch.Tensor] = None,
202
+ use_teacher_forcing: bool = True,
203
+ n_steps_generate: int = 10,
204
+ ) -> Tuple[torch.Tensor, torch.Tensor]: # [MODIFIED] Returns tuple
205
+ """
206
+ Level 3 (Residual over 672 hours) Flow Matching loss.
207
+ Models fine-grained hourly fluctuations after removing periodic trends.
208
+ """
209
+ B = real_traffic.shape[0]
210
+ device = real_traffic.device
211
+
212
+ steps_per_day = 24
213
+ n_days = 28
214
+ n_weeks = 4
215
+ real_reshaped = real_traffic.reshape(B, n_days, steps_per_day) # [B, 28, 24]
216
+
217
+ # Ground-truth weekly pattern (168)
218
+ weekly_days = []
219
+ for dow in range(7):
220
+ idx = torch.tensor([dow + 7 * w for w in range(n_weeks)], device=device, dtype=torch.long)
221
+ weekly_days.append(real_reshaped.index_select(1, idx).mean(dim=1)) # [B, 24]
222
+ weekly_pattern_gt = torch.stack(weekly_days, dim=1).reshape(B, 7 * steps_per_day) # [B, 168]
223
+
224
+ # Get daily pattern and weekly trend
225
+ if use_teacher_forcing:
226
+ # Teacher forcing
227
+ day_of_week = torch.arange(n_days, device=device) % 7
228
+ weekday_idx = torch.where(day_of_week < 5)[0]
229
+ weekend_idx = torch.where(day_of_week >= 5)[0]
230
+ weekday_pattern = real_reshaped.index_select(1, weekday_idx).mean(dim=1) # [B, 24]
231
+ weekend_pattern = real_reshaped.index_select(1, weekend_idx).mean(dim=1) # [B, 24]
232
+ daily_pattern = torch.cat([weekday_pattern, weekend_pattern], dim=1) # [B, 48]
233
+ weekly_trend = weekly_pattern_gt
234
+ else:
235
+ if daily_pattern is None:
236
+ with torch.no_grad():
237
+ daily_pattern = model.generate_daily_pattern(
238
+ spatial_cond_level1, n_steps=n_steps_generate
239
+ )
240
+
241
+ if weekly_trend is None:
242
+ with torch.no_grad():
243
+ weekly_trend = model.generate_weekly_trend(
244
+ daily_pattern, spatial_cond_level2, n_steps=n_steps_generate
245
+ )
246
+
247
+ # Construct periodic component (coarse signal) from weekly pattern
248
+ coarse_signal = weekly_trend.repeat(1, n_weeks) # [B, 672]
249
+
250
+ # Target residual
251
+ x1 = real_traffic - coarse_signal # [B, 672]
252
+
253
+ # Sample noise
254
+ x0 = 0.1 * torch.randn_like(x1)
255
+
256
+ # Sample time
257
+ t = torch.rand(B, 1, device=device)
258
+
259
+ # Interpolation
260
+ x_t = t * x1 + (1 - t) * x0
261
+
262
+ # Target velocity
263
+ v_target = x1 - x0
264
+
265
+ # Predict velocity
266
+ # Pass peak_hour_gt to model
267
+ v_pred = model(
268
+ x_t, t, spatial_cond_level3, level=3,
269
+ daily_pattern=daily_pattern,
270
+ weekly_trend=weekly_trend,
271
+ coarse_signal=coarse_signal,
272
+ peak_hour=peak_hour_gt
273
+ )
274
+
275
+ # Flow Matching loss
276
+ loss_fm = F.mse_loss(v_pred, v_target)
277
+
278
+ # Boundary Loss for Level 3
279
+ # Ensure that (Coarse + Residual) >= 0
280
+ residual_est = x_t + (1 - t) * v_pred
281
+ final_traffic_est = coarse_signal + residual_est
282
+ loss_boundary = self.compute_boundary_loss(final_traffic_est)
283
+
284
+ return loss_fm, loss_boundary
285
+
286
+ def compute_temporal_structure_loss(
287
+ self,
288
+ generated: torch.Tensor,
289
+ real: torch.Tensor,
290
+ ) -> torch.Tensor:
291
+ """
292
+ Temporal structure preservation loss.
293
+ """
294
+ d_gen = generated[..., 1:] - generated[..., :-1]
295
+ d_real = real[..., 1:] - real[..., :-1]
296
+ loss_deriv = F.mse_loss(d_gen, d_real)
297
+ return loss_deriv
298
+
299
+ def compute_multi_periodic_consistency_loss(
300
+ self,
301
+ generated: torch.Tensor,
302
+ real: torch.Tensor,
303
+ ) -> torch.Tensor:
304
+ """
305
+ Multi-periodic consistency loss.
306
+ """
307
+ B = generated.shape[0]
308
+ device = generated.device
309
+
310
+ steps_per_day = 24
311
+ n_days = 28
312
+ n_weeks = 4
313
+
314
+ gen_days = generated.reshape(B, n_days, steps_per_day) # [B, 28, 24]
315
+ real_days = real.reshape(B, n_days, steps_per_day)
316
+
317
+ # Daily mean pattern
318
+ gen_daily = gen_days.mean(dim=1)
319
+ real_daily = real_days.mean(dim=1)
320
+ loss_daily = F.mse_loss(gen_daily, real_daily)
321
+
322
+ # Weekly pattern
323
+ def weekly_pattern(x_days: torch.Tensor) -> torch.Tensor:
324
+ days = []
325
+ for dow in range(7):
326
+ idx = torch.tensor([dow + 7 * w for w in range(n_weeks)], device=device, dtype=torch.long)
327
+ days.append(x_days.index_select(1, idx).mean(dim=1)) # [B, 24]
328
+ return torch.stack(days, dim=1).reshape(B, 7 * steps_per_day)
329
+
330
+ gen_weekly = weekly_pattern(gen_days)
331
+ real_weekly = weekly_pattern(real_days)
332
+ loss_weekly = F.mse_loss(gen_weekly, real_weekly)
333
+
334
+ return loss_daily + loss_weekly
335
+
336
+ # Pearson Correlation Loss
337
+ def compute_correlation_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
338
+ """
339
+ Loss = 1 - Correlation. 强迫模型优化波形形状。
340
+ """
341
+ # 1. Center the data
342
+ pred_mean = pred - pred.mean(dim=1, keepdim=True)
343
+ target_mean = target - target.mean(dim=1, keepdim=True)
344
+
345
+ # 2. Normalize
346
+ pred_norm = torch.norm(pred_mean, p=2, dim=1) + 1e-8
347
+ target_norm = torch.norm(target_mean, p=2, dim=1) + 1e-8
348
+
349
+ # 3. Calculate cosine similarity (i.e., correlation after mean-shifting)
350
+ cosine_sim = (pred_mean * target_mean).sum(dim=1) / (pred_norm * target_norm)
351
+
352
+ # 4. Loss = 1 - Correlation
353
+ return 1.0 - cosine_sim.mean()
354
+
355
+ def forward(
356
+ self,
357
+ model: HierarchicalFlowMatchingV4,
358
+ real_traffic: torch.Tensor,
359
+ spatial_cond: Dict[str, torch.Tensor] | torch.Tensor,
360
+ fusion_method: str = 'generative',
361
+ lambda_level1: float = 1.0,
362
+ lambda_level2: float = 1.0,
363
+ lambda_level3: float = 1.0,
364
+ lambda_temporal: float = 0.1,
365
+ lambda_periodic: float = 0.1,
366
+ lambda_corr: float = 0.5,
367
+ lambda_boundary: float = 1.0,
368
+ lambda_bias: float = 1.0,
369
+ teacher_forcing_ratio: float = 1.0,
370
+ n_steps_generate: int = 10,
371
+ **kwargs
372
+ ) -> Dict[str, torch.Tensor]:
373
+ """
374
+ Compute combined hierarchical loss.
375
+ """
376
+ if isinstance(spatial_cond, torch.Tensor):
377
+ # Compatibility fallback
378
+ spatial_cond = {
379
+ 'level1_cond': spatial_cond,
380
+ 'level2_cond': spatial_cond,
381
+ 'level3_cond': spatial_cond,
382
+ 'pred_peak_logits': None
383
+ }
384
+
385
+ # ---------------------------------------------------------------------
386
+ # 1. Derive Ground Truth Peak Hour
387
+ # ---------------------------------------------------------------------
388
+ # Reshape to [B, days, 24] -> mean daily pattern -> argmax
389
+ B = real_traffic.shape[0]
390
+ avg_daily = real_traffic.reshape(B, -1, 24).mean(dim=1)
391
+ peak_hour_gt = avg_daily.argmax(dim=1) # [B] (0-23)
392
+
393
+ # 2. Auxiliary Classification Loss
394
+ pred_peak_logits = spatial_cond.get('pred_peak_logits', None)
395
+ if pred_peak_logits is not None:
396
+ loss_peak_cls = F.cross_entropy(pred_peak_logits, peak_hour_gt)
397
+ else:
398
+ loss_peak_cls = torch.tensor(0.0, device=real_traffic.device)
399
+
400
+ # Determine teacher forcing
401
+ use_tf = torch.rand(1).item() < teacher_forcing_ratio
402
+
403
+ # ---------------------------------------------------------------------
404
+ # 3. Compute Level Losses (FM + Boundary)
405
+ # ---------------------------------------------------------------------
406
+ loss_l1_fm, loss_l1_bound = self.compute_level1_loss(
407
+ model, real_traffic, spatial_cond['level1_cond']
408
+ )
409
+ loss_l2_fm, loss_l2_bound = self.compute_level2_loss(
410
+ model,
411
+ real_traffic,
412
+ spatial_cond_level2=spatial_cond['level2_cond'],
413
+ spatial_cond_level1=spatial_cond['level1_cond'],
414
+ use_teacher_forcing=use_tf,
415
+ n_steps_generate=n_steps_generate,
416
+ )
417
+ # Pass peak_hour_gt
418
+ loss_l3_fm, loss_l3_bound = self.compute_level3_loss(
419
+ model,
420
+ real_traffic,
421
+ spatial_cond_level3=spatial_cond['level3_cond'],
422
+ spatial_cond_level2=spatial_cond['level2_cond'],
423
+ spatial_cond_level1=spatial_cond['level1_cond'],
424
+ peak_hour_gt=peak_hour_gt, # Explicit GT
425
+ use_teacher_forcing=use_tf,
426
+ n_steps_generate=n_steps_generate,
427
+ )
428
+
429
+ # ---------------------------------------------------------------------
430
+ # 4. Aux Losses (Temporal, Periodic, Bias)
431
+ # ---------------------------------------------------------------------
432
+ loss_temporal = torch.tensor(0.0, device=real_traffic.device)
433
+ loss_periodic = torch.tensor(0.0, device=real_traffic.device)
434
+ loss_bias = torch.tensor(0.0, device=real_traffic.device)
435
+ loss_corr = torch.tensor(0.0, device=real_traffic.device)
436
+
437
+ # Increase Generation Sampling Frequency
438
+ # If lambda_corr is significant (>0.1), we increase the sampling probability from 30% to 60%
439
+ # This enables more frequent computation of the Correlation Loss
440
+ prob_threshold = 0.6 if lambda_corr > 0.1 else 0.3
441
+
442
+ # Only compute generation-based losses occasionally to save time
443
+ should_compute_gen_losses = (lambda_temporal > 0 or lambda_periodic > 0 or lambda_bias > 0 or lambda_corr > 0)
444
+
445
+ if should_compute_gen_losses and torch.rand(1).item() < 0.3:
446
+ with torch.no_grad():
447
+ # Must provide peak_hour for generation
448
+ generated, _ = model.generate_hierarchical(
449
+ spatial_cond,
450
+ peak_hour=peak_hour_gt,
451
+ n_steps_per_level=n_steps_generate
452
+ )
453
+
454
+ if lambda_corr > 0:
455
+ loss_corr = self.compute_correlation_loss(generated, real_traffic)
456
+
457
+ if lambda_temporal > 0:
458
+ loss_temporal = self.compute_temporal_structure_loss(generated, real_traffic)
459
+
460
+ if lambda_periodic > 0:
461
+ loss_periodic = self.compute_multi_periodic_consistency_loss(generated, real_traffic)
462
+
463
+ if lambda_bias > 0:
464
+ loss_bias = self.compute_bias_loss(generated, real_traffic)
465
+
466
+ # ---------------------------------------------------------------------
467
+ # 5. Combined Loss
468
+ # ---------------------------------------------------------------------
469
+
470
+ # FM Loss
471
+ loss_fm_total = (
472
+ lambda_level1 * loss_l1_fm +
473
+ lambda_level2 * loss_l2_fm +
474
+ lambda_level3 * loss_l3_fm
475
+ )
476
+
477
+ # Boundary Loss
478
+ loss_boundary_total = lambda_boundary * (loss_l1_bound + loss_l2_bound + loss_l3_bound)
479
+
480
+ # Bias Loss
481
+ loss_bias_total = lambda_bias * loss_bias
482
+
483
+ # Peak Classification Weight (static 0.5 for now)
484
+ lambda_peak = 5.0
485
+
486
+ total_loss = (
487
+ loss_fm_total +
488
+ loss_boundary_total +
489
+ loss_bias_total +
490
+ lambda_temporal * loss_temporal +
491
+ lambda_periodic * loss_periodic +
492
+ lambda_peak * loss_peak_cls +
493
+ lambda_corr * loss_corr # 加入总 Loss
494
+ )
495
+
496
+ return {
497
+ 'loss_level1': loss_l1_fm,
498
+ 'loss_level2': loss_l2_fm,
499
+ 'loss_level3': loss_l3_fm,
500
+ 'loss_boundary': loss_boundary_total,
501
+ 'loss_bias': loss_bias_total,
502
+ 'loss_temporal': loss_temporal,
503
+ 'loss_periodic': loss_periodic,
504
+ 'loss_peak_cls': loss_peak_cls,
505
+ 'loss_corr': loss_corr,
506
+ 'loss_total': total_loss,
507
+ }
508
+
509
+
510
+
511
+
512
+ # =============================================================================
513
+ # Complete Hierarchical Flow Matching Model with Encoder
514
+ # =============================================================================
515
+
516
+ class HierarchicalFlowMatchingSystemV4(nn.Module):
517
+ """
518
+ Complete system combining:
519
+ - Multi-modal spatial encoder
520
+ - Hierarchical Flow Matching model
521
+ """
522
+
523
+ def __init__(
524
+ self,
525
+ spatial_dim: int = 192,
526
+ hidden_dim: int = 256,
527
+ poi_dim: int = 20,
528
+ n_layers_level3: int = 6,
529
+ fusion_method: Literal['generative', 'contrastive'] = 'generative' # Default
530
+ ):
531
+ super().__init__()
532
+ self.fusion_method = fusion_method
533
+ self.spatial_dim = spatial_dim
534
+
535
+ # 1. Environment Encoder (Multi-modal)
536
+ self.spatial_encoder = MultiModalSpatialEncoderV4(spatial_dim, poi_dim)
537
+
538
+ # NOTE: No TrafficCLIPEncoder in Generative Mode
539
+ self.traffic_encoder = None
540
+
541
+ # 2. Flow Matching Generative Model
542
+ self.fm_model = HierarchicalFlowMatchingV4(spatial_dim, hidden_dim, n_layers_level3)
543
+
544
+ # 3. Loss
545
+ self.loss_fn = HierarchicalFlowMatchingLoss()
546
+
547
+ def forward(self, batch: Dict, mode: str = 'train', loss_cfg: Optional[Dict] = None) -> Dict:
548
+ """
549
+ Args:
550
+ batch: dict with spatial and traffic data
551
+ mode: 'train' or 'generate'
552
+ Returns:
553
+ outputs: dict with losses or generated samples
554
+ """
555
+ # Encode spatial features
556
+ spatial_cond_dict = self.spatial_encoder(batch)
557
+ loss_cfg = loss_cfg or {}
558
+
559
+ if mode == 'train':
560
+ real_traffic = batch['traffic_seq']
561
+
562
+ # Calculate Losses
563
+ losses = self.loss_fn(
564
+ model=self.fm_model,
565
+ real_traffic=real_traffic,
566
+ spatial_cond=spatial_cond_dict,
567
+ fusion_method=self.fusion_method,
568
+ **loss_cfg
569
+ )
570
+ return {'losses': losses}
571
+
572
+ elif mode == 'generate':
573
+ # Inference logic: Explicit Peak Conditioning
574
+ # 1. Use the auxiliary head to predict peak location
575
+ pred_logits = spatial_cond_dict['pred_peak_logits']
576
+ pred_peak_hour = pred_logits.argmax(dim=1) # [B]
577
+
578
+ # 2. Allow manual override if 'manual_peak_hour' is in batch
579
+ if 'manual_peak_hour' in batch:
580
+ pred_peak_hour = batch['manual_peak_hour']
581
+
582
+ # Generate hierarchical samples
583
+ generated, intermediates = self.fm_model.generate_hierarchical(
584
+ spatial_cond_dict,
585
+ peak_hour=pred_peak_hour,
586
+ n_steps_per_level=loss_cfg.get('n_steps_generate', 50),
587
+ )
588
+ return {'generated': generated, 'intermediates': intermediates, 'pred_peak_hour': pred_peak_hour}
589
+
590
+ else:
591
+ raise ValueError(f"Unknown mode: {mode}")
592
+
593
+
594
+ # =============================================================================
595
+ # Trainer
596
+ # =============================================================================
597
+
598
+ class HierarchicalFlowMatchingTrainerV4:
599
+ """
600
+ Trainer for Hierarchical Flow Matching V4.
601
+ """
602
+
603
+ def __init__(
604
+ self,
605
+ model: HierarchicalFlowMatchingSystemV4,
606
+ train_loader: DataLoader,
607
+ val_loader: DataLoader,
608
+ lr: float = 1e-4,
609
+ weight_decay: float = 0.01,
610
+ checkpoint_dir: str = "checkpoints_hfm_v4",
611
+ lambda_level1: float = 1.0,
612
+ lambda_level2: float = 1.0,
613
+ lambda_level3: float = 1.0,
614
+ lambda_temporal: float = 0.1,
615
+ lambda_periodic: float = 0.1,
616
+ lambda_boundary: float = 1.0,
617
+ lambda_bias: float = 1.0,
618
+ lambda_corr: float = 0.5,
619
+ warmup_epochs: int = 5,
620
+ ):
621
+ self.model = model
622
+ self.train_loader = train_loader
623
+ self.val_loader = val_loader
624
+ self.checkpoint_dir = checkpoint_dir
625
+
626
+ # Loss weights
627
+ self.loss_cfg = {
628
+ 'lambda_level1': lambda_level1,
629
+ 'lambda_level2': lambda_level2,
630
+ 'lambda_level3': lambda_level3,
631
+ 'lambda_temporal': lambda_temporal,
632
+ 'lambda_periodic': lambda_periodic,
633
+ 'lambda_boundary': lambda_boundary,
634
+ 'lambda_bias': lambda_bias,
635
+ 'lambda_corr': lambda_corr,
636
+ 'teacher_forcing_ratio': 1.0,
637
+ 'n_steps_generate': 10
638
+ }
639
+
640
+ # Warmup
641
+ self.warmup_epochs = warmup_epochs
642
+ self.base_lr = lr
643
+
644
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
645
+ self.model = self.model.to(self.device)
646
+
647
+ # Optimizer
648
+ self.optimizer = torch.optim.AdamW(
649
+ self.model.parameters(),
650
+ lr=lr,
651
+ weight_decay=weight_decay,
652
+ betas=(0.9, 0.99),
653
+ )
654
+
655
+ os.makedirs(checkpoint_dir, exist_ok=True)
656
+
657
+ # Training history
658
+ self.history = {
659
+ 'train_loss': [],
660
+ 'val_loss': [],
661
+ 'train_loss_level1': [],
662
+ 'train_loss_level2': [],
663
+ 'train_loss_level3': [],
664
+ 'train_loss_temporal': [],
665
+ 'train_loss_periodic': [],
666
+ 'train_loss_peak_cls': [],
667
+ 'train_loss_boundary': [],
668
+ 'train_loss_bias': [],
669
+ 'train_loss_corr': [],
670
+ 'val_mae': [],
671
+ 'val_corr': [],
672
+ 'val_var_ratio': [],
673
+ 'lr': [],
674
+ }
675
+
676
+ def get_lr_scale(self, epoch: int, total_epochs: int) -> float:
677
+ """Get learning rate scale with warmup and cosine decay."""
678
+ if epoch < self.warmup_epochs:
679
+ return (epoch + 1) / self.warmup_epochs
680
+ else:
681
+ progress = (epoch - self.warmup_epochs) / (total_epochs - self.warmup_epochs)
682
+ return 0.5 * (1 + np.cos(np.pi * progress))
683
+
684
+ def set_lr(self, scale: float):
685
+ """Set learning rate."""
686
+ for param_group in self.optimizer.param_groups:
687
+ param_group['lr'] = self.base_lr * scale
688
+
689
+ def train_epoch(self, epoch: int, total_epochs: int) -> Dict[str, float]:
690
+ """Train one epoch."""
691
+ self.model.train()
692
+
693
+ # Set learning rate
694
+ lr_scale = self.get_lr_scale(epoch, total_epochs)
695
+ self.set_lr(lr_scale)
696
+ current_lr = self.optimizer.param_groups[0]['lr']
697
+
698
+ # Teacher forcing ratio
699
+ self.loss_cfg['teacher_forcing_ratio'] = max(0.5, 1.0 - epoch / (2 * total_epochs))
700
+
701
+ total_loss = 0.0
702
+ loss_level1 = 0.0
703
+ loss_level2 = 0.0
704
+ loss_level3 = 0.0
705
+ loss_temporal = 0.0
706
+ loss_periodic = 0.0
707
+ loss_peak_cls = 0.0
708
+ loss_boundary = 0.0
709
+ loss_bias = 0.0
710
+ loss_corr = 0.0
711
+
712
+ pbar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1} [Train]")
713
+ for batch in pbar:
714
+ batch = {
715
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
716
+ for k, v in batch.items()
717
+ }
718
+
719
+ # Forward pass
720
+ self.optimizer.zero_grad()
721
+ output = self.model(
722
+ batch,
723
+ mode='train',
724
+ loss_cfg=self.loss_cfg,
725
+ )
726
+ losses = output['losses']
727
+
728
+ # Total loss
729
+ total_batch_loss = losses['loss_total']
730
+
731
+ # Backward pass
732
+ total_batch_loss.backward()
733
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
734
+ self.optimizer.step()
735
+
736
+ # Accumulate losses
737
+ total_loss += total_batch_loss.item()
738
+ loss_level1 += losses['loss_level1'].item()
739
+ loss_level2 += losses['loss_level2'].item()
740
+ loss_level3 += losses['loss_level3'].item()
741
+ loss_temporal += losses.get('loss_temporal', torch.tensor(0.0)).item()
742
+ loss_periodic += losses.get('loss_periodic', torch.tensor(0.0)).item()
743
+ loss_peak_cls += losses.get('loss_peak_cls', torch.tensor(0.0)).item()
744
+ loss_boundary += losses.get('loss_boundary', torch.tensor(0.0)).item()
745
+ loss_bias += losses.get('loss_bias', torch.tensor(0.0)).item()
746
+ loss_corr += losses.get('loss_corr', torch.tensor(0.0)).item()
747
+
748
+ # Update progress bar
749
+ pbar.set_postfix({
750
+ 'loss': total_loss / (len(pbar) + 1),
751
+ 'corr': loss_corr / (len(pbar) + 1),
752
+ 'peak': loss_peak_cls / (len(pbar) + 1),
753
+ 'bnd': loss_boundary / (len(pbar) + 1),
754
+ 'bias': loss_bias / (len(pbar) + 1),
755
+ 'lr': f'{current_lr:.2e}',
756
+ })
757
+
758
+ n_batches = len(self.train_loader)
759
+ return {
760
+ 'loss_total': total_loss / n_batches,
761
+ 'loss_level1': loss_level1 / n_batches,
762
+ 'loss_level2': loss_level2 / n_batches,
763
+ 'loss_level3': loss_level3 / n_batches,
764
+ 'loss_temporal': loss_temporal / n_batches,
765
+ 'loss_periodic': loss_periodic / n_batches,
766
+ 'loss_peak_cls': loss_peak_cls / n_batches,
767
+ 'loss_boundary': loss_boundary / n_batches,
768
+ 'loss_bias': loss_bias / n_batches,
769
+ 'loss_corr': loss_corr / n_batches,
770
+ 'lr': current_lr,
771
+ }
772
+
773
+ @torch.no_grad()
774
+ def validate(self, epoch: int) -> Dict[str, float]:
775
+ """Validate."""
776
+ self.model.eval()
777
+
778
+ total_loss = 0.0
779
+ all_mae = []
780
+ all_corr = []
781
+ all_var_ratio = []
782
+
783
+ pbar = tqdm(self.val_loader, desc=f"Epoch {epoch + 1} [Val]")
784
+ for batch in pbar:
785
+ batch = {
786
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
787
+ for k, v in batch.items()
788
+ }
789
+
790
+ # Loss
791
+ output = self.model(
792
+ batch,
793
+ mode='train',
794
+ loss_cfg=self.loss_cfg,
795
+ )
796
+ losses = output['losses']
797
+ total_loss += losses['loss_total'].item()
798
+
799
+ # Generate samples
800
+ # Note: generate now internally handles peak_hour logic in System.forward
801
+ gen_output = self.model(
802
+ batch,
803
+ mode='generate',
804
+ loss_cfg={'n_steps_generate': 50},
805
+ )
806
+ real = batch['traffic_seq'].cpu().numpy()
807
+ generated = gen_output['generated'].cpu().numpy()
808
+
809
+ # Metrics
810
+ mae = np.mean(np.abs(real - generated))
811
+ all_mae.append(mae)
812
+
813
+ # Variance ratio
814
+ real_var = np.var(real, axis=1).mean()
815
+ gen_var = np.var(generated, axis=1).mean()
816
+ var_ratio = gen_var / (real_var + 1e-8)
817
+ all_var_ratio.append(var_ratio)
818
+
819
+ # Correlation
820
+ for i in range(len(real)):
821
+ r_std = np.std(real[i])
822
+ g_std = np.std(generated[i])
823
+ if r_std > 1e-6 and g_std > 1e-6:
824
+ corr = np.corrcoef(real[i], generated[i])[0, 1]
825
+ if not np.isnan(corr):
826
+ all_corr.append(corr)
827
+
828
+ n_batches = len(self.val_loader)
829
+ return {
830
+ 'loss_total': total_loss / n_batches,
831
+ 'mae': np.mean(all_mae),
832
+ 'correlation': np.mean(all_corr) if all_corr else 0.0,
833
+ 'var_ratio': np.mean(all_var_ratio),
834
+ }
835
+
836
+ def train(self, epochs: int):
837
+ """Full training loop."""
838
+ print("=" * 80)
839
+ print("Hierarchical Flow Matching V4 - Training")
840
+ print(f"Fusion Method: {self.model.fusion_method}")
841
+ print("=" * 80)
842
+ print(f"Device: {self.device}")
843
+ print(f"Epochs: {epochs}")
844
+ print(f"Base learning rate: {self.base_lr:.2e}")
845
+ print("=" * 80)
846
+
847
+ # [修改 1] 初始化两个最佳指标跟踪变量
848
+ best_val_loss = float('inf')
849
+ best_val_corr = -1.0 # 初始化相关性为 -1
850
+
851
+ for epoch in range(epochs):
852
+ # Train
853
+ train_losses = self.train_epoch(epoch, epochs)
854
+
855
+ # Validate
856
+ val_losses = self.validate(epoch)
857
+
858
+ # Print summary
859
+ print(f"\nEpoch {epoch + 1}/{epochs}")
860
+ print(f" Train Loss: {train_losses['loss_total']:.6f}")
861
+ print(f" Peak Cls Loss: {train_losses['loss_peak_cls']:.6f}")
862
+ print(f" Boundary Loss: {train_losses['loss_boundary']:.6f}")
863
+ print(f" Bias Loss: {train_losses['loss_bias']:.6f}")
864
+ print(f" Val Loss: {val_losses['loss_total']:.6f}")
865
+ print(f" Val MAE: {val_losses['mae']:.4f}")
866
+ print(f" Val Correlation: {val_losses['correlation']:.4f}")
867
+ print(f" Val Var Ratio: {val_losses['var_ratio']:.4f}")
868
+
869
+ # Save history
870
+ self.history['train_loss'].append(train_losses['loss_total'])
871
+ self.history['val_loss'].append(val_losses['loss_total'])
872
+ self.history['train_loss_level1'].append(train_losses['loss_level1'])
873
+ self.history['train_loss_level2'].append(train_losses['loss_level2'])
874
+ self.history['train_loss_level3'].append(train_losses['loss_level3'])
875
+ self.history['train_loss_temporal'].append(train_losses['loss_temporal'])
876
+ self.history['train_loss_periodic'].append(train_losses['loss_periodic'])
877
+ self.history['train_loss_peak_cls'].append(train_losses['loss_peak_cls'])
878
+ self.history['train_loss_boundary'].append(train_losses['loss_boundary'])
879
+ self.history['train_loss_bias'].append(train_losses['loss_bias'])
880
+ self.history['val_mae'].append(val_losses['mae'])
881
+ self.history['val_corr'].append(val_losses['correlation'])
882
+ self.history['val_var_ratio'].append(val_losses['var_ratio'])
883
+ self.history['lr'].append(train_losses['lr'])
884
+
885
+ # Logic A: Save the model with the lowest loss (as the mathematically optimal fallback)
886
+ if val_losses['loss_total'] < best_val_loss:
887
+ best_val_loss = val_losses['loss_total']
888
+ self.save_checkpoint(epoch, val_losses, filename='best_loss_model.pt')
889
+ print(f" ✓ [Best Loss model saved! Loss: {best_val_loss:.4f}]")
890
+
891
+ # Logic B: Save the model with the highest correlation (as the business-practical best)
892
+ if val_losses['correlation'] > best_val_corr:
893
+ best_val_corr = val_losses['correlation']
894
+ self.save_checkpoint(epoch, val_losses, filename='best_corr_model.pt')
895
+ print(f" ★ [Best Correlation model saved! Corr: {best_val_corr:.4f}]")
896
+
897
+ # Always save the latest version
898
+ self.save_checkpoint(epoch, val_losses, filename='latest_model.pt')
899
+
900
+ # Save history
901
+ self.save_history()
902
+
903
+ print("\n" + "=" * 80)
904
+ print("Training Completed!")
905
+ print(f"Best validation loss: {best_val_loss:.6f}")
906
+ print(f"Best validation corr: {best_val_corr:.4f}") # 打印最佳相关性
907
+ print(f"Checkpoints saved to: {self.checkpoint_dir}")
908
+ print("=" * 80)
909
+
910
+ def save_checkpoint(self, epoch: int, losses: Dict, filename: str = 'best_model.pt'):
911
+ """Save checkpoint."""
912
+ checkpoint = {
913
+ 'epoch': epoch,
914
+ 'model_state_dict': self.model.state_dict(),
915
+ 'optimizer_state_dict': self.optimizer.state_dict(),
916
+ 'losses': losses,
917
+ 'history': self.history,
918
+ }
919
+
920
+ path = os.path.join(self.checkpoint_dir, filename)
921
+ torch.save(checkpoint, path)
922
+
923
+ def save_history(self):
924
+ """Save training history."""
925
+ path = os.path.join(self.checkpoint_dir, 'training_history.json')
926
+
927
+ history_serializable = {}
928
+ for key, values in self.history.items():
929
+ history_serializable[key] = [
930
+ float(v) if isinstance(v, (np.floating, np.integer)) else v
931
+ for v in values
932
+ ]
933
+
934
+ with open(path, 'w') as f:
935
+ json.dump(history_serializable, f, indent=2)
hierarchical_flow_matching_v4.py ADDED
@@ -0,0 +1,1019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hierarchical Flow Matching with Mamba/SSM Backbone (V4) - Generative Version
3
+ ============================================================================
4
+
5
+ Architecture: Pure Diffusion/Flow Matching (No GANs).
6
+ Fusion Method: Generative (Implicit alignment via conditional generation).
7
+
8
+ Core improvements over V3:
9
+ 1. Three-level cascaded Flow Matching architecture.
10
+ 2. Multi-modal spatial context encoding.
11
+ 3. Long-sequence modeling backbone.
12
+ 4. Explicit Peak Conditioning.
13
+ 5. Physical Constraints (Non-negative output enforced).
14
+
15
+ Author: Optimization Team
16
+ Date: 2026-01-21
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import math
22
+ import numpy as np
23
+ from typing import Dict, Optional, Tuple, List
24
+ from dataclasses import dataclass
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+
31
+ # =============================================================================
32
+ # Mamba/SSM Backbone for Long Sequence Modeling
33
+ # =============================================================================
34
+
35
+ @dataclass
36
+ class MambaConfig:
37
+ """Configuration for Mamba block."""
38
+ d_model: int = 256
39
+ d_state: int = 64
40
+ d_conv: int = 4
41
+ expand: int = 2
42
+ dt_rank: str = "auto"
43
+ dt_min: float = 0.001
44
+ dt_max: float = 0.1
45
+ dt_init: str = "random"
46
+ dt_scale: float = 1.0
47
+ dt_init_floor: float = 1e-4
48
+ bias: bool = True
49
+ conv_bias: bool = True
50
+ pscan: bool = True
51
+ use_cuda: bool = True
52
+
53
+
54
+ def _selective_scan_diagonal(
55
+ log_a: torch.Tensor, # [B, L, N]
56
+ b: torch.Tensor, # [B, L, N]
57
+ ) -> torch.Tensor:
58
+ """
59
+ Parallel (vectorized) diagonal linear recurrence:
60
+ h_t = a_t * h_{t-1} + b_t, h_{-1}=0
61
+ where a_t = exp(log_a_t), computed without Python loops.
62
+ """
63
+ # log_p[t] = sum_{i<=t} log_a[i]
64
+ log_p = torch.cumsum(log_a, dim=1) # [B, L, N]
65
+ inv_p = torch.exp(-log_p)
66
+ s = torch.cumsum(b * inv_p, dim=1) # [B, L, N]
67
+ h = torch.exp(log_p) * s
68
+ return h
69
+
70
+
71
+ class Mamba(nn.Module):
72
+ """
73
+ Mamba block for efficient long-sequence modeling.
74
+
75
+ Based on: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
76
+ Pure-PyTorch implementation (vectorized diagonal selective scan) for traffic
77
+ sequence generation (no external kernels / dependencies).
78
+ """
79
+
80
+ def __init__(self, config: MambaConfig):
81
+ super().__init__()
82
+ self.config = config
83
+
84
+ d_model = config.d_model
85
+ d_state = config.d_state
86
+ d_conv = config.d_conv
87
+ expand = config.expand
88
+
89
+ self.d_inner = int(expand * d_model)
90
+
91
+ # (1) Input projection: x -> (u, gate)
92
+ self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=config.bias)
93
+
94
+ # (2) Depthwise conv for short-range mixing (Mamba-style local context)
95
+ self.dwconv = nn.Conv1d(
96
+ in_channels=self.d_inner,
97
+ out_channels=self.d_inner,
98
+ kernel_size=d_conv,
99
+ padding=d_conv - 1,
100
+ groups=self.d_inner,
101
+ bias=config.conv_bias,
102
+ )
103
+
104
+ # (3) Input-dependent SSM parameters (B, C, dt)
105
+ self.B_proj = nn.Linear(self.d_inner, d_state, bias=False)
106
+ self.C_proj = nn.Linear(self.d_inner, d_state, bias=False)
107
+ self.dt_proj = nn.Linear(self.d_inner, d_state, bias=True)
108
+
109
+ # Diagonal A (negative, stable)
110
+ self.A_log = nn.Parameter(torch.zeros(d_state))
111
+
112
+ # Skip connection from u (Mamba "D" term)
113
+ self.D = nn.Parameter(torch.ones(self.d_inner))
114
+
115
+ # (4) State -> inner -> model projections
116
+ self.out_state_proj = nn.Linear(d_state, self.d_inner, bias=False)
117
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=config.bias)
118
+
119
+ # Initialize FiLM-like stability: start close to identity
120
+ nn.init.zeros_(self.A_log)
121
+ nn.init.zeros_(self.dt_proj.weight)
122
+ nn.init.constant_(self.dt_proj.bias, math.log(math.expm1(0.01))) # softplus^-1
123
+ nn.init.zeros_(self.out_state_proj.weight)
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ """
127
+ Args:
128
+ x: [B, L, D] input sequence
129
+ Returns:
130
+ y: [B, L, D] output sequence
131
+ """
132
+ B, L, _ = x.shape
133
+
134
+ # Input projection
135
+ u, gate = self.in_proj(x).chunk(2, dim=-1) # [B, L, d_inner] each
136
+
137
+ # Depthwise conv (causal-ish via padding then crop)
138
+ u_conv = self.dwconv(u.transpose(1, 2))[:, :, :L].transpose(1, 2) # [B, L, d_inner]
139
+ u_conv = F.silu(u_conv)
140
+
141
+ # Input-dependent SSM params
142
+ dt = F.softplus(self.dt_proj(u_conv)) # [B, L, d_state]
143
+ dt = dt.clamp(min=self.config.dt_min, max=self.config.dt_max)
144
+
145
+ B_t = self.B_proj(u_conv) # [B, L, d_state]
146
+ C_t = self.C_proj(u_conv) # [B, L, d_state]
147
+
148
+ # Diagonal state transition: a_t = exp(A * dt)
149
+ A = -torch.exp(self.A_log).view(1, 1, -1) # [1, 1, d_state]
150
+ log_a = A * dt # [B, L, d_state]
151
+ b = B_t * dt # [B, L, d_state]
152
+
153
+ # Selective scan (vectorized)
154
+ h = _selective_scan_diagonal(log_a, b) # [B, L, d_state]
155
+
156
+ # Output from states
157
+ y_state = h * C_t
158
+ y_inner = self.out_state_proj(y_state) # [B, L, d_inner]
159
+
160
+ # Skip + gate (Mamba-style)
161
+ y_inner = y_inner + u_conv * self.D.view(1, 1, -1)
162
+ y_inner = y_inner * torch.sigmoid(gate)
163
+
164
+ return self.out_proj(y_inner)
165
+
166
+
167
+ # =============================================================================
168
+ # Multi-Scale Dilated Convolution Backbone
169
+ # =============================================================================
170
+
171
+ class MultiScaleDilatedConv(nn.Module):
172
+ """
173
+ Multi-scale dilated convolution for capturing temporal patterns at different scales.
174
+
175
+ Receptive fields:
176
+ - Scale 1 (dilation=1): Daily patterns (24 hours)
177
+ - Scale 2 (example): Weekly patterns (7 days) -> hourly would be dilation=168
178
+ - Scale 3 (example): Longer cycles (e.g., 28 days) -> hourly would be dilation=672
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ channels: int,
184
+ kernel_size: int = 3,
185
+ dilations: Optional[List[int]] = None,
186
+ dropout: float = 0.0,
187
+ ):
188
+ super().__init__()
189
+ if dilations is None:
190
+ dilations = [1, 4, 16]
191
+ self.channels = channels
192
+ self.kernel_size = kernel_size
193
+ self.dilations = [int(d) for d in dilations if int(d) >= 1]
194
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
195
+
196
+ padding_base = (kernel_size - 1) // 2
197
+
198
+ # Depthwise-separable conv branches
199
+ self.branches = nn.ModuleList()
200
+ for d in self.dilations:
201
+ self.branches.append(
202
+ nn.Sequential(
203
+ nn.Conv1d(
204
+ channels,
205
+ channels,
206
+ kernel_size,
207
+ dilation=d,
208
+ padding=padding_base * d,
209
+ groups=channels,
210
+ bias=True,
211
+ ),
212
+ nn.GELU(),
213
+ nn.Conv1d(channels, channels, kernel_size=1, bias=True),
214
+ )
215
+ )
216
+
217
+ # Fusion (token-wise MLP)
218
+ self.fusion = nn.Sequential(
219
+ nn.Linear(channels * len(self.dilations), channels * 2),
220
+ nn.GELU(),
221
+ nn.Dropout(0.1),
222
+ nn.Linear(channels * 2, channels),
223
+ )
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ """
227
+ Args:
228
+ x: [B, L, C] input
229
+ Returns:
230
+ y: [B, L, C] output
231
+ """
232
+ x_t = x.transpose(1, 2) # [B, C, L]
233
+ outs = []
234
+ for branch in self.branches:
235
+ outs.append(branch(x_t).transpose(1, 2)) # [B, L, C]
236
+ y = torch.cat(outs, dim=-1) # [B, L, C * n_scales]
237
+ y = self.fusion(y)
238
+ return self.dropout(y)
239
+
240
+
241
+ # =============================================================================
242
+ # Hybrid Backbone: Mamba + Multi-Scale Dilated Conv
243
+ # =============================================================================
244
+
245
+ class HybridLongSequenceBackbone(nn.Module):
246
+ """
247
+ Hybrid backbone combining Mamba/SSM and multi-scale dilated convolutions.
248
+
249
+ Designed for efficient long-sequence modeling with multi-scale temporal patterns.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ d_model: int = 256,
255
+ n_layers: int = 4,
256
+ d_state: int = 64,
257
+ use_mamba: bool = True,
258
+ use_dilated_conv: bool = True,
259
+ dilations: Optional[List[int]] = None,
260
+ cond_dim: Optional[int] = None,
261
+ dropout: float = 0.1,
262
+ ):
263
+ super().__init__()
264
+ self.d_model = d_model
265
+ self.n_layers = n_layers
266
+ self.d_state = d_state
267
+ self.use_dilated_conv = use_dilated_conv
268
+ self.use_mamba = use_mamba
269
+ self.cond_dim = cond_dim
270
+
271
+ if dilations is None:
272
+ dilations = [1, 4, 16]
273
+
274
+ self.blocks = nn.ModuleList()
275
+ for _ in range(n_layers):
276
+ self.blocks.append(
277
+ _HybridBlock(
278
+ d_model=d_model,
279
+ d_state=d_state,
280
+ use_mamba=use_mamba,
281
+ use_dilated_conv=use_dilated_conv,
282
+ dilations=dilations,
283
+ cond_dim=cond_dim,
284
+ dropout=dropout,
285
+ )
286
+ )
287
+
288
+ def forward(
289
+ self,
290
+ x: torch.Tensor,
291
+ t_emb: Optional[torch.Tensor] = None, # [B, D]
292
+ cond: Optional[torch.Tensor] = None, # [B, C]
293
+ ) -> torch.Tensor:
294
+ """
295
+ Args:
296
+ x: [B, L, D] input sequence
297
+ Returns:
298
+ y: [B, L, D] output sequence
299
+ """
300
+ for block in self.blocks:
301
+ x = block(x, t_emb=t_emb, cond=cond)
302
+ return x
303
+
304
+
305
+ def _valid_num_groups(channels: int, requested: int) -> int:
306
+ g = min(requested, channels)
307
+ while g > 1 and (channels % g) != 0:
308
+ g -= 1
309
+ return max(g, 1)
310
+
311
+
312
+ class _HybridBlock(nn.Module):
313
+ def __init__(
314
+ self,
315
+ d_model: int,
316
+ d_state: int,
317
+ use_mamba: bool,
318
+ use_dilated_conv: bool,
319
+ dilations: List[int],
320
+ cond_dim: Optional[int],
321
+ dropout: float,
322
+ ):
323
+ super().__init__()
324
+ self.use_mamba = use_mamba
325
+ self.use_dilated_conv = use_dilated_conv
326
+ self.cond_dim = cond_dim
327
+
328
+ self.norm1 = nn.LayerNorm(d_model)
329
+ self.norm2 = nn.LayerNorm(d_model)
330
+ self.norm3 = nn.LayerNorm(d_model)
331
+
332
+ self.mamba = (
333
+ Mamba(MambaConfig(d_model=d_model, d_state=d_state))
334
+ if use_mamba
335
+ else nn.Identity()
336
+ )
337
+ self.conv = (
338
+ MultiScaleDilatedConv(
339
+ channels=d_model,
340
+ kernel_size=3,
341
+ dilations=dilations,
342
+ dropout=dropout,
343
+ )
344
+ if use_dilated_conv
345
+ else nn.Identity()
346
+ )
347
+
348
+ self.ffn = nn.Sequential(
349
+ nn.Linear(d_model, d_model * 4),
350
+ nn.GELU(),
351
+ nn.Dropout(dropout),
352
+ nn.Linear(d_model * 4, d_model),
353
+ )
354
+
355
+ self.dropout = nn.Dropout(dropout)
356
+
357
+ self.film = FiLMModulation(d_model, cond_dim) if cond_dim is not None else None
358
+ self.ada_gn = AdaptiveGroupNorm(d_model, cond_dim) if cond_dim is not None else None
359
+
360
+ def _cond(self, h: torch.Tensor, cond: Optional[torch.Tensor]) -> torch.Tensor:
361
+ if cond is None or self.film is None or self.ada_gn is None:
362
+ return h
363
+ h = self.film(h, cond)
364
+ h = self.ada_gn(h, cond)
365
+ return h
366
+
367
+ def forward(
368
+ self,
369
+ x: torch.Tensor, # [B, L, D]
370
+ t_emb: Optional[torch.Tensor] = None, # [B, D]
371
+ cond: Optional[torch.Tensor] = None, # [B, C]
372
+ ) -> torch.Tensor:
373
+ # Mamba/SSM
374
+ h = self.norm1(x)
375
+ if t_emb is not None:
376
+ h = h + t_emb.unsqueeze(1)
377
+ h = self._cond(h, cond)
378
+ h = self.mamba(h)
379
+ x = x + self.dropout(h)
380
+
381
+ # Multi-scale dilated conv
382
+ if self.use_dilated_conv:
383
+ h = self.norm2(x)
384
+ if t_emb is not None:
385
+ h = h + 0.5 * t_emb.unsqueeze(1)
386
+ h = self._cond(h, cond)
387
+ h = self.conv(h)
388
+ x = x + self.dropout(h)
389
+
390
+ # FFN
391
+ h = self.norm3(x)
392
+ if t_emb is not None:
393
+ h = h + 0.5 * t_emb.unsqueeze(1)
394
+ h = self._cond(h, cond)
395
+ h = self.ffn(h)
396
+ x = x + self.dropout(h)
397
+ return x
398
+
399
+
400
+ # =============================================================================
401
+ # FiLM Modulation for Condition Injection
402
+ # =============================================================================
403
+
404
+ class FiLMModulation(nn.Module):
405
+ """
406
+ Feature-wise Linear Modulation (FiLM) for adaptive condition injection.
407
+
408
+ Dynamically modulates intermediate features based on spatial context.
409
+ """
410
+
411
+ def __init__(self, d_model: int, cond_dim: int):
412
+ super().__init__()
413
+
414
+ self.gamma_proj = nn.Linear(cond_dim, d_model)
415
+ self.beta_proj = nn.Linear(cond_dim, d_model)
416
+
417
+ # Start near identity modulation
418
+ nn.init.zeros_(self.gamma_proj.weight)
419
+ nn.init.zeros_(self.gamma_proj.bias)
420
+ nn.init.zeros_(self.beta_proj.weight)
421
+ nn.init.zeros_(self.beta_proj.bias)
422
+
423
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
424
+ """
425
+ Args:
426
+ x: [B, L, D] features
427
+ cond: [B, C] condition
428
+ Returns:
429
+ y: [B, L, D] modulated features
430
+ """
431
+ gamma = self.gamma_proj(cond).unsqueeze(1) # [B, 1, D]
432
+ beta = self.beta_proj(cond).unsqueeze(1) # [B, 1, D]
433
+ return x * (1.0 + gamma) + beta
434
+
435
+
436
+ # =============================================================================
437
+ # Adaptive Group Normalization
438
+ # =============================================================================
439
+
440
+ class AdaptiveGroupNorm(nn.Module):
441
+ """
442
+ Adaptive Group Normalization (AdaGN) for condition-aware normalization.
443
+ """
444
+
445
+ def __init__(self, d_model: int, cond_dim: int, num_groups: int = 32):
446
+ super().__init__()
447
+ self.num_groups = _valid_num_groups(d_model, num_groups)
448
+ self.group_norm = nn.GroupNorm(self.num_groups, d_model, affine=False)
449
+
450
+ self.weight_proj = nn.Linear(cond_dim, d_model)
451
+ self.bias_proj = nn.Linear(cond_dim, d_model)
452
+ nn.init.zeros_(self.weight_proj.weight)
453
+ nn.init.zeros_(self.weight_proj.bias)
454
+ nn.init.zeros_(self.bias_proj.weight)
455
+ nn.init.zeros_(self.bias_proj.bias)
456
+
457
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
458
+ """
459
+ Args:
460
+ x: [B, L, D] features
461
+ cond: [B, C] condition
462
+ Returns:
463
+ y: [B, L, D] normalized features
464
+ """
465
+ # Group norm
466
+ x_norm = self.group_norm(x.transpose(1, 2)).transpose(1, 2) # [B, L, D]
467
+
468
+ # Adaptive scaling
469
+ weight = self.weight_proj(cond).unsqueeze(1) # [B, 1, D]
470
+ bias = self.bias_proj(cond).unsqueeze(1) # [B, 1, D]
471
+
472
+ return x_norm * (1.0 + weight) + bias
473
+
474
+
475
+ class FourierTimeEmbedding(nn.Module):
476
+ """Gaussian Fourier features for diffusion/FM time t in [0,1]."""
477
+
478
+ def __init__(self, d_model: int, n_freqs: int = 64):
479
+ super().__init__()
480
+ self.n_freqs = n_freqs
481
+ self.W = nn.Parameter(torch.randn(n_freqs) * 10.0)
482
+ self.proj = nn.Sequential(
483
+ nn.Linear(2 * n_freqs, d_model),
484
+ nn.GELU(),
485
+ nn.Linear(d_model, d_model),
486
+ )
487
+
488
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
489
+ # t: [B, 1]
490
+ t = t.clamp(0.0, 1.0)
491
+ w = self.W.view(1, 1, -1) # [1, 1, F]
492
+ angles = 2 * math.pi * t.unsqueeze(-1) * w # [B, 1, F]
493
+ emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).squeeze(1)
494
+ return self.proj(emb) # [B, D]
495
+
496
+
497
+ def sinusoidal_positional_embedding(
498
+ length: int,
499
+ dim: int,
500
+ device: torch.device,
501
+ dtype: torch.dtype,
502
+ ) -> torch.Tensor:
503
+ """Standard sinusoidal positional embeddings [L, D]."""
504
+ position = torch.arange(length, device=device, dtype=dtype).unsqueeze(1) # [L, 1]
505
+ div_term = torch.exp(
506
+ torch.arange(0, dim, 2, device=device, dtype=dtype) * (-math.log(10000.0) / dim)
507
+ ) # [D/2]
508
+ pe = torch.zeros(length, dim, device=device, dtype=dtype)
509
+ pe[:, 0::2] = torch.sin(position * div_term)
510
+ pe[:, 1::2] = torch.cos(position * div_term)
511
+ return pe
512
+
513
+
514
+ # =============================================================================
515
+ # Level 1: Daily Pattern Flow Matching
516
+ # =============================================================================
517
+
518
+ class DailyPatternFM(nn.Module):
519
+ """
520
+ Level 1: Daily Pattern Flow Matching.
521
+
522
+ Learns to generate day-type templates for hourly traffic:
523
+ - weekday template (24 hours)
524
+ - weekend template (24 hours)
525
+
526
+ Output is a concatenation of two 24-hour patterns: [weekday | weekend] -> 48 dims.
527
+ """
528
+
529
+ def __init__(self, spatial_dim: int = 192, hidden_dim: int = 256, steps_per_day: int = 24):
530
+ super().__init__()
531
+ self.spatial_dim = spatial_dim
532
+ self.hidden_dim = hidden_dim
533
+ self.steps_per_day = steps_per_day
534
+ self.daytype_len = 2 * steps_per_day # weekday + weekend
535
+
536
+ self.time_embed = FourierTimeEmbedding(hidden_dim)
537
+
538
+ self.in_proj = nn.Linear(1, hidden_dim)
539
+ self.backbone = HybridLongSequenceBackbone(
540
+ d_model=hidden_dim,
541
+ n_layers=3,
542
+ d_state=64,
543
+ use_mamba=True,
544
+ use_dilated_conv=True,
545
+ dilations=[1, 2, 4, 8, 16],
546
+ cond_dim=spatial_dim,
547
+ dropout=0.1,
548
+ )
549
+ self.out_proj = nn.Linear(hidden_dim, 1)
550
+
551
+ def forward(
552
+ self,
553
+ x: torch.Tensor,
554
+ t: torch.Tensor,
555
+ spatial_cond: torch.Tensor,
556
+ ) -> torch.Tensor:
557
+ """
558
+ Args:
559
+ x: [B, 48] day-type templates = [weekday(24), weekend(24)]
560
+ t: [B, 1] time step
561
+ spatial_cond: [B, spatial_dim] spatial context
562
+ Returns:
563
+ v: [B, 48] velocity field
564
+ """
565
+ B, L = x.shape
566
+ assert L == self.daytype_len, f"DailyPatternFM expects L={self.daytype_len}, got {L}"
567
+
568
+ t_emb = self.time_embed(t) # [B, hidden_dim]
569
+ pos = sinusoidal_positional_embedding(L, self.hidden_dim, x.device, x.dtype) # [L, D]
570
+
571
+ h = self.in_proj(x.unsqueeze(-1)) # [B, L, D]
572
+ h = h + pos.unsqueeze(0)
573
+ h = self.backbone(h, t_emb=t_emb, cond=spatial_cond)
574
+ v = self.out_proj(h).squeeze(-1) # [B, L]
575
+ return v
576
+
577
+
578
+ # =============================================================================
579
+ # Level 2: Weekly Pattern Flow Matching
580
+ # =============================================================================
581
+
582
+ class WeeklyPatternFM(nn.Module):
583
+ """
584
+ Level 2: Weekly Pattern Flow Matching.
585
+
586
+ Learns to generate a weekly periodic pattern at hourly resolution:
587
+ weekly_pattern: 7 days × 24 hours = 168 time steps.
588
+
589
+ This level is conditioned on day-type templates from Level 1.
590
+ """
591
+
592
+ def __init__(self, spatial_dim: int = 192, hidden_dim: int = 256, steps_per_day: int = 24):
593
+ super().__init__()
594
+ self.spatial_dim = spatial_dim
595
+ self.hidden_dim = hidden_dim
596
+ self.steps_per_day = steps_per_day
597
+ self.week_len = 7 * steps_per_day
598
+ self.daytype_len = 2 * steps_per_day
599
+
600
+ self.time_embed = FourierTimeEmbedding(hidden_dim)
601
+
602
+ self.in_proj = nn.Linear(1, hidden_dim)
603
+ self.daily_token_proj = nn.Linear(1, hidden_dim)
604
+
605
+ self.daily_to_weekly_attn = nn.MultiheadAttention(
606
+ embed_dim=hidden_dim,
607
+ num_heads=8,
608
+ dropout=0.1,
609
+ batch_first=True,
610
+ )
611
+
612
+ self.backbone = HybridLongSequenceBackbone(
613
+ d_model=hidden_dim,
614
+ n_layers=3,
615
+ d_state=64,
616
+ use_mamba=True,
617
+ use_dilated_conv=True,
618
+ dilations=[1, 2, 4],
619
+ cond_dim=spatial_dim,
620
+ dropout=0.1,
621
+ )
622
+ self.out_proj = nn.Linear(hidden_dim, 1)
623
+
624
+ def forward(
625
+ self,
626
+ x: torch.Tensor,
627
+ t: torch.Tensor,
628
+ daily_pattern: torch.Tensor,
629
+ spatial_cond: torch.Tensor,
630
+ ) -> torch.Tensor:
631
+ """
632
+ Args:
633
+ x: [B, 168] weekly pattern
634
+ t: [B, 1] time step
635
+ daily_pattern: [B, 48] day-type templates (from Level 1)
636
+ spatial_cond: [B, spatial_dim] spatial context
637
+ Returns:
638
+ v: [B, 168] velocity field
639
+ """
640
+ B, Lw = x.shape
641
+ assert Lw == self.week_len, f"WeeklyPatternFM expects L={self.week_len}, got {Lw}"
642
+ Bd, Ld = daily_pattern.shape
643
+ assert Bd == B and Ld == self.daytype_len, (
644
+ f"WeeklyPatternFM expects daily_pattern [B,{self.daytype_len}], got {daily_pattern.shape}"
645
+ )
646
+
647
+ t_emb = self.time_embed(t) # [B, D]
648
+
649
+ pos_w = sinusoidal_positional_embedding(Lw, self.hidden_dim, x.device, x.dtype)
650
+ pos_d = sinusoidal_positional_embedding(Ld, self.hidden_dim, x.device, x.dtype)
651
+
652
+ week_tokens = self.in_proj(x.unsqueeze(-1)) + pos_w.unsqueeze(0) # [B, 168, D]
653
+ day_tokens = self.daily_token_proj(daily_pattern.unsqueeze(-1)) + pos_d.unsqueeze(0) # [B, 48, D]
654
+
655
+ # Explicitly condition on S^d via cross-attention (decouples day/week)
656
+ attn_out, _ = self.daily_to_weekly_attn(week_tokens, day_tokens, day_tokens)
657
+ week_tokens = week_tokens + attn_out
658
+
659
+ h = self.backbone(week_tokens, t_emb=t_emb, cond=spatial_cond)
660
+ v = self.out_proj(h).squeeze(-1) # [B, 168]
661
+ return v
662
+
663
+
664
+ # =============================================================================
665
+ # Level 3: Long-term Residual Flow Matching
666
+ # =============================================================================
667
+
668
+ class LongTermResidualFM(nn.Module):
669
+ """
670
+ Level 3: Long-term Residual Flow Matching.
671
+
672
+ Learns to generate fine residuals for the full sequence (672 time steps).
673
+ Uses Mamba + multi-scale dilated convolutions for efficient long-sequence modeling.
674
+
675
+ Explicitly conditioned on peak hour location to force peak generation.
676
+ """
677
+
678
+ def __init__(
679
+ self,
680
+ spatial_dim: int = 192,
681
+ hidden_dim: int = 256,
682
+ n_layers: int = 6,
683
+ steps_per_day: int = 24,
684
+ ):
685
+ super().__init__()
686
+ self.spatial_dim = spatial_dim
687
+ self.hidden_dim = hidden_dim
688
+ self.steps_per_day = steps_per_day
689
+ self.week_len = 7 * steps_per_day
690
+ self.daytype_len = 2 * steps_per_day
691
+
692
+ self.time_embed = FourierTimeEmbedding(hidden_dim)
693
+
694
+ # Explicit Peak Position Encoding
695
+ # Maps 0-23 hours to a hidden vector to serve as a strong condition
696
+ self.peak_embed = nn.Embedding(24, hidden_dim)
697
+
698
+ # Token-wise projection of multi-channel inputs
699
+ self.in_proj = nn.Linear(4, hidden_dim)
700
+
701
+ # Main backbone (Mamba + multi-scale dilated conv) for long sequences
702
+ self.backbone = HybridLongSequenceBackbone(
703
+ d_model=hidden_dim,
704
+ n_layers=n_layers,
705
+ d_state=128,
706
+ use_mamba=True,
707
+ use_dilated_conv=True,
708
+ # (local, daily, weekly) receptive fields at hourly resolution
709
+ dilations=[1, 2, 4, 8, 16, 24, 48, 168],
710
+ cond_dim=spatial_dim,
711
+ dropout=0.1,
712
+ )
713
+ self.out_proj = nn.Linear(hidden_dim, 1)
714
+
715
+ def _repeat_to_length(self, pattern: torch.Tensor, target_len: int) -> torch.Tensor:
716
+ # pattern: [B, P]
717
+ B, P = pattern.shape
718
+ reps = (target_len + P - 1) // P
719
+ tiled = pattern.repeat(1, reps)
720
+ return tiled[:, :target_len]
721
+
722
+ def _repeat_daytype_to_length(self, daytype: torch.Tensor, target_len: int) -> torch.Tensor:
723
+ """
724
+ Expand day-type templates (weekday/weekend) to a full 28-day hourly sequence.
725
+
726
+ Assumption (consistent with plot_traffic_decomposition*.py): sequence starts on Monday.
727
+ """
728
+ B, L = daytype.shape
729
+ assert L == self.daytype_len, f"Expected daytype_len={self.daytype_len}, got {L}"
730
+ steps = self.steps_per_day
731
+ weekday = daytype[:, :steps]
732
+ weekend = daytype[:, steps:]
733
+
734
+ n_days = target_len // steps
735
+ parts = []
736
+ for d in range(n_days):
737
+ dow = d % 7
738
+ parts.append(weekday if dow < 5 else weekend)
739
+ seq = torch.cat(parts, dim=1) # [B, n_days*steps]
740
+ if seq.shape[1] < target_len:
741
+ pad = torch.zeros(B, target_len - seq.shape[1], device=seq.device, dtype=seq.dtype)
742
+ seq = torch.cat([seq, pad], dim=1)
743
+ return seq[:, :target_len]
744
+
745
+ def forward(
746
+ self,
747
+ x: torch.Tensor,
748
+ t: torch.Tensor,
749
+ coarse_signal: torch.Tensor,
750
+ daily_pattern: torch.Tensor,
751
+ weekly_trend: torch.Tensor,
752
+ spatial_cond: torch.Tensor,
753
+ peak_hour: torch.Tensor,
754
+ ) -> torch.Tensor:
755
+ """
756
+ Args:
757
+ x: [B, 672] residual sequence
758
+ t: [B, 1] time step
759
+ coarse_signal: [B, 672] periodic component (tiled weekly pattern)
760
+ daily_pattern: [B, 48] day-type templates
761
+ weekly_trend: [B, 168] weekly pattern
762
+ spatial_cond: [B, spatial_dim] spatial context
763
+ peak_hour: [B] Integer tensor (0-23) indicating explicit peak location
764
+ Returns:
765
+ v: [B, 672] velocity field
766
+ """
767
+ B, L = x.shape
768
+ assert coarse_signal.shape == (B, L)
769
+ assert daily_pattern.shape == (B, self.daytype_len)
770
+ assert weekly_trend.shape == (B, self.week_len)
771
+
772
+ # Fuse Time Embedding with Peak Embedding
773
+ t_emb = self.time_embed(t) # [B, D]
774
+ peak_cond = self.peak_embed(peak_hour) # [B, D]
775
+
776
+ # Combine: Global time context + "Peak Attention" bias
777
+ global_cond = t_emb + peak_cond
778
+
779
+ pos = sinusoidal_positional_embedding(L, self.hidden_dim, x.device, x.dtype)
780
+
781
+ daily_rep = self._repeat_daytype_to_length(daily_pattern, L) # [B, L]
782
+ # weekly_trend is weekly pattern here: tile 168 -> 672 (4 weeks)
783
+ weekly_rep = self._repeat_to_length(weekly_trend, L) # [B, L]
784
+ weekly_delta = coarse_signal - daily_rep # [B, L]
785
+
786
+ # Token features: [residual, periodic, repeated_daytype, weekly_delta]
787
+ feats = torch.stack([x, coarse_signal, daily_rep, weekly_delta], dim=-1) # [B, L, 4]
788
+ h = self.in_proj(feats) + pos.unsqueeze(0)
789
+
790
+ # Pass combined global condition
791
+ h = self.backbone(h, t_emb=global_cond, cond=spatial_cond)
792
+
793
+ v = self.out_proj(h).squeeze(-1) # [B, L]
794
+ return v
795
+
796
+
797
+ # =============================================================================
798
+ # Complete Hierarchical Flow Matching Model
799
+ # =============================================================================
800
+
801
+ class HierarchicalFlowMatchingV4(nn.Module):
802
+ """
803
+ Complete Hierarchical Flow Matching model with three-level cascaded architecture.
804
+
805
+ Level 1: Daily Pattern FM
806
+ Level 2: Weekly Pattern FM (with daily conditioning)
807
+ Level 3: Long-term Residual FM (with daily + weekly conditioning + explicit peak)
808
+ """
809
+
810
+ def __init__(
811
+ self,
812
+ spatial_dim: int = 192,
813
+ hidden_dim: int = 256,
814
+ n_layers_level3: int = 6,
815
+ steps_per_day: int = 24,
816
+ ):
817
+ super().__init__()
818
+ self.spatial_dim = spatial_dim
819
+ self.hidden_dim = hidden_dim
820
+ self.steps_per_day = steps_per_day
821
+ self.week_len = 7 * steps_per_day
822
+ self.daytype_len = 2 * steps_per_day
823
+ # This repo's 672-length traffic is hourly: 28 days = 4 weeks.
824
+ self.seq_len = 672
825
+ self.n_weeks = self.seq_len // self.week_len
826
+
827
+ # Three-level FM
828
+ self.level1_fm = DailyPatternFM(spatial_dim, hidden_dim, steps_per_day=steps_per_day)
829
+ self.level2_fm = WeeklyPatternFM(spatial_dim, hidden_dim, steps_per_day=steps_per_day)
830
+ self.level3_fm = LongTermResidualFM(
831
+ spatial_dim, hidden_dim, n_layers_level3, steps_per_day=steps_per_day
832
+ )
833
+
834
+ def forward(
835
+ self,
836
+ x: torch.Tensor,
837
+ t: torch.Tensor,
838
+ spatial_cond: torch.Tensor,
839
+ level: int = 1,
840
+ daily_pattern: Optional[torch.Tensor] = None,
841
+ weekly_trend: Optional[torch.Tensor] = None,
842
+ coarse_signal: Optional[torch.Tensor] = None,
843
+ peak_hour: Optional[torch.Tensor] = None,
844
+ ) -> torch.Tensor:
845
+ """
846
+ Forward pass for a specific level.
847
+ """
848
+ if level == 1:
849
+ return self.level1_fm(x, t, spatial_cond)
850
+
851
+ elif level == 2:
852
+ assert daily_pattern is not None, "daily_pattern required for level 2"
853
+ return self.level2_fm(x, t, daily_pattern, spatial_cond)
854
+
855
+ elif level == 3:
856
+ assert daily_pattern is not None, "daily_pattern required for level 3"
857
+ assert weekly_trend is not None, "weekly_trend required for level 3"
858
+ assert coarse_signal is not None, "coarse_signal required for level 3"
859
+ assert peak_hour is not None, "peak_hour required for level 3 (Explicit Peak Conditioning)"
860
+ return self.level3_fm(x, t, coarse_signal, daily_pattern, weekly_trend, spatial_cond, peak_hour)
861
+
862
+ else:
863
+ raise ValueError(f"Invalid level: {level}")
864
+
865
+ # =========================================================================
866
+ # Generation Methods (ODE Solve)
867
+ # =========================================================================
868
+
869
+ def _unpack_level_conditions(
870
+ self,
871
+ spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
872
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
873
+ if isinstance(spatial_cond, dict):
874
+ return (
875
+ spatial_cond["level1_cond"],
876
+ spatial_cond["level2_cond"],
877
+ spatial_cond["level3_cond"],
878
+ )
879
+ return spatial_cond, spatial_cond, spatial_cond
880
+
881
+ def generate_daily_pattern(
882
+ self,
883
+ spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
884
+ n_steps: int = 50,
885
+ ) -> torch.Tensor:
886
+ """
887
+ Generate day-type templates (Level 1).
888
+ """
889
+ spatial_cond_level1, _, _ = self._unpack_level_conditions(spatial_cond)
890
+ B = spatial_cond_level1.shape[0]
891
+ device = spatial_cond_level1.device
892
+
893
+ x = torch.randn(B, self.daytype_len, device=device)
894
+ dt = 1.0 / n_steps
895
+
896
+ for step in range(n_steps):
897
+ t = torch.full((B, 1), step / n_steps, device=device)
898
+ v = self.level1_fm(x, t, spatial_cond_level1)
899
+ v = torch.clamp(v, -10.0, 10.0)
900
+ x = x + dt * v
901
+ x = torch.clamp(x, -10.0, 10.0)
902
+
903
+ return x
904
+
905
+ def generate_weekly_trend(
906
+ self,
907
+ daily_pattern: torch.Tensor,
908
+ spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
909
+ n_steps: int = 50,
910
+ ) -> torch.Tensor:
911
+ """
912
+ Generate weekly pattern (Level 2).
913
+ """
914
+ _, spatial_cond_level2, _ = self._unpack_level_conditions(spatial_cond)
915
+ B = spatial_cond_level2.shape[0]
916
+ device = spatial_cond_level2.device
917
+
918
+ x = torch.randn(B, self.week_len, device=device)
919
+ dt = 1.0 / n_steps
920
+
921
+ for step in range(n_steps):
922
+ t = torch.full((B, 1), step / n_steps, device=device)
923
+ v = self.level2_fm(x, t, daily_pattern, spatial_cond_level2)
924
+ v = torch.clamp(v, -10.0, 10.0)
925
+ x = x + dt * v
926
+ x = torch.clamp(x, -10.0, 10.0)
927
+
928
+ return x
929
+
930
+ def generate_residual(
931
+ self,
932
+ coarse_signal: torch.Tensor,
933
+ daily_pattern: torch.Tensor,
934
+ weekly_trend: torch.Tensor,
935
+ spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
936
+ peak_hour: torch.Tensor,
937
+ n_steps: int = 50,
938
+ ) -> torch.Tensor:
939
+ """
940
+ Generate fine residual (Level 3).
941
+ Requires peak_hour for explicit conditioning.
942
+ """
943
+ _, _, spatial_cond_level3 = self._unpack_level_conditions(spatial_cond)
944
+ B = spatial_cond_level3.shape[0]
945
+ device = spatial_cond_level3.device
946
+
947
+ x = 0.1 * torch.randn_like(coarse_signal, device=device)
948
+ dt = 1.0 / n_steps
949
+
950
+ for step in range(n_steps):
951
+ t = torch.full((B, 1), step / n_steps, device=device)
952
+ # Pass peak_hour
953
+ v = self.level3_fm(
954
+ x, t, coarse_signal, daily_pattern, weekly_trend, spatial_cond_level3, peak_hour
955
+ )
956
+ v = torch.clamp(v, -5.0, 5.0)
957
+ x = x + dt * v
958
+ x = torch.clamp(x, -5.0, 5.0)
959
+
960
+ return x
961
+
962
+ def generate_hierarchical(
963
+ self,
964
+ spatial_cond: torch.Tensor | Dict[str, torch.Tensor],
965
+ peak_hour: torch.Tensor, # Required input
966
+ n_steps_per_level: int = 50,
967
+ ) -> Tuple[torch.Tensor, Dict]:
968
+ """
969
+ Full hierarchical generation.
970
+ """
971
+ spatial_cond_level1, spatial_cond_level2, spatial_cond_level3 = self._unpack_level_conditions(
972
+ spatial_cond
973
+ )
974
+ B = spatial_cond_level3.shape[0]
975
+ device = spatial_cond_level3.device
976
+
977
+ # Level 1: Generate day-type templates
978
+ daily_pattern = self.generate_daily_pattern(spatial_cond_level1, n_steps_per_level)
979
+ daily_pattern = torch.clamp(daily_pattern, -10.0, 10.0)
980
+
981
+ # Level 2: Generate weekly pattern (168 hours)
982
+ weekly_pattern = self.generate_weekly_trend(
983
+ daily_pattern, spatial_cond_level2, n_steps_per_level
984
+ )
985
+ weekly_pattern = torch.clamp(weekly_pattern, -10.0, 10.0)
986
+
987
+ # Construct periodic component for 4 weeks (672 hours)
988
+ coarse_signal = weekly_pattern.repeat(1, self.n_weeks) # [B, 672]
989
+ coarse_signal = torch.clamp(coarse_signal, -10.0, 10.0)
990
+
991
+ # Level 3: Generate fine residual
992
+ # Pass peak_hour to residual generator
993
+ residual = self.generate_residual(
994
+ coarse_signal,
995
+ daily_pattern,
996
+ weekly_pattern,
997
+ spatial_cond_level3,
998
+ peak_hour=peak_hour,
999
+ n_steps=n_steps_per_level,
1000
+ )
1001
+ residual = torch.clamp(residual, -5.0, 5.0)
1002
+
1003
+ # Final output
1004
+ generated = coarse_signal + residual
1005
+
1006
+ # =========================================================================
1007
+ # [MODIFIED] Physical Constraint: Enforce non-negative traffic
1008
+ # Previously was: generated = torch.clamp(generated, -10.0, 10.0)
1009
+ # =========================================================================
1010
+ generated = torch.clamp(generated, min=0.0, max=10.0)
1011
+
1012
+ intermediates = {
1013
+ 'daily_pattern': daily_pattern,
1014
+ 'weekly_pattern': weekly_pattern,
1015
+ 'coarse_signal': coarse_signal,
1016
+ 'residual': residual,
1017
+ }
1018
+
1019
+ return generated, intermediates
index.html ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Perception Layer - Data Alignment</title>
7
+
8
+ <script src="https://api.mapbox.com/mapbox-gl-js/v2.15.0/mapbox-gl.js"></script>
9
+ <link href="https://api.mapbox.com/mapbox-gl-js/v2.15.0/mapbox-gl.css" rel="stylesheet" />
10
+ <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
11
+
12
+ <link rel="stylesheet" href="style.css">
13
+ </head>
14
+ <body>
15
+
16
+ <div id="loading" class="loading-overlay">
17
+ <div class="spinner"></div>
18
+ <h2>SYSTEM INITIALIZING</h2>
19
+ <p>Loading Spatial & Temporal Data...</p>
20
+ </div>
21
+
22
+ <div class="sidebar">
23
+ <div class="header">
24
+ <h1>Overall</h1>
25
+
26
+ <div class="search-section">
27
+ <div class="search-container">
28
+ <input type="text" id="search-input" placeholder="Search ID..." autocomplete="off">
29
+ <button id="search-btn" class="cyber-btn-small">GO</button>
30
+ <button id="clear-search-btn" class="cyber-btn-small" title="Clear Markers">✕</button>
31
+ </div>
32
+
33
+ <div class="search-mode">
34
+ <input type="checkbox" id="keep-markers-check" checked>
35
+ <label for="keep-markers-check">Keep Previous Markers</label>
36
+ </div>
37
+ </div>
38
+ </div>
39
+
40
+ <div class="card">
41
+ <h2>📈 Temporal Modality</h2>
42
+ <div class="chart-container">
43
+ <canvas id="energyChart"></canvas>
44
+ </div>
45
+ </div>
46
+
47
+ <div class="card details-card">
48
+ <h2>📍 Spatial Metadata</h2>
49
+ <div class="stat-row">
50
+ <div><span class="label">Station ID</span> <span id="selected-id" class="value highlight">--</span></div>
51
+ <div><span class="label">Total Nodes</span> <span id="total-stations" class="value">--</span></div>
52
+ </div>
53
+ <div id="station-details" class="details-content">
54
+ <p class="placeholder-text">Waiting for selection...</p>
55
+ </div>
56
+ </div>
57
+ </div>
58
+
59
+ <div id="prediction-panel" class="sidebar-right">
60
+ <div class="header">
61
+ <h1>Traffic Prediction</h1>
62
+ <button id="close-pred-btn" class="cyber-btn-small">✕</button>
63
+ </div>
64
+ <div class="details-content">
65
+ <div class="stat-row" style="margin-bottom: 20px;">
66
+ <div><span class="label">Target Station ID</span> <span id="pred-station-id" class="value highlight" style="color: #f39c12;">--</span></div>
67
+ </div>
68
+
69
+ <div class="chart-container" style="height: 250px; position: relative;">
70
+ <canvas id="predictionChart"></canvas>
71
+ </div>
72
+
73
+ <!-- <div class="legend-box" style="margin-top: 20px; font-size: 0.9em; padding: 10px; background: rgba(0,0,0,0.3); border-radius: 4px;">
74
+ <div style="display:flex; align-items:center; margin-bottom:8px;">
75
+ <span style="display:inline-block; width:12px; height:12px; background:#00cec9; margin-right:10px; border-radius:50%;"></span>
76
+ <span>Real Data (Observed)</span>
77
+ </div>
78
+ <div style="display:flex; align-items:center;">
79
+ <span style="display:inline-block; width:12px; height:12px; background:#f39c12; margin-right:10px; border-radius:50%;"></span>
80
+ <span>AI Prediction (Model + POI)</span>
81
+ </div>
82
+ </div> -->
83
+
84
+ <div id="site-map-container" style="margin-top: 20px; display: none; border-top: 1px solid rgba(243, 156, 18, 0.3); padding-top: 15px;">
85
+ <h3 style="font-size: 13px; color: #f39c12; margin-bottom: 10px; text-transform: uppercase; letter-spacing: 1px;">
86
+ <span class="icon">📍</span> Optimal Site Analysis
87
+ </h3>
88
+
89
+ <div style="background: rgba(0,0,0,0.5); padding: 10px; border-radius: 6px; border: 1px solid rgba(255,255,255,0.05); display: flex; justify-content: center; align-items: center;">
90
+ <img id="site-map-img" src="" alt="LSI Site Map" style="width: 75%; border-radius: 4px; box-shadow: 0 0 10px rgba(0,0,0,0.5); display: block;">
91
+ </div>
92
+
93
+ <div id="site-explanation" class="cyber-explanation" style="display: none;"></div>
94
+
95
+ <p style="font-size: 0.7em; color: #aaa; margin-top: 8px; line-height: 1.4;">
96
+ * Heatmap calculated via spatial windowing.<br>
97
+ <span style="color:#2ecc71;">Green = High LSI (Stable)</span> | <span style="color:#e74c3c;">Red = High Volatility</span>
98
+ </p>
99
+ </div>
100
+
101
+ <p style="font-size: 0.75em; color: #666; margin-top: 20px; line-height: 1.4; border-top: 1px solid #333; padding-top: 10px;">
102
+ * Powered by <strong>Hierarchical Flow Matching V4</strong>.<br>
103
+ Utilizes Multi-modal Spatial Embeddings (POI, Satellite, Coordinates) for context-aware traffic forecasting.
104
+ </p>
105
+ </div>
106
+ </div>
107
+
108
+ <button id="toggle-left-btn" class="panel-toggle-btn left-toggle">◀</button>
109
+ <button id="toggle-right-btn" class="panel-toggle-btn right-toggle">▶</button>
110
+ <div class="main-content">
111
+ <div class="controls-container">
112
+ <button id="view-toggle" class="cyber-btn">
113
+ <span class="icon">👁️</span> View: 3D
114
+ </button>
115
+ <button id="data-toggle" class="cyber-btn">
116
+ <span class="icon">📡</span> Toggle Data
117
+ </button>
118
+
119
+ <button id="predict-toggle" class="cyber-btn" style="border-color: #f39c12; color: #f39c12;">
120
+ <span class="icon">🔮</span> Prediction Mode
121
+ </button>
122
+
123
+ <div class="filter-wrapper">
124
+ <button id="filter-btn" class="cyber-btn">
125
+ <span class="icon">🌪️</span> Filter Volatility
126
+ </button>
127
+ <div id="filter-menu" class="filter-menu"></div>
128
+ </div>
129
+ </div>
130
+
131
+ <div class="time-panel">
132
+ <button id="play-btn" class="cyber-btn play-control">▶</button>
133
+ <div class="slider-wrapper">
134
+ <input type="range" id="time-slider" min="0" max="671" value="0" step="1">
135
+ <div class="slider-ticks">
136
+ <span>Day 1</span><span>Day 7</span><span>Day 14</span><span>Day 21</span><span>Day 28</span>
137
+ </div>
138
+ </div>
139
+ <div id="time-display" class="digital-clock" style="min-width: 170px;">Day 01 - 00:00</div>
140
+ </div>
141
+
142
+ <div id="map"></div>
143
+ </div>
144
+
145
+ <script src="script.js"></script>
146
+ </body>
147
+ </html>
multimodal_spatial_encoder_v4.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Modal Spatial Context Encoder (V4)
3
+ =========================================
4
+
5
+ Fuses POI features and satellite imagery into a unified spatial context embedding.
6
+
7
+ Key components:
8
+ 1. POI Encoder: MLP with learnable category importance weights
9
+ 2. Satellite Image Encoder: ResNet-18 with multi-scale features
10
+ 3. Coordinate Encoder: Fourier features with learnable frequencies
11
+ 4. Fusion Strategy: Cross-attention + adaptive gating
12
+ 5. Condition Injection: FiLM/AdaGN modulation
13
+ 6. [NEW] Auxiliary Head: Peak Hour Classification (Explicit Peak Prediction)
14
+
15
+ Author: Optimization Team
16
+ Date: 2026-01-21
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import numpy as np
23
+ from typing import Dict, Optional
24
+
25
+
26
+ # =============================================================================
27
+ # POI Encoder with Learnable Importance Weights
28
+ # =============================================================================
29
+
30
+ class POIEncoder(nn.Module):
31
+ """
32
+ POI encoder with learnable category importance weights.
33
+
34
+ Input: POI count/density vector [B, poi_dim]
35
+ Output: POI embedding [B, spatial_dim]
36
+ """
37
+
38
+ def __init__(self, poi_dim: int = 20, spatial_dim: int = 192):
39
+ super().__init__()
40
+ self.poi_dim = poi_dim
41
+ self.spatial_dim = spatial_dim
42
+
43
+ # Learnable category importance weights
44
+ self.category_importance = nn.Parameter(torch.ones(poi_dim))
45
+
46
+ # Category token embeddings (POI-Enhancer inspired: attention-weighted semantic fusion)
47
+ self.category_embed = nn.Embedding(poi_dim, spatial_dim)
48
+
49
+ # Deep encoder with residual connections
50
+ self.encoder = nn.Sequential(
51
+ nn.Linear(poi_dim, 256),
52
+ nn.GELU(),
53
+ nn.Dropout(0.1),
54
+ nn.Linear(256, 256),
55
+ nn.GELU(),
56
+ nn.LayerNorm(256),
57
+ nn.Linear(256, spatial_dim),
58
+ nn.LayerNorm(spatial_dim),
59
+ )
60
+
61
+ # Attention pooling over category tokens
62
+ self.token_attn = nn.Sequential(
63
+ nn.Linear(spatial_dim, 128),
64
+ nn.GELU(),
65
+ nn.Linear(128, 1),
66
+ )
67
+
68
+ # Gate between MLP vector and token-pooled vector
69
+ self.fuse_gate = nn.Sequential(
70
+ nn.Linear(spatial_dim * 2, 1),
71
+ nn.Sigmoid(),
72
+ )
73
+
74
+ def forward(self, poi_dist: torch.Tensor, return_tokens: bool = False):
75
+ """
76
+ Args:
77
+ poi_dist: [B, poi_dim] POI distribution
78
+ Returns:
79
+ features: [B, spatial_dim] POI embedding
80
+ """
81
+ # Apply learnable importance weights
82
+ weights = F.softmax(self.category_importance, dim=0)
83
+ weighted_poi = poi_dist * weights
84
+
85
+ # Log transform for count data (handles skewed distributions)
86
+ poi_log = torch.log1p(weighted_poi)
87
+
88
+ # (1) Global vector via MLP
89
+ features_mlp = self.encoder(poi_log)
90
+
91
+ # (2) Category tokens + attention pooling (attention score-weighted merging)
92
+ # token_scale: [B, poi_dim, 1]
93
+ token_scale = poi_log.unsqueeze(-1)
94
+ # tokens: [B, poi_dim, D]
95
+ tokens = token_scale * self.category_embed.weight.unsqueeze(0)
96
+ attn_logits = self.token_attn(tokens).squeeze(-1) # [B, poi_dim]
97
+ attn = F.softmax(attn_logits, dim=-1).unsqueeze(-1) # [B, poi_dim, 1]
98
+ features_tok = (tokens * attn).sum(dim=1) # [B, D]
99
+
100
+ # Combine (learned trade-off)
101
+ g = self.fuse_gate(torch.cat([features_mlp, features_tok], dim=-1)) # [B, 1]
102
+ features = g * features_mlp + (1.0 - g) * features_tok
103
+
104
+ if return_tokens:
105
+ return features, tokens
106
+ return features
107
+
108
+
109
+ # =============================================================================
110
+ # Satellite Image Encoder (ResNet-18 backbone)
111
+ # =============================================================================
112
+
113
+ class ResidualBlock(nn.Module):
114
+ """Basic residual block for ResNet."""
115
+
116
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
117
+ super().__init__()
118
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
119
+ self.bn1 = nn.BatchNorm2d(out_channels)
120
+ self.relu = nn.ReLU(inplace=True)
121
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
122
+ self.bn2 = nn.BatchNorm2d(out_channels)
123
+
124
+ self.stride = stride
125
+ if stride != 1 or in_channels != out_channels:
126
+ self.shortcut = nn.Sequential(
127
+ nn.Conv2d(in_channels, out_channels, 1, stride=stride),
128
+ nn.BatchNorm2d(out_channels),
129
+ )
130
+ else:
131
+ self.shortcut = None
132
+
133
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
134
+ identity = x
135
+ out = self.conv1(x)
136
+ out = self.bn1(out)
137
+ out = self.relu(out)
138
+
139
+ out = self.conv2(out)
140
+ out = self.bn2(out)
141
+
142
+ if self.shortcut is not None:
143
+ identity = self.shortcut(x)
144
+
145
+ out = out + identity
146
+ out = self.relu(out)
147
+ return out
148
+
149
+
150
+ class SatelliteImageEncoder(nn.Module):
151
+ """
152
+ ResNet-18-based satellite image encoder with multi-scale feature extraction.
153
+
154
+ Input: Satellite image [B, 3, 64, 64]
155
+ Output: Image embedding [B, spatial_dim]
156
+ """
157
+
158
+ def __init__(self, spatial_dim: int = 192, n_heads: int = 8, token_layers: int = 2):
159
+ super().__init__()
160
+ self.spatial_dim = spatial_dim
161
+
162
+ # Initial layer
163
+ self.conv1 = nn.Sequential(
164
+ nn.Conv2d(3, 64, 7, stride=2, padding=3),
165
+ nn.BatchNorm2d(64),
166
+ nn.ReLU(inplace=True),
167
+ nn.MaxPool2d(3, stride=2, padding=1),
168
+ )
169
+
170
+ # ResNet blocks
171
+ self.layer1 = self._make_layer(64, 64, 2, stride=1)
172
+ self.layer2 = self._make_layer(64, 128, 2, stride=2)
173
+ self.layer3 = self._make_layer(128, 256, 2, stride=2)
174
+ self.layer4 = self._make_layer(256, 512, 2, stride=2)
175
+
176
+ # Multi-scale feature aggregation
177
+ self.pool1 = nn.AdaptiveAvgPool2d(1)
178
+ self.pool2 = nn.AdaptiveAvgPool2d(1)
179
+ self.pool3 = nn.AdaptiveAvgPool2d(1)
180
+ self.pool4 = nn.AdaptiveAvgPool2d(1)
181
+
182
+ # Learnable scale weights
183
+ self.scale_weights = nn.Parameter(torch.tensor([1.0, 1.0, 1.0, 1.0]))
184
+
185
+ # Final projection
186
+ self.proj = nn.Sequential(
187
+ nn.Linear(64 + 128 + 256 + 512, 384),
188
+ nn.GELU(),
189
+ nn.Dropout(0.1),
190
+ nn.Linear(384, spatial_dim),
191
+ nn.LayerNorm(spatial_dim),
192
+ )
193
+
194
+ # Region-level tokens (RemoteCLIP-inspired: patch/region awareness)
195
+ self.token_proj3 = nn.Linear(256, spatial_dim)
196
+ self.token_proj4 = nn.Linear(512, spatial_dim)
197
+ self.img_cls = nn.Parameter(torch.zeros(1, 1, spatial_dim))
198
+ enc_layer = nn.TransformerEncoderLayer(
199
+ d_model=spatial_dim,
200
+ nhead=n_heads,
201
+ dim_feedforward=spatial_dim * 4,
202
+ dropout=0.1,
203
+ activation="gelu",
204
+ batch_first=True,
205
+ norm_first=True,
206
+ )
207
+ self.token_mixer = nn.TransformerEncoder(enc_layer, num_layers=int(token_layers))
208
+
209
+ def _make_layer(self, in_channels: int, out_channels: int, blocks: int, stride: int):
210
+ layers = []
211
+ layers.append(ResidualBlock(in_channels, out_channels, stride))
212
+ for _ in range(1, blocks):
213
+ layers.append(ResidualBlock(out_channels, out_channels, 1))
214
+ return nn.Sequential(*layers)
215
+
216
+ def forward(self, x: torch.Tensor, return_tokens: bool = False):
217
+ """
218
+ Args:
219
+ x: [B, 3, 64, 64] satellite image
220
+ Returns:
221
+ features: [B, spatial_dim] image embedding
222
+ """
223
+ x = self.conv1(x) # [B, 64, 16, 16]
224
+
225
+ x1 = self.layer1(x) # [B, 64, 16, 16]
226
+ x2 = self.layer2(x1) # [B, 128, 8, 8]
227
+ x3 = self.layer3(x2) # [B, 256, 4, 4]
228
+ x4 = self.layer4(x3) # [B, 512, 2, 2]
229
+
230
+ # Multi-scale pooling
231
+ f1 = self.pool1(x1).flatten(1) # [B, 64]
232
+ f2 = self.pool2(x2).flatten(1) # [B, 128]
233
+ f3 = self.pool3(x3).flatten(1) # [B, 256]
234
+ f4 = self.pool4(x4).flatten(1) # [B, 512]
235
+
236
+ # Weighted fusion
237
+ weights = F.softmax(self.scale_weights, dim=0)
238
+ fused = torch.cat([
239
+ f1 * weights[0],
240
+ f2 * weights[1],
241
+ f3 * weights[2],
242
+ f4 * weights[3],
243
+ ], dim=-1)
244
+
245
+ # Final projection
246
+ features = self.proj(fused)
247
+
248
+ if not return_tokens:
249
+ return features
250
+
251
+ # Build region tokens from intermediate feature maps (4x4 + 2x2 = 20 tokens)
252
+ t3 = x3.flatten(2).transpose(1, 2) # [B, 16, 256]
253
+ t4 = x4.flatten(2).transpose(1, 2) # [B, 4, 512]
254
+ t3 = self.token_proj3(t3) # [B, 16, D]
255
+ t4 = self.token_proj4(t4) # [B, 4, D]
256
+ tokens = torch.cat([t3, t4], dim=1) # [B, 20, D]
257
+
258
+ # Mix tokens with a tiny Transformer, include a [CLS] token
259
+ cls = self.img_cls.expand(tokens.shape[0], -1, -1)
260
+ tokens_with_cls = torch.cat([cls, tokens], dim=1) # [B, 21, D]
261
+ tokens_with_cls = self.token_mixer(tokens_with_cls)
262
+
263
+ # tokens_with_cls[:, 0] is CLS; keep both CLS and spatial tokens
264
+ cls_out = tokens_with_cls[:, 0] # [B, D]
265
+ spatial_tokens = tokens_with_cls[:, 1:] # [B, 20, D]
266
+
267
+ # Blend CLS with pooled global feature for stability
268
+ feat = 0.5 * features + 0.5 * cls_out
269
+ return feat, spatial_tokens
270
+
271
+
272
+ # =============================================================================
273
+ # Coordinate Encoder with Learnable Fourier Features
274
+ # =============================================================================
275
+
276
+ class CoordinateEncoder(nn.Module):
277
+ """
278
+ Coordinate encoder with learnable Fourier frequencies.
279
+
280
+ Input: Coordinates [B, 2] (latitude, longitude)
281
+ Output: Coordinate embedding [B, spatial_dim]
282
+ """
283
+
284
+ def __init__(self, coord_dim: int = 2, spatial_dim: int = 192):
285
+ super().__init__()
286
+ self.coord_dim = coord_dim
287
+ self.spatial_dim = spatial_dim
288
+
289
+ # Multi-scale learnable Fourier frequencies
290
+ n_freqs = 64
291
+ init_freqs = 2 ** torch.linspace(0, 8, n_freqs)
292
+ self.freqs = nn.Parameter(init_freqs)
293
+
294
+ fourier_dim = coord_dim * n_freqs * 2
295
+
296
+ # Deep encoder
297
+ self.encoder = nn.Sequential(
298
+ nn.Linear(fourier_dim + coord_dim, 512),
299
+ nn.GELU(),
300
+ nn.Dropout(0.1),
301
+ nn.Linear(512, 384),
302
+ nn.GELU(),
303
+ nn.LayerNorm(384),
304
+ nn.Linear(384, spatial_dim),
305
+ nn.LayerNorm(spatial_dim),
306
+ )
307
+
308
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
309
+ """
310
+ Args:
311
+ coords: [B, 2] coordinates
312
+ Returns:
313
+ features: [B, spatial_dim] coordinate embedding
314
+ """
315
+ # Fourier features with learnable frequencies
316
+ coords_scaled = coords.unsqueeze(-1) * self.freqs # [B, 2, n_freqs]
317
+ fourier = torch.cat([
318
+ torch.sin(coords_scaled * np.pi),
319
+ torch.cos(coords_scaled * np.pi),
320
+ ], dim=-1).flatten(-2) # [B, fourier_dim]
321
+
322
+ # Combine with raw coordinates
323
+ combined = torch.cat([coords, fourier], dim=-1)
324
+
325
+ # Encode
326
+ features = self.encoder(combined)
327
+
328
+ return features
329
+
330
+
331
+ # =============================================================================
332
+ # Cross-Attention Fusion Module
333
+ # =============================================================================
334
+
335
+ class CrossAttentionFusion(nn.Module):
336
+ """
337
+ Cross-attention fusion for multi-modal conditioning.
338
+
339
+ Fuses POI, satellite, and coordinate embeddings via multi-head attention.
340
+ """
341
+
342
+ def __init__(self, spatial_dim: int = 192, n_heads: int = 8):
343
+ super().__init__()
344
+ self.spatial_dim = spatial_dim
345
+ self.n_heads = n_heads
346
+
347
+ # Two rounds of cross-attention (vector mode: 3 tokens; token mode: CLS->context)
348
+ self.cross_attn1 = nn.MultiheadAttention(
349
+ spatial_dim, num_heads=n_heads, dropout=0.1, batch_first=True
350
+ )
351
+ self.cross_attn2 = nn.MultiheadAttention(
352
+ spatial_dim, num_heads=n_heads, dropout=0.1, batch_first=True
353
+ )
354
+
355
+ # Layer norms
356
+ self.norm1 = nn.LayerNorm(spatial_dim)
357
+ self.norm2 = nn.LayerNorm(spatial_dim)
358
+ self.norm3 = nn.LayerNorm(spatial_dim)
359
+ self.norm4 = nn.LayerNorm(spatial_dim)
360
+
361
+ # Feed-forward networks
362
+ self.ffn1 = nn.Sequential(
363
+ nn.Linear(spatial_dim, spatial_dim * 4),
364
+ nn.GELU(),
365
+ nn.Dropout(0.1),
366
+ nn.Linear(spatial_dim * 4, spatial_dim),
367
+ )
368
+ self.ffn2 = nn.Sequential(
369
+ nn.Linear(spatial_dim, spatial_dim * 4),
370
+ nn.GELU(),
371
+ nn.Dropout(0.1),
372
+ nn.Linear(spatial_dim * 4, spatial_dim),
373
+ )
374
+
375
+ # Adaptive gating for modality importance
376
+ self.gate = nn.Sequential(
377
+ nn.Linear(spatial_dim * 3, 256),
378
+ nn.GELU(),
379
+ nn.Linear(256, 3),
380
+ nn.Softmax(dim=-1),
381
+ )
382
+
383
+ # Token-mode: learnable fusion token
384
+ self.fusion_cls = nn.Parameter(torch.zeros(1, 1, spatial_dim))
385
+ self.token_out_gate = nn.Sequential(
386
+ nn.Linear(spatial_dim * 2, 1),
387
+ nn.Sigmoid(),
388
+ )
389
+
390
+ def forward(
391
+ self,
392
+ sat_feat: torch.Tensor,
393
+ poi_feat: torch.Tensor,
394
+ coord_feat: torch.Tensor,
395
+ sat_tokens: Optional[torch.Tensor] = None,
396
+ poi_tokens: Optional[torch.Tensor] = None,
397
+ coord_token: Optional[torch.Tensor] = None,
398
+ ) -> torch.Tensor:
399
+ """
400
+ Args:
401
+ sat_feat: [B, spatial_dim] satellite embedding
402
+ poi_feat: [B, spatial_dim] POI embedding
403
+ coord_feat: [B, spatial_dim] coordinate embedding
404
+ Returns:
405
+ fused: [B, spatial_dim] fused embedding
406
+ """
407
+ # ---------------------------------------------------------------------
408
+ # (A) Vector mode (backward-compatible): treat each modality as 1 token.
409
+ # ---------------------------------------------------------------------
410
+ if sat_tokens is None and poi_tokens is None and coord_token is None:
411
+ # Stack as sequence [B, 3, D]
412
+ modalities = torch.stack([sat_feat, poi_feat, coord_feat], dim=1)
413
+
414
+ # First round of cross-attention
415
+ attn_out1, _ = self.cross_attn1(modalities, modalities, modalities)
416
+ modalities = self.norm1(modalities + attn_out1)
417
+ ffn_out1 = self.ffn1(modalities)
418
+ modalities = self.norm2(modalities + ffn_out1)
419
+
420
+ # Second round
421
+ attn_out2, _ = self.cross_attn2(modalities, modalities, modalities)
422
+ modalities = self.norm3(modalities + attn_out2)
423
+ ffn_out2 = self.ffn2(modalities)
424
+ modalities = self.norm4(modalities + ffn_out2)
425
+
426
+ # Unpack
427
+ sat_out, poi_out, coord_out = modalities.unbind(dim=1)
428
+
429
+ # Adaptive gating
430
+ concat = torch.cat([sat_out, poi_out, coord_out], dim=-1)
431
+ weights = self.gate(concat) # [B, 3]
432
+
433
+ # Weighted fusion
434
+ fused = (
435
+ weights[:, 0:1] * sat_out +
436
+ weights[:, 1:2] * poi_out +
437
+ weights[:, 2:3] * coord_out
438
+ )
439
+
440
+ return fused
441
+
442
+ # ---------------------------------------------------------------------
443
+ # (B) Token mode: CLS attends over (sat tokens + poi tokens + coord token).
444
+ # RemoteCLIP-inspired region tokens + POI-Enhancer-inspired semantic tokens.
445
+ # ---------------------------------------------------------------------
446
+ B = sat_feat.shape[0]
447
+ context = []
448
+ if sat_tokens is not None:
449
+ context.append(sat_tokens)
450
+ else:
451
+ context.append(sat_feat.unsqueeze(1))
452
+
453
+ if poi_tokens is not None:
454
+ context.append(poi_tokens)
455
+ else:
456
+ context.append(poi_feat.unsqueeze(1))
457
+
458
+ if coord_token is not None:
459
+ context.append(coord_token.unsqueeze(1))
460
+ else:
461
+ context.append(coord_feat.unsqueeze(1))
462
+
463
+ context_tokens = torch.cat(context, dim=1) # [B, L, D]
464
+ cls = self.fusion_cls.expand(B, -1, -1) # [B, 1, D]
465
+
466
+ # Two rounds of CLS->context attention + FFN (Transformer-like)
467
+ attn1, _ = self.cross_attn1(cls, context_tokens, context_tokens)
468
+ cls = self.norm1(cls + attn1)
469
+ cls = self.norm2(cls + self.ffn1(cls))
470
+
471
+ attn2, _ = self.cross_attn2(cls, context_tokens, context_tokens)
472
+ cls = self.norm3(cls + attn2)
473
+ cls = self.norm4(cls + self.ffn2(cls))
474
+
475
+ cls_vec = cls.squeeze(1) # [B, D]
476
+
477
+ # Keep the original adaptive gating as a global shortcut, then learn to mix.
478
+ concat = torch.cat([sat_feat, poi_feat, coord_feat], dim=-1)
479
+ weights = self.gate(concat)
480
+ gated = (
481
+ weights[:, 0:1] * sat_feat +
482
+ weights[:, 1:2] * poi_feat +
483
+ weights[:, 2:3] * coord_feat
484
+ )
485
+ mix = self.token_out_gate(torch.cat([cls_vec, gated], dim=-1)) # [B, 1]
486
+ fused = mix * cls_vec + (1.0 - mix) * gated
487
+ return fused
488
+
489
+
490
+ # =============================================================================
491
+ # Multi-Scale Condition Generator
492
+ # =============================================================================
493
+
494
+ class MultiScaleConditionGenerator(nn.Module):
495
+ """
496
+ Generate stage-specific multi-scale conditions.
497
+
498
+ Produces different condition embeddings for each hierarchical level.
499
+ """
500
+
501
+ def __init__(self, spatial_dim: int = 192):
502
+ super().__init__()
503
+
504
+ # Level 1 (daily): global patterns
505
+ self.level1_proj = nn.Sequential(
506
+ nn.Linear(spatial_dim, 256),
507
+ nn.GELU(),
508
+ nn.Dropout(0.1),
509
+ nn.Linear(256, spatial_dim),
510
+ nn.LayerNorm(spatial_dim),
511
+ )
512
+
513
+ # Level 2 (weekly): periodic structure
514
+ self.level2_proj = nn.Sequential(
515
+ nn.Linear(spatial_dim, 256),
516
+ nn.GELU(),
517
+ nn.Dropout(0.1),
518
+ nn.Linear(256, spatial_dim),
519
+ nn.LayerNorm(spatial_dim),
520
+ )
521
+
522
+ # Level 3 (residual): fine details
523
+ self.level3_proj = nn.Sequential(
524
+ nn.Linear(spatial_dim, 384),
525
+ nn.GELU(),
526
+ nn.Dropout(0.1),
527
+ nn.Linear(384, spatial_dim),
528
+ nn.LayerNorm(spatial_dim),
529
+ )
530
+
531
+ def forward(self, base_condition: torch.Tensor) -> Dict[str, torch.Tensor]:
532
+ """Generate stage-specific conditions."""
533
+ return {
534
+ 'level1_cond': self.level1_proj(base_condition),
535
+ 'level2_cond': self.level2_proj(base_condition),
536
+ 'level3_cond': self.level3_proj(base_condition),
537
+ }
538
+
539
+
540
+ # =============================================================================
541
+ # Complete Multi-Modal Spatial Encoder
542
+ # =============================================================================
543
+
544
+ class MultiModalSpatialEncoderV4(nn.Module):
545
+ """
546
+ Complete multi-modal spatial encoder combining:
547
+ - POI features
548
+ - Satellite imagery
549
+ - Geographic coordinates
550
+ - Cross-attention fusion
551
+ - Multi-scale condition generation
552
+ - [NEW] Auxiliary Peak Hour Classification
553
+ """
554
+
555
+ def __init__(self, spatial_dim: int = 192, poi_dim: int = 20):
556
+ super().__init__()
557
+ self.spatial_dim = spatial_dim
558
+ self.poi_dim = poi_dim
559
+
560
+ # Individual encoders
561
+ self.poi_encoder = POIEncoder(poi_dim, spatial_dim)
562
+ self.satellite_encoder = SatelliteImageEncoder(spatial_dim)
563
+ self.coord_encoder = CoordinateEncoder(2, spatial_dim)
564
+
565
+ # Multi-modal fusion
566
+ self.fusion = CrossAttentionFusion(spatial_dim, n_heads=8)
567
+
568
+ # Multi-scale condition generation
569
+ self.multiscale_generator = MultiScaleConditionGenerator(spatial_dim)
570
+
571
+ # [NEW] Auxiliary Head: Peak Hour Prediction
572
+ # Predicts which hour (0-23) has the maximum traffic
573
+ self.peak_hour_classifier = nn.Sequential(
574
+ nn.Linear(spatial_dim, 128),
575
+ nn.GELU(),
576
+ nn.Dropout(0.1),
577
+ nn.Linear(128, 24) # 24 hours classification
578
+ )
579
+
580
+ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
581
+ """
582
+ Args:
583
+ batch: dict with keys:
584
+ - 'satellite_img': [B, 3, 64, 64]
585
+ - 'poi_dist': [B, poi_dim]
586
+ - 'coords': [B, 2]
587
+ Returns:
588
+ outputs: dict with conditions and predicted peak logits
589
+ """
590
+ # Encode each modality (token-aware for stronger multi-modal fusion)
591
+ sat_feat, sat_tokens = self.satellite_encoder(batch['satellite_img'], return_tokens=True)
592
+ poi_feat, poi_tokens = self.poi_encoder(batch['poi_dist'], return_tokens=True)
593
+ coord_feat = self.coord_encoder(batch['coords'])
594
+
595
+ # Fuse modalities (CLS attends to region tokens + semantic tokens)
596
+ base_condition = self.fusion(
597
+ sat_feat,
598
+ poi_feat,
599
+ coord_feat,
600
+ sat_tokens=sat_tokens,
601
+ poi_tokens=poi_tokens,
602
+ coord_token=coord_feat,
603
+ )
604
+
605
+ # Generate multi-scale conditions
606
+ stage_conditions = self.multiscale_generator(base_condition)
607
+
608
+ # [NEW] Predict peak hour
609
+ pred_peak_logits = self.peak_hour_classifier(base_condition)
610
+
611
+ outputs = {
612
+ 'base_condition': base_condition,
613
+ 'pred_peak_logits': pred_peak_logits, # Auxiliary output
614
+ **stage_conditions,
615
+ }
616
+
617
+ return outputs
618
+
619
+
620
+ if __name__ == "__main__":
621
+ # Test the encoder
622
+ B = 4
623
+ spatial_dim = 192
624
+ poi_dim = 20
625
+
626
+ encoder = MultiModalSpatialEncoderV4(spatial_dim, poi_dim)
627
+
628
+ # Create dummy batch
629
+ batch = {
630
+ 'satellite_img': torch.randn(B, 3, 64, 64),
631
+ 'poi_dist': torch.randn(B, poi_dim),
632
+ 'coords': torch.randn(B, 2),
633
+ }
634
+
635
+ # Forward pass
636
+ outputs = encoder(batch)
637
+
638
+ print("Multi-Modal Spatial Encoder V4 Test:")
639
+ print(f" Base condition shape: {outputs['base_condition'].shape}")
640
+ print(f" Peak Logits shape: {outputs['pred_peak_logits'].shape}")
641
+ print(f" Level 1 condition shape: {outputs['level1_cond'].shape}")
642
+ print(f" Level 2 condition shape: {outputs['level2_cond'].shape}")
643
+ print(f" Level 3 condition shape: {outputs['level3_cond'].shape}")
644
+
645
+ print("\nEncoder test passed!")
prediction_backend.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import math
4
+ import numpy as np
5
+ import requests
6
+ import random
7
+ import base64
8
+ import matplotlib
9
+ matplotlib.use('Agg')
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
12
+ from io import BytesIO
13
+ from PIL import Image
14
+
15
+ # Import V4 Model System
16
+ from hierarchical_flow_matching_training_v4 import HierarchicalFlowMatchingSystemV4
17
+
18
+ # =============================================================================
19
+ # Config
20
+ # =============================================================================
21
+ MAPBOX_ACCESS_TOKEN = "pk.eyJ1IjoieXlhaXl5IiwiYSI6ImNtaTVpMTVlaTJmdzMybW9zcmFieGxpdHUifQ.181d6E5fzLw1CEZMEPU53Q"
22
+ MAPBOX_ZOOM = 15
23
+ FETCH_SIZE = 256
24
+ IMAGE_SIZE = 64
25
+ SEED = 42
26
+
27
+ SPATIAL_DIM = 192
28
+ HIDDEN_DIM = 256
29
+ POI_DIM = 20
30
+ N_LAYERS_LEVEL3 = 6
31
+ N_STEPS = 50
32
+
33
+
34
+ class MapboxSatelliteFetcher:
35
+ """
36
+ Dynamically fetches satellite imagery, strictly aligning with
37
+ the image preprocessing logic used during training.
38
+ """
39
+ def __init__(self, access_token=MAPBOX_ACCESS_TOKEN, zoom=MAPBOX_ZOOM, fetch_size=FETCH_SIZE, target_size=IMAGE_SIZE):
40
+ self.access_token = access_token
41
+ self.zoom = zoom
42
+ self.fetch_size = fetch_size
43
+ self.target_size = target_size
44
+
45
+ def fetch(self, lon, lat, station_id=None, return_pil=False):
46
+ """ Fetches static satellite map centered at [lon, lat] """
47
+ url = f"https://api.mapbox.com/styles/v1/mapbox/satellite-v9/static/{lon},{lat},{self.zoom},0,0/{self.fetch_size}x{self.fetch_size}?access_token={self.access_token}"
48
+ try:
49
+ response = requests.get(url, timeout=10)
50
+ response.raise_for_status()
51
+
52
+ img = Image.open(BytesIO(response.content)).convert("RGB")
53
+ original_pil = img.copy() # Store high-res original for micro-grid slicing
54
+
55
+ # Resize to model input size (64x64)
56
+ img_resized = img.resize((self.target_size, self.target_size), Image.BILINEAR)
57
+ arr = np.asarray(img_resized, dtype=np.float32) / 255.0
58
+
59
+ # Convert HWC to CHW format
60
+ chw = arr.transpose(2, 0, 1)
61
+ tensor_np = np.clip(chw, 0.0, 1.0).astype(np.float32, copy=False)
62
+
63
+ if return_pil:
64
+ return tensor_np, original_pil
65
+ return tensor_np
66
+
67
+ except Exception as e:
68
+ print(f"[Mapbox Fetcher Error] Station {station_id}: {e}")
69
+ fallback_np = np.zeros((3, self.target_size, self.target_size), dtype=np.float32)
70
+ if return_pil:
71
+ return fallback_np, Image.new('RGB', (self.fetch_size, self.fetch_size), color='black')
72
+ return fallback_np
73
+
74
+
75
+ class TrafficPredictor:
76
+ """Traffic generation model predictor and site selection analyzer"""
77
+ def __init__(self, model_path, spatial_path, traffic_path, local_sat_dir="real_spatial_data/satellite_png", device=None):
78
+ self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
+ self.local_sat_dir = local_sat_dir
80
+
81
+ # Load spatial features cache
82
+ if os.path.exists(spatial_path):
83
+ self.data_cache = np.load(spatial_path, allow_pickle=True)
84
+ else:
85
+ raise FileNotFoundError(f"Spatial path {spatial_path} not found!")
86
+
87
+ # Load traffic records for validation comparison
88
+ if os.path.exists(traffic_path):
89
+ self.traffic_data = np.load(traffic_path, allow_pickle=True)['bs_record']
90
+ else:
91
+ self.traffic_data = None
92
+
93
+ # ==========================================
94
+ # Strictly simulate Dataset length truncation to ensure
95
+ # normalization extrema are 100% aligned with training.
96
+ # ==========================================
97
+ n_traffic = len(self.traffic_data) if self.traffic_data is not None else float('inf')
98
+ n_poi = len(self.data_cache['poi_distributions'])
99
+ n_coords = len(self.data_cache['coordinates'])
100
+ self.n_valid = min(n_traffic, n_poi, n_coords)
101
+
102
+ # Calculate coordinate bounds for normalization
103
+ raw_coords_valid = self.data_cache['coordinates'][:self.n_valid].astype(np.float32)
104
+ self.coord_min = raw_coords_valid.min(axis=0)
105
+ self.coord_max = raw_coords_valid.max(axis=0)
106
+
107
+ self.satellite_fetcher = MapboxSatelliteFetcher()
108
+ self.model = self._load_model(model_path)
109
+
110
+ def _load_model(self, model_path):
111
+ print(f"Loading V4 Model on {self.device}...")
112
+ model = HierarchicalFlowMatchingSystemV4(
113
+ spatial_dim=SPATIAL_DIM,
114
+ hidden_dim=HIDDEN_DIM,
115
+ poi_dim=POI_DIM,
116
+ n_layers_level3=N_LAYERS_LEVEL3
117
+ ).to(self.device)
118
+
119
+ if os.path.exists(model_path):
120
+ ckpt = torch.load(model_path, map_location=self.device, weights_only=False)
121
+ state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
122
+ model.load_state_dict(state_dict)
123
+ print(f"Checkpoint loaded successfully from {model_path}.")
124
+ else:
125
+ raise FileNotFoundError(f"Model checkpoint not found at {model_path}")
126
+
127
+ model.eval()
128
+ return model
129
+
130
+ def predict(self, idx, use_local_img_for_debug=False):
131
+ """Generates traffic predictions and performs LSI spatial analysis"""
132
+ try:
133
+ torch.manual_seed(SEED)
134
+ idx = int(idx)
135
+
136
+ # 1. Process POI distributions
137
+ raw_poi = self.data_cache['poi_distributions'][idx].astype(np.float32).copy()
138
+ raw_poi = np.clip(raw_poi, 0.0, None)
139
+ poi_sum = float(raw_poi.sum())
140
+ if poi_sum > 1e-8:
141
+ raw_poi = raw_poi / poi_sum
142
+ else:
143
+ raw_poi = np.zeros_like(raw_poi)
144
+ poi_tensor = torch.from_numpy(raw_poi).unsqueeze(0).to(self.device)
145
+
146
+ # 2. Process Geographical Coordinates
147
+ raw_loc = self.data_cache['coordinates'][idx].astype(np.float32)
148
+ lon, lat = raw_loc[0], raw_loc[1]
149
+
150
+ # Fetch Satellite Image (Local Debug vs. Remote API)
151
+ if use_local_img_for_debug and os.path.exists(f"{self.local_sat_dir}/{idx}.png"):
152
+ img = Image.open(f"{self.local_sat_dir}/{idx}.png").convert("RGB")
153
+ original_pil = img.copy()
154
+ else:
155
+ _, original_pil = self.satellite_fetcher.fetch(lon, lat, station_id=str(idx), return_pil=True)
156
+
157
+ # 3. LSI Heatmap Generation (Location Stability Index)
158
+ lsi_grid, best_idx, best_traffic = self.generate_lsi_heatmap(original_pil, lat, lon, poi_tensor, grid_size=3)
159
+ site_map_b64 = self.create_site_map_base64(original_pil, lsi_grid, best_idx)
160
+
161
+ # 4. Map Projections: Convert grid indices back to physical Lat/Lon
162
+ grid_size = 3
163
+ best_row, best_col = best_idx
164
+
165
+ # Calculate spans based on Web Mercator projection at specific zoom
166
+ lon_span = 360.0 / (2 ** MAPBOX_ZOOM)
167
+ lat_span = lon_span * math.cos(math.radians(lat))
168
+ step_lon = lon_span / grid_size
169
+ step_lat = lat_span / grid_size
170
+
171
+ # Offset from base center
172
+ best_lat = lat - (best_row - grid_size // 2) * step_lat
173
+ best_lon = lon + (best_col - grid_size // 2) * step_lon
174
+ best_loc = [float(best_lon), float(best_lat)]
175
+
176
+ # 5. Multidimensional NLG (Natural Language Generation) Engine
177
+ best_lsi_value = float(lsi_grid[best_idx])
178
+ avg_lsi_value = float(np.mean(lsi_grid))
179
+ min_lsi_value = float(np.min(lsi_grid))
180
+
181
+ # Core Performance Metrics
182
+ improvement_avg = ((best_lsi_value - avg_lsi_value) / avg_lsi_value) * 100 if avg_lsi_value > 0 else 0
183
+ spatial_contrast = ((best_lsi_value - min_lsi_value) / min_lsi_value) * 100 if min_lsi_value > 0 else 0
184
+
185
+ # Feature A: POI Semantic Mapping
186
+ poi_idx = int(torch.argmax(poi_tensor[0]))
187
+ poi_categories = [
188
+ "Commercial/Retail", "Residential Complex", "Transit Hub",
189
+ "Corporate/Office", "Public/Recreational", "Industrial Zone",
190
+ "Mixed-Use Urban", "Educational/Campus"
191
+ ]
192
+ dominant_poi = poi_categories[poi_idx % len(poi_categories)]
193
+
194
+ # Feature B: Temporal Tide Analysis (Extract daily peak from 672-hour sequence)
195
+ daily_pattern = best_traffic.reshape(-1, 24).mean(axis=0)
196
+ peak_hour = int(np.argmax(daily_pattern))
197
+
198
+ if 7 <= peak_hour <= 10:
199
+ peak_type = "Morning Rush (07:00-10:00)"
200
+ elif 16 <= peak_hour <= 19:
201
+ peak_type = "Evening Rush (16:00-19:00)"
202
+ elif 11 <= peak_hour <= 15:
203
+ peak_type = "Midday Active (11:00-15:00)"
204
+ else:
205
+ peak_type = "Night/Off-peak Active"
206
+
207
+ # Load description based on average volume
208
+ avg_load = float(np.mean(best_traffic))
209
+ if avg_load > 6.0: load_desc = "High-Capacity"
210
+ elif avg_load > 3.0: load_desc = "Moderate-Load"
211
+ else: load_desc = "Baseline/Sparse"
212
+
213
+ # Dynamic Text Assembly (4-Stage Structure)
214
+ # Stage 1: Spatial Environment Diagnosis
215
+ if spatial_contrast > 40:
216
+ p1 = f"Spatial scan detects a highly heterogeneous {dominant_poi} sector with steep traffic gradients. "
217
+ else:
218
+ p1 = f"Spatial scan indicates a relatively uniform {dominant_poi} matrix. "
219
+
220
+ # Stage 2: Temporal Characteristics
221
+ p2 = f"Flow Matching model projects a {load_desc} demand curve, heavily anchored by a {peak_type} signature. "
222
+
223
+ # Stage 3: Decision Output
224
+ p3 = f"Micro-grid ({best_row}, {best_col}) is isolated as the topological optimum, yielding a peak Location Stability Index (LSI) of {best_lsi_value:.2f}. "
225
+
226
+ # Stage 4: Business Value Assessment
227
+ if improvement_avg > 15:
228
+ p4 = f"Deploying infrastructure here intercepts peak volatility, providing a {improvement_avg:.1f}% structural stability gain over the regional average."
229
+ else:
230
+ p4 = f"This precise coordinate offers a marginal yet critical {improvement_avg:.1f}% variance reduction, ensuring optimal load-balancing."
231
+
232
+ explanation_text = p1 + p2 + p3 + p4
233
+ # ===============================================
234
+
235
+ # Finalize output sequence
236
+ gen_seq_real = np.clip(best_traffic, 0.0, 10.0)
237
+ real_seq = self.traffic_data[idx].tolist() if self.traffic_data is not None else []
238
+
239
+ return {
240
+ "station_id": idx,
241
+ "prediction": gen_seq_real.tolist(),
242
+ "real": real_seq,
243
+ "site_map_b64": site_map_b64,
244
+ "best_loc": best_loc,
245
+ "explanation": explanation_text,
246
+ "status": "success"
247
+ }
248
+
249
+ except Exception as e:
250
+ import traceback
251
+ traceback.print_exc()
252
+ return {"error": str(e), "status": "failed"}
253
+
254
+ @torch.no_grad()
255
+ def generate_lsi_heatmap(self, img_pil, base_lat, base_lon, poi_tensor, grid_size=3):
256
+ w, h = img_pil.size
257
+ patch_w, patch_h = w // grid_size, h // grid_size
258
+
259
+ patches = []
260
+ coords = []
261
+
262
+ # Calculate precise Lat/Lon spans for 256x256 image area
263
+ lon_span = 360.0 / (2 ** MAPBOX_ZOOM)
264
+ lat_span = lon_span * math.cos(math.radians(base_lat))
265
+ step_lon = lon_span / grid_size
266
+ step_lat = lat_span / grid_size
267
+
268
+ for i in range(grid_size):
269
+ for j in range(grid_size):
270
+ # Slice and resize patches for model
271
+ box = (j * patch_w, i * patch_h, (j+1) * patch_w, (i+1) * patch_h)
272
+ patch = img_pil.crop(box).resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR).convert("RGB")
273
+ arr = np.array(patch).transpose(2,0,1) / 255.0
274
+ patches.append(torch.tensor(arr, dtype=torch.float32))
275
+
276
+ # Normalize coordinates for the model
277
+ offset_lat = base_lat - (i - grid_size//2) * step_lat
278
+ offset_lon = base_lon + (j - grid_size//2) * step_lon
279
+ raw_coord = np.array([offset_lon, offset_lat], dtype=np.float32)
280
+ norm_coord = (raw_coord - self.coord_min) / (self.coord_max - self.coord_min + 1e-8)
281
+ coords.append(torch.tensor(norm_coord, dtype=torch.float32))
282
+
283
+ batch_size = grid_size ** 2
284
+ # Assemble GPU Batch
285
+ batch_gpu = {
286
+ 'satellite_img': torch.stack(patches).to(self.device),
287
+ 'poi_dist': poi_tensor.repeat(batch_size, 1),
288
+ 'coords': torch.stack(coords).to(self.device),
289
+ 'traffic_seq': torch.zeros(batch_size, 672, dtype=torch.float32).to(self.device)
290
+ }
291
+
292
+ # Batch Inference (Computes 9 regions simultaneously)
293
+ output = self.model(batch_gpu, mode='generate', loss_cfg={'n_steps_generate': N_STEPS})
294
+ outputs = output['generated'].cpu().numpy() # [9, 672]
295
+
296
+ ## LSI Calculation: 1 / (std + epsilon)
297
+ # Higher LSI means lower variance (more stable traffic)
298
+ stds = outputs.std(axis=1)
299
+ lsis = 1.0 / (stds + 1e-6)
300
+ lsi_grid = lsis.reshape(grid_size, grid_size)
301
+
302
+ # Find the most stable coordinate
303
+ best_idx = np.unravel_index(np.argmax(lsi_grid), lsi_grid.shape)
304
+ best_traffic = outputs[best_idx[0] * grid_size + best_idx[1]]
305
+
306
+ return lsi_grid, best_idx, best_traffic
307
+
308
+
309
+ def create_site_map_base64(self, img_pil, lsi_grid, best_idx):
310
+ """Generates heatmap overlay visualization and encodes to Base64"""
311
+ img_arr = np.array(img_pil)
312
+ h, w, _ = img_arr.shape
313
+ grid_h, grid_w = lsi_grid.shape
314
+
315
+ fig, ax = plt.subplots(figsize=(4, 4), dpi=120)
316
+
317
+ # Overlay heatmap on satellite image
318
+ ax.imshow(img_arr)
319
+ im = ax.imshow(lsi_grid, cmap='RdYlGn', alpha=0.45, extent=[0, w, h, 0], interpolation='bicubic')
320
+
321
+ cell_w, cell_h = w / grid_w, h / grid_h
322
+ best_row, best_col = best_idx
323
+ center_x = best_col * cell_w + cell_w / 2
324
+ center_y = best_row * cell_h + cell_h / 2
325
+
326
+ # Draw target star and LSI indicator
327
+ ax.plot(center_x, center_y, marker='*', color='red', markersize=20, markeredgecolor='white', markeredgewidth=1.5)
328
+ best_lsi = lsi_grid[best_idx]
329
+ ax.annotate(f"LSI: {best_lsi:.2f}", xy=(center_x, center_y), xytext=(10, 10),
330
+ textcoords='offset points', color='white', fontweight='bold',
331
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="red", alpha=0.8, edgecolor="white"))
332
+
333
+ ax.axis('off')
334
+ plt.tight_layout()
335
+
336
+ # Convert to Base64 for API transmission
337
+ buf = BytesIO()
338
+ fig.savefig(buf, format="png", bbox_inches='tight', transparent=True)
339
+ plt.close(fig)
340
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
341
+
342
+ if __name__ == "__main__":
343
+ MODEL_PATH = "best_corr_model.pt"
344
+ SPATIAL_PATH = "data/spatial_features.npz"
345
+ TRAFFIC_PATH = "data/bs_record_energy_normalized_sampled.npz"
346
+
347
+ predictor = TrafficPredictor(
348
+ model_path=MODEL_PATH,
349
+ spatial_path=SPATIAL_PATH,
350
+ traffic_path=TRAFFIC_PATH
351
+ )
352
+
353
+ test_id = 277
354
+ result = predictor.predict(test_id, use_local_img_for_debug=False)
355
+
356
+ if result.get("status") == "success":
357
+ print(f"Prediction successful for Station {test_id}!")
358
+ else:
359
+ print(f"Prediction failed: {result.get('error')}")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Flask==3.0.0
2
+ flask-cors==4.0.0
3
+ numpy==1.26.4
4
+ Pillow==10.2.0
5
+ requests==2.31.0
6
+ matplotlib==3.8.2
7
+ --extra-index-url https://download.pytorch.org/whl/cpu
8
+ torch==2.1.2+cpu
script.js ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ==========================================
2
+ // 1. Config (Focused on Shanghai)
3
+ // ==========================================
4
+ const CONFIG = {
5
+ MAPBOX_TOKEN: 'pk.eyJ1IjoieXlhaXl5IiwiYSI6ImNtaTVpMTVlaTJmdzMybW9zcmFieGxpdHUifQ.181d6E5fzLw1CEZMEPU53Q',
6
+ // API_BASE: 'http://127.0.0.1:5000/api', // Local
7
+ API_BASE: '/api', // Online
8
+
9
+ // Shanghai City Center
10
+ DEFAULT_CENTER: [121.4737, 31.2304],
11
+ DEFAULT_ZOOM: 10.5,
12
+
13
+ // Shanghai Coordinate Bounds [Southwest, Northeast]
14
+ SHANGHAI_BOUNDS: [
15
+ [120.80, 30.60], // Southwest
16
+ [122.50, 31.90] // Northeast
17
+ ]
18
+ };
19
+
20
+ // ==========================================
21
+ // 2. Globals
22
+ // ==========================================
23
+ let chartInstance = null;
24
+ let predictionChartInstance = null;
25
+ let currentMarker = null;
26
+ let mapInstance = null;
27
+ let globalStationData = [];
28
+ let animationFrameId = null;
29
+ let isPredictionMode = false;
30
+ let predictionMarker = null;
31
+ let optimalMarker = null;
32
+
33
+ // ==========================================
34
+ // 3. API Logic
35
+ // ==========================================
36
+ async function fetchLocations() {
37
+ console.log("Requesting backend data...");
38
+ const res = await fetch(`${CONFIG.API_BASE}/stations/locations`);
39
+ if (!res.ok) throw new Error(`API Error: ${res.status}`);
40
+ return await res.json();
41
+ }
42
+
43
+ async function fetchStationDetail(id) {
44
+ try {
45
+ const res = await fetch(`${CONFIG.API_BASE}/stations/detail/${id}`);
46
+ return await res.json();
47
+ } catch (e) {
48
+ console.error("Fetch Detail Error:", e);
49
+ return null;
50
+ }
51
+ }
52
+
53
+ // Fetch AI Prediction Data
54
+ async function fetchPrediction(id) {
55
+ try {
56
+ const res = await fetch(`${CONFIG.API_BASE}/predict/${id}?t=${Date.now()}`);
57
+ const data = await res.json();
58
+ if (data.error) throw new Error(data.error);
59
+ return data;
60
+ } catch (e) {
61
+ console.error("Prediction API Error:", e);
62
+ alert("Prediction failed: " + e.message);
63
+ return null;
64
+ }
65
+ }
66
+
67
+ function loadSatellitePatch(lng, lat) {
68
+ // Logic for loading static satellite imagery patch
69
+ const img = document.getElementById('satellite-patch');
70
+ const placeholder = document.getElementById('sat-placeholder');
71
+ if(!img) return;
72
+
73
+ img.style.display = 'none';
74
+ placeholder.style.display = 'flex';
75
+ placeholder.innerHTML = '<p>Loading...</p>';
76
+
77
+ img.src = `https://api.mapbox.com/styles/v1/mapbox/satellite-v9/static/${lng},${lat},16,0,0/320x200?access_token=${CONFIG.MAPBOX_TOKEN}`;
78
+ img.onload = () => { img.style.display = 'block'; placeholder.style.display = 'none'; };
79
+ }
80
+
81
+ // ==========================================
82
+ // 4. Chart Logic (Normal & Prediction)
83
+ // ==========================================
84
+ function renderChart(recordData) {
85
+ const ctx = document.getElementById('energyChart').getContext('2d');
86
+ if (chartInstance) chartInstance.destroy();
87
+
88
+ chartInstance = new Chart(ctx, {
89
+ type: 'line',
90
+ data: {
91
+ labels: recordData.map((_, i) => i),
92
+ datasets: [
93
+ {
94
+ label: 'Traffic', data: recordData,
95
+ borderColor: '#00cec9', backgroundColor: 'rgba(0, 206, 201, 0.1)',
96
+ borderWidth: 1.5, fill: true, pointRadius: 0, tension: 0.3
97
+ },
98
+ {
99
+ label: 'Current', data: [], type: 'scatter',
100
+ pointRadius: 6, pointBackgroundColor: '#ffffff',
101
+ pointBorderColor: '#e84393', pointBorderWidth: 3
102
+ }
103
+ ]
104
+ },
105
+ options: {
106
+ responsive: true, maintainAspectRatio: false, animation: false,
107
+ plugins: { legend: { display: false } },
108
+ scales: { x: { display: false }, y: { grid: { color: 'rgba(255,255,255,0.05)' }, ticks: { color: '#64748b', font: {size: 10} } } }
109
+ }
110
+ });
111
+ }
112
+
113
+ function updateChartCursor(timeIndex) {
114
+ if (chartInstance && chartInstance.data.datasets[0].data.length > timeIndex) {
115
+ const yValue = chartInstance.data.datasets[0].data[timeIndex];
116
+ chartInstance.data.datasets[1].data = [{x: timeIndex, y: yValue}];
117
+ chartInstance.update('none');
118
+ }
119
+ }
120
+
121
+ // Render AI Prediction Comparison Chart
122
+ function renderPredictionChart(realData, predData) {
123
+ const canvas = document.getElementById('predictionChart');
124
+ if (!canvas) return;
125
+ const ctx = canvas.getContext('2d');
126
+
127
+ if (predictionChartInstance) {
128
+ predictionChartInstance.destroy();
129
+ }
130
+
131
+ // Generate X-axis labels (e.g., H0, H1...)
132
+ const labels = realData.map((_, i) => `H${i}`);
133
+
134
+ predictionChartInstance = new Chart(ctx, {
135
+ type: 'line',
136
+ data: {
137
+ labels: labels,
138
+ datasets: [
139
+ {
140
+ label: 'Real Traffic',
141
+ data: realData,
142
+ borderColor: 'rgba(0, 206, 201, 0.8)', // Cyan
143
+ backgroundColor: 'rgba(0, 206, 201, 0.1)',
144
+ borderWidth: 1.5,
145
+ pointRadius: 0,
146
+ fill: true,
147
+ tension: 0.3
148
+ },
149
+ {
150
+ label: 'AI Prediction',
151
+ data: predData,
152
+ borderColor: '#f39c12', // Orange
153
+ backgroundColor: 'transparent',
154
+ borderWidth: 2,
155
+ borderDash: [5, 5], // Dashed line effect
156
+ pointRadius: 0,
157
+ fill: false,
158
+ tension: 0.3
159
+ }
160
+ ]
161
+ },
162
+ options: {
163
+ responsive: true,
164
+ maintainAspectRatio: false,
165
+ interaction: {
166
+ mode: 'index',
167
+ intersect: false, // Tooltip shows both values simultaneously
168
+ },
169
+ plugins: {
170
+ legend: {
171
+ display: true,
172
+ labels: { color: '#e0e0e0', font: { size: 10 } }
173
+ }
174
+ },
175
+ scales: {
176
+ x: {
177
+ display: true,
178
+ grid: { color: 'rgba(255,255,255,0.05)' },
179
+ ticks: { color: '#64748b', font: {size: 9}, maxTicksLimit: 14 }
180
+ },
181
+ y: {
182
+ grid: { color: 'rgba(255,255,255,0.1)' },
183
+ ticks: { color: '#888', font: {size: 10} },
184
+ beginAtZero: true
185
+ }
186
+ }
187
+ }
188
+ });
189
+ }
190
+
191
+ // ==========================================
192
+ // 5. Map Manager
193
+ // ==========================================
194
+ function initMap() {
195
+ mapboxgl.accessToken = CONFIG.MAPBOX_TOKEN;
196
+ mapInstance = new mapboxgl.Map({
197
+ container: 'map',
198
+ style: 'mapbox://styles/mapbox/satellite-streets-v12',
199
+ center: CONFIG.DEFAULT_CENTER,
200
+ zoom: CONFIG.DEFAULT_ZOOM,
201
+ pitch: 60,
202
+ bearing: -15,
203
+ antialias: true,
204
+ maxBounds: CONFIG.SHANGHAI_BOUNDS,
205
+ minZoom: 9
206
+ });
207
+ mapInstance.addControl(new mapboxgl.NavigationControl(), 'top-right');
208
+ return mapInstance;
209
+ }
210
+
211
+ function setupMapEnvironment(map) {
212
+ map.addSource('mapbox-dem', {
213
+ 'type': 'raster-dem',
214
+ 'url': 'mapbox://mapbox.mapbox-terrain-dem-v1',
215
+ 'tileSize': 512,
216
+ 'maxzoom': 14 });
217
+
218
+ map.setTerrain({ 'source': 'mapbox-dem',
219
+ 'exaggeration': 1.5 });
220
+
221
+ map.addLayer({
222
+ 'id': 'sky',
223
+ 'type': 'sky',
224
+ 'paint': { 'sky-type': 'atmosphere', 'sky-atmosphere-sun': [0.0, 0.0], 'sky-atmosphere-sun-intensity': 15 }
225
+ });
226
+
227
+ if (map.setFog) {
228
+ map.setFog({ 'range': [0.5, 10],
229
+ 'color': '#240b36',
230
+ 'horizon-blend': 0.1,
231
+ 'high-color': '#0f172a',
232
+ 'space-color': '#000000',
233
+ 'star-intensity': 0.6 });
234
+ }
235
+
236
+ const labelLayerId = map.getStyle().layers.find(l => l.type === 'symbol' && l.layout['text-field']).id;
237
+ if (!map.getLayer('3d-buildings')) {
238
+ map.addLayer({
239
+ 'id': '3d-buildings', 'source': 'composite',
240
+ 'source-layer': 'building', 'filter': ['==', 'extrude', 'true'],
241
+ 'type': 'fill-extrusion', 'minzoom': 11,
242
+ 'paint': {
243
+ 'fill-extrusion-color': ['interpolate', ['linear'], ['get', 'height'], 0, '#0f0c29', 30, '#1e2a4a', 200, '#4b6cb7'],
244
+ 'fill-extrusion-height': ['get', 'height'], 'fill-extrusion-base': ['get', 'min_height'], 'fill-extrusion-opacity': 0.6
245
+ }
246
+ }, labelLayerId);
247
+ }
248
+ }
249
+
250
+ function updateGeoJSONData(map, stations, mode = 'avg', timeIndex = 0) {
251
+ const pointFeatures = [];
252
+ const polygonFeatures = [];
253
+ const r = 0.00025; // Marker radius
254
+
255
+ stations.forEach(s => {
256
+ const lng = s.loc[0], lat = s.loc[1];
257
+ let valH = (mode === 'avg') ? (s.val_h || 0) : ((s.vals && s.vals[timeIndex]) !== undefined ? s.vals[timeIndex] : 0);
258
+ let valC = (s.val_c !== undefined) ? s.val_c : 0;
259
+
260
+ const props = { id: s.id, load_avg: valH, load_std: valC };
261
+
262
+ pointFeatures.push({ type: 'Feature', geometry: {
263
+ type: 'Point', coordinates: [lng, lat] }, properties: props });
264
+ polygonFeatures.push({ type: 'Feature', geometry: {
265
+ type: 'Polygon', coordinates: [[ [lng-r, lat-r], [lng+r, lat-r], [lng+r, lat+r], [lng-r, lat+r], [lng-r, lat-r] ]] }, properties: props });
266
+ });
267
+
268
+ if (map.getSource('stations-points')) {
269
+ map.getSource('stations-points').setData({
270
+ type: 'FeatureCollection',
271
+ features: pointFeatures });
272
+
273
+ map.getSource('stations-polygons').setData({
274
+ type: 'FeatureCollection',
275
+ features: polygonFeatures });
276
+ }
277
+ return { points: { type: 'FeatureCollection', features: pointFeatures }, polys: { type: 'FeatureCollection', features: polygonFeatures } };
278
+ }
279
+
280
+ function addStationLayers(map, geoData, statsLoad, statsColor) {
281
+ map.addSource('stations-points', { type: 'geojson', data: geoData.points });
282
+ map.addSource('stations-polygons', { type: 'geojson', data: geoData.polys });
283
+
284
+ map.addLayer({
285
+ id: 'stations-heatmap', type: 'heatmap', source: 'stations-points', maxzoom: 14,
286
+ paint: {
287
+ 'heatmap-weight': ['interpolate', ['linear'], ['get', 'load_avg'], statsLoad.min, 0, statsLoad.max, 1],
288
+ 'heatmap-intensity': ['interpolate', ['linear'], ['zoom'], 0, 1, 13, 3],
289
+ 'heatmap-color': ['interpolate', ['linear'], ['heatmap-density'], 0, 'rgba(0,0,0,0)', 0.2, '#0984e3', 0.4, '#00cec9', 0.6, '#a29bfe', 0.8, '#fd79a8', 1, '#ffffff'],
290
+ 'heatmap-radius': ['interpolate', ['linear'], ['zoom'], 0, 2, 13, 25],
291
+ 'heatmap-opacity': ['interpolate', ['linear'], ['zoom'], 12, 1, 14, 0]
292
+ }
293
+ });
294
+
295
+ map.addLayer({
296
+ id: 'stations-2d-dots', type: 'circle', source: 'stations-points', minzoom: 12,
297
+ paint: {
298
+ 'circle-radius': 3,
299
+ 'circle-color': ['step', ['get', 'load_std'], '#1e1e2e', statsColor.t1, '#0984e3', statsColor.t2, '#00cec9', statsColor.t3, '#fd79a8', statsColor.t4, '#e84393'],
300
+ 'circle-stroke-width': 1, 'circle-stroke-color': '#fff', 'circle-opacity': 0.8
301
+ }
302
+ });
303
+
304
+ map.addLayer({
305
+ id: 'stations-3d-pillars', type: 'fill-extrusion', source: 'stations-polygons', minzoom: 12,
306
+ paint: {
307
+ 'fill-extrusion-color': ['step', ['get', 'load_std'], '#1e1e2e', statsColor.t1, '#0984e3', statsColor.t2, '#00cec9', statsColor.t3, '#fd79a8', statsColor.t4, '#e84393'],
308
+ 'fill-extrusion-height': ['interpolate', ['linear'], ['get', 'load_avg'], 0, 0, statsLoad.min, 5, statsLoad.max, 300],
309
+ 'fill-extrusion-opacity': 0.7
310
+ }
311
+ });
312
+
313
+ map.addLayer({ id: 'stations-hitbox', type: 'circle', source: 'stations-points',
314
+ paint: { 'circle-radius': 10, 'circle-color': 'transparent', 'circle-opacity': 0 } });
315
+ }
316
+
317
+ // ==========================================
318
+ // 6. Map Interactions
319
+ // ==========================================
320
+ function setupInteraction(map) {
321
+ const popup = new mapboxgl.Popup({ closeButton: false, closeOnClick: false, className: 'cyber-popup' });
322
+
323
+ map.on('mouseenter', 'stations-hitbox', (e) => {
324
+ map.getCanvas().style.cursor = 'pointer';
325
+ if (isPredictionMode) return;
326
+
327
+ const props = e.features[0].properties;
328
+ const coordinates = e.features[0].geometry.coordinates.slice();
329
+
330
+ while (Math.abs(e.lngLat.lng - coordinates[0]) > 180) { coordinates[0] += e.lngLat.lng > coordinates[0] ? 360 : -360; }
331
+
332
+ popup.setLngLat(coordinates)
333
+ .setHTML(`
334
+ <div style="font-weight:bold; color:#fff; border-bottom:1px solid #444; padding-bottom:2px; margin-bottom:2px;">Station ${props.id}</div>
335
+ <div style="color:#00cec9;">Load: <span style="color:#fff;">${props.load_avg.toFixed(2)}</span></div>
336
+ <div style="color:#fd79a8;">Stability: <span style="color:#fff;">${props.load_std.toFixed(4)}</span></div>
337
+ `).addTo(map);
338
+ });
339
+
340
+ map.on('mouseleave', 'stations-hitbox', () => {
341
+ if (!isPredictionMode) map.getCanvas().style.cursor = '';
342
+ popup.remove();
343
+ });
344
+
345
+ // Core Interaction Logic
346
+ map.on('click', 'stations-hitbox', async (e) => {
347
+ const coordinates = e.features[0].geometry.coordinates.slice();
348
+ const id = e.features[0].properties.id;
349
+
350
+ // 1. Prediction Mode Logic
351
+ if (isPredictionMode) {
352
+ const predPanel = document.getElementById('prediction-panel');
353
+ const predIdDisplay = document.getElementById('pred-station-id');
354
+ const siteMapContainer = document.getElementById('site-map-container');
355
+ const siteMapImg = document.getElementById('site-map-img');
356
+
357
+ predPanel.classList.add('active');
358
+
359
+ const rightBtn = document.getElementById('toggle-right-btn');
360
+ if (rightBtn) rightBtn.classList.add('active');
361
+
362
+ predIdDisplay.innerText = `${id} (Calculating...)`;
363
+
364
+ // Clear previous optimal site marker when a new station is clicked
365
+ if (optimalMarker) {
366
+ optimalMarker.remove();
367
+ optimalMarker = null;
368
+ }
369
+
370
+ // Drop orange selection pin and draw 3x3 grid
371
+ if (!predictionMarker) {
372
+ predictionMarker = new mapboxgl.Marker({ color: '#f39c12' })
373
+ .setLngLat(coordinates).addTo(map);
374
+ } else {
375
+ predictionMarker.setLngLat(coordinates);
376
+ }
377
+ updatePredictionGrid(map, coordinates[0], coordinates[1]);
378
+
379
+ if (siteMapContainer) siteMapContainer.style.display = 'none';
380
+ if (siteMapImg) siteMapImg.src = '';
381
+
382
+ if(predictionChartInstance) {
383
+ predictionChartInstance.destroy();
384
+ predictionChartInstance = null;
385
+ }
386
+
387
+ // Call Prediction API
388
+ const result = await fetchPrediction(id);
389
+ if(result && result.status === "success") {
390
+ predIdDisplay.innerText = id;
391
+ renderPredictionChart(result.real, result.prediction);
392
+
393
+ // Render returned Base64 site heatmap and mark optimal location
394
+ if (result.site_map_b64 && siteMapContainer && siteMapImg) {
395
+ siteMapImg.src = `data:image/png;base64,${result.site_map_b64}`;
396
+ siteMapContainer.style.display = 'block';
397
+
398
+ // Typewriter Effect for AI Explanation
399
+ const explanationBox = document.getElementById('site-explanation');
400
+ if (explanationBox && result.explanation) {
401
+ explanationBox.style.display = 'block';
402
+
403
+ // Reset content and add blinking cursor
404
+ explanationBox.innerHTML = `<strong>> SYSTEM LOG: AI DECISION</strong><br><span id="typewriter-text"></span><span class="cursor" style="animation: blink 1s step-end infinite;">_</span>`;
405
+
406
+ const textTarget = document.getElementById('typewriter-text');
407
+ const fullText = result.explanation;
408
+ let charIndex = 0;
409
+
410
+ function typeWriter() {
411
+ if (charIndex < fullText.length) {
412
+ textTarget.innerHTML += fullText.charAt(charIndex);
413
+ charIndex++;
414
+ // Randomize typing speed for realistic terminal feel
415
+ setTimeout(typeWriter, Math.random() * 20 + 10);
416
+ }
417
+ }
418
+ typeWriter();
419
+ }
420
+
421
+ // Mark green optimal Pin on physical map coordinates
422
+ if (result.best_loc) {
423
+
424
+ // Remove orange marker to avoid overlap
425
+ if (predictionMarker) {
426
+ predictionMarker.remove();
427
+ predictionMarker = null;
428
+ }
429
+
430
+ // Create custom "Green Pulse" DOM element defined in CSS
431
+ const customPin = document.createElement('div');
432
+ customPin.className = 'optimal-pulse-pin';
433
+
434
+ optimalMarker = new mapboxgl.Marker(customPin)
435
+ .setLngLat(result.best_loc)
436
+ .setPopup(new mapboxgl.Popup({ offset: 25, closeButton: false, className: 'cyber-popup' })
437
+ .setHTML('<div style="color:#2ecc71; font-weight:bold; font-size:14px;">🌟 Best LSI Site</div>'))
438
+ .addTo(map);
439
+
440
+ optimalMarker.togglePopup();
441
+
442
+ // Smoothly fly to the optimal site location
443
+ map.flyTo({
444
+ center: result.best_loc,
445
+ zoom: 16.5,
446
+ speed: 1.2
447
+ });
448
+ }
449
+ }
450
+ } else {
451
+ predIdDisplay.innerText = `${id} (Failed)`;
452
+ }
453
+ return;
454
+ }
455
+
456
+ // 2. Standard Detail Mode Logic
457
+ if (currentMarker) currentMarker.remove();
458
+ currentMarker = new mapboxgl.Marker().setLngLat(coordinates).addTo(map);
459
+
460
+ const pitch = map.getPitch();
461
+ map.flyTo({ center: coordinates, zoom: 15, pitch: pitch > 10 ? 60 : 0, speed: 1.5 });
462
+
463
+ document.getElementById('selected-id').innerText = id;
464
+
465
+ try {
466
+ document.getElementById('station-details').innerHTML = '<p class="placeholder-text">Loading details...</p>';
467
+
468
+ const detailData = await fetchStationDetail(id);
469
+ if (detailData) {
470
+ const stats = detailData.stats || {avg:0, std:0};
471
+
472
+ document.getElementById('station-details').innerHTML =
473
+ `<div style="margin-top:10px;">
474
+ <p><strong>Longitude:</strong> ${detailData.loc[0].toFixed(4)}</p>
475
+ <p><strong>Latitude:</strong> ${detailData.loc[1].toFixed(4)}</p>
476
+ <hr style="border:0; border-top:1px solid #444; margin:5px 0;">
477
+ <p><strong>Avg Load:</strong> <span style="color:#00cec9">${stats.avg.toFixed(4)}</span></p>
478
+ <p><strong>Stability:</strong> <span style="color:#fd79a8">${stats.std.toFixed(4)}</span></p>
479
+ </div>`;
480
+
481
+ if (detailData.bs_record) {
482
+ renderChart(detailData.bs_record);
483
+ }
484
+ }
485
+ } catch (err) {
486
+ console.error("Failed to fetch clicked station details:", err);
487
+ document.getElementById('station-details').innerHTML = '<p style="color:red">Error loading data</p>';
488
+ }
489
+ });
490
+ }
491
+
492
+ // Prediction Mode State Control
493
+ function setupPredictionMode(map) {
494
+ const predictBtn = document.getElementById('predict-toggle');
495
+ const predPanel = document.getElementById('prediction-panel');
496
+ const closePredBtn = document.getElementById('close-pred-btn');
497
+
498
+ if (!predictBtn) return;
499
+
500
+ predictBtn.addEventListener('click', () => {
501
+ // Enforce 2D view check for prediction mode
502
+ const pitch = map.getPitch();
503
+ if (pitch > 10) {
504
+ alert("Prediction Mode is only available in 2D View. Please switch to 2D first.");
505
+ return;
506
+ }
507
+
508
+ isPredictionMode = !isPredictionMode;
509
+
510
+ if (isPredictionMode) {
511
+ predictBtn.classList.add('predict-on');
512
+ predictBtn.innerHTML = '<span class="icon">🔮</span> Mode: ON';
513
+ map.getCanvas().style.cursor = 'crosshair';
514
+ } else {
515
+ predictBtn.classList.remove('predict-on');
516
+ predictBtn.innerHTML = '<span class="icon">🔮</span> Prediction Mode';
517
+ map.getCanvas().style.cursor = '';
518
+ predPanel.classList.remove('active');
519
+
520
+ // Reset UI state when exiting prediction
521
+ predPanel.classList.remove('collapsed');
522
+ const rightBtn = document.getElementById('toggle-right-btn');
523
+ if(rightBtn) {
524
+ rightBtn.innerText = '▶';
525
+ rightBtn.classList.remove('active');
526
+ rightBtn.classList.remove('collapsed');
527
+ }
528
+
529
+ // Clear markers and grids
530
+ clearPredictionExtras(map);
531
+ }
532
+ });
533
+
534
+ if (closePredBtn) {
535
+ closePredBtn.addEventListener('click', () => {
536
+ predPanel.classList.remove('active');
537
+ const rightBtn = document.getElementById('toggle-right-btn');
538
+ if (rightBtn) rightBtn.classList.remove('active');
539
+ predictBtn.click(); // Trigger toggle to clean up state
540
+ });
541
+ }
542
+ }
543
+
544
+ // === 新增:绘制 AI 模型的 3x3 空间感知网格 ===
545
+ // function updatePredictionGrid(map, centerLng, centerLat) {
546
+ // const features = [];
547
+ // const step = 0.002; // 与后端 Python 的 offset = 0.002 对齐
548
+ // const gridSize = 3;
549
+ // const offset = Math.floor(gridSize / 2);
550
+
551
+ // for (let i = 0; i < gridSize; i++) {
552
+ // for (let j = 0; j < gridSize; j++) {
553
+ // // 精确还原 Python 中切片的中心坐标
554
+ // const cLng = centerLng + (j - offset) * step;
555
+ // const cLat = centerLat - (i - offset) * step;
556
+ // const w = step / 2;
557
+
558
+ // features.push({
559
+ // 'type': 'Feature',
560
+ // 'geometry': {
561
+ // 'type': 'Polygon',
562
+ // 'coordinates': [[
563
+ // [cLng - w, cLat - w], [cLng + w, cLat - w],
564
+ // [cLng + w, cLat + w], [cLng - w, cLat + w],
565
+ // [cLng - w, cLat - w]
566
+ // ]]
567
+ // }
568
+ // });
569
+ // }
570
+ // }
571
+
572
+ // const geojson = { 'type': 'FeatureCollection', 'features': features };
573
+
574
+ // if (map.getSource('pred-grid-source')) {
575
+ // map.getSource('pred-grid-source').setData(geojson);
576
+ // } else {
577
+ // map.addSource('pred-grid-source', { type: 'geojson', data: geojson });
578
+ // map.addLayer({
579
+ // 'id': 'pred-grid-fill', 'type': 'fill', 'source': 'pred-grid-source',
580
+ // 'paint': { 'fill-color': '#f39c12', 'fill-opacity': 0.1 }
581
+ // });
582
+ // map.addLayer({
583
+ // 'id': 'pred-grid-line', 'type': 'line', 'source': 'pred-grid-source',
584
+ // 'paint': { 'line-color': '#f39c12', 'line-width': 2, 'line-dasharray': [2, 2] }
585
+ // });
586
+ // }
587
+ // }
588
+
589
+ // Dynamic 3x3 grid matching the 256px satellite patch bounds
590
+ function updatePredictionGrid(map, centerLng, centerLat) {
591
+ const features = [];
592
+ const gridSize = 3;
593
+ const offset = Math.floor(gridSize / 2);
594
+
595
+ // Precise Web Mercator projection span calculation at Zoom 15
596
+ const zoom = 15;
597
+
598
+ // Total Longitude span for 256 pixels at this zoom
599
+ const lonSpan = 360 / Math.pow(2, zoom);
600
+ // Latitude span (scaled by local latitude)
601
+ const latSpan = lonSpan * Math.cos(centerLat * Math.PI / 180);
602
+
603
+ // Actual step sizes for 3x3 division
604
+ const stepLon = lonSpan / gridSize;
605
+ const stepLat = latSpan / gridSize;
606
+
607
+ for (let i = 0; i < gridSize; i++) {
608
+ for (let j = 0; j < gridSize; j++) {
609
+ // Center point of each micro-grid cell
610
+ const cLng = centerLng + (j - offset) * stepLon;
611
+ const cLat = centerLat - (i - offset) * stepLat;
612
+
613
+ const wLon = stepLon / 2;
614
+ const wLat = stepLat / 2;
615
+
616
+ features.push({
617
+ 'type': 'Feature',
618
+ 'geometry': {
619
+ 'type': 'Polygon',
620
+ 'coordinates': [[
621
+ [cLng - wLon, cLat - wLat], [cLng + wLon, cLat - wLat],
622
+ [cLng + wLon, cLat + wLat], [cLng - wLon, cLat + wLat],
623
+ [cLng - wLon, cLat - wLat]
624
+ ]]
625
+ }
626
+ });
627
+ }
628
+ }
629
+
630
+ const geojson = { 'type': 'FeatureCollection', 'features': features };
631
+
632
+ if (map.getSource('pred-grid-source')) {
633
+ map.getSource('pred-grid-source').setData(geojson);
634
+ } else {
635
+ map.addSource('pred-grid-source', { type: 'geojson', data: geojson });
636
+ map.addLayer({
637
+ 'id': 'pred-grid-fill', 'type': 'fill', 'source': 'pred-grid-source',
638
+ 'paint': { 'fill-color': '#f39c12', 'fill-opacity': 0.1 }
639
+ });
640
+ map.addLayer({
641
+ 'id': 'pred-grid-line', 'type': 'line', 'source': 'pred-grid-source',
642
+ 'paint': { 'line-color': '#f39c12', 'line-width': 2, 'line-dasharray': [2, 2] }
643
+ });
644
+ }
645
+ }
646
+
647
+ // Cleanup prediction visual elements
648
+ function clearPredictionExtras(map) {
649
+ if (predictionMarker) { predictionMarker.remove(); predictionMarker = null; }
650
+ if (optimalMarker) { optimalMarker.remove(); optimalMarker = null; } // ====== 新增:清理绿色点 ======
651
+ if (map.getSource('pred-grid-source')) {
652
+ map.getSource('pred-grid-source').setData({ type: 'FeatureCollection', features: [] });
653
+ }
654
+ }
655
+
656
+ // ==========================================
657
+ // 7. Timeline Logic
658
+ // ==========================================
659
+ function setupTimeLapse(map, globalData) {
660
+ const playBtn = document.getElementById('play-btn');
661
+ const slider = document.getElementById('time-slider');
662
+ const display = document.getElementById('time-display');
663
+ if (!playBtn || !slider) return;
664
+
665
+ const totalHours = (globalData.length > 0 && globalData[0].vals) ? globalData[0].vals.length : 672;
666
+ slider.max = totalHours - 1;
667
+ let isPlaying = false;
668
+ let speed = 100;
669
+
670
+ const updateTime = (val) => {
671
+ const day = Math.floor(val / 24) + 1;
672
+ const hour = val % 24;
673
+ display.innerText = `Day ${day.toString().padStart(2, '0')} - ${hour.toString().padStart(2, '0')}:00`;
674
+
675
+ updateGeoJSONData(map, globalData, 'time', val);
676
+ updateChartCursor(val);
677
+ };
678
+
679
+ const play = () => {
680
+ let val = parseInt(slider.value);
681
+ val = (val + 1) % totalHours;
682
+ slider.value = val;
683
+ updateTime(val);
684
+ if (isPlaying) animationFrameId = setTimeout(() => requestAnimationFrame(play), speed);
685
+ };
686
+
687
+ playBtn.onclick = () => {
688
+ isPlaying = !isPlaying;
689
+ playBtn.innerText = isPlaying ? '⏸' : '▶';
690
+ if (isPlaying) play(); else clearTimeout(animationFrameId);
691
+ };
692
+
693
+ slider.oninput = (e) => {
694
+ isPlaying = false;
695
+ if(animationFrameId) clearTimeout(animationFrameId);
696
+ playBtn.innerText = '▶';
697
+ updateTime(parseInt(e.target.value));
698
+ };
699
+ }
700
+
701
+ // ==========================================
702
+ // 8. UI Controls
703
+ // ==========================================
704
+ function setupModeToggle(map) {
705
+ const btn = document.getElementById('view-toggle');
706
+ const timePanel = document.querySelector('.time-panel');
707
+ let is3D = true;
708
+
709
+ if (!btn) return;
710
+
711
+ btn.onclick = () => {
712
+ // Prevent switching to 3D mode if Prediction Mode is active
713
+ if (isPredictionMode) {
714
+ alert("Please exit Prediction Mode before switching to 3D.");
715
+ return;
716
+ }
717
+
718
+ is3D = !is3D;
719
+ if (is3D) {
720
+ // Switch to 3D View: Show pillars and tilt camera
721
+ if(map.getLayer('stations-3d-pillars')) map.setLayoutProperty('stations-3d-pillars', 'visibility', 'visible');
722
+ map.easeTo({ pitch: 60, bearing: -15 });
723
+ btn.innerHTML = '<span class="icon">👁️</span> View: 3D';
724
+ if (timePanel) {
725
+ timePanel.style.display = 'flex';
726
+ setTimeout(() => { timePanel.style.opacity = '1'; }, 10);
727
+ }
728
+ } else {
729
+ // Switch to 2D View: Hide pillars and reset camera pitch
730
+ if(map.getLayer('stations-3d-pillars')) map.setLayoutProperty('stations-3d-pillars', 'visibility', 'none');
731
+ map.easeTo({ pitch: 0, bearing: 0 });
732
+ btn.innerHTML = '<span class="icon">🗺️</span> View: 2D';
733
+ if (timePanel) {
734
+ timePanel.style.display = 'none';
735
+ timePanel.style.opacity = '0';
736
+ }
737
+ // Stop timelapse playback when entering 2D mode
738
+ const playBtn = document.getElementById('play-btn');
739
+ if (playBtn && playBtn.innerText === '⏸') playBtn.click();
740
+ }
741
+ };
742
+ }
743
+
744
+ function setupDataToggle(map) {
745
+ const btn = document.getElementById('data-toggle');
746
+ const layers = ['stations-3d-pillars', 'stations-2d-dots', 'stations-heatmap', 'stations-hitbox'];
747
+ let isVisible = true;
748
+ if(btn) btn.onclick = () => {
749
+ isVisible = !isVisible;
750
+ const val = isVisible ? 'visible' : 'none';
751
+ layers.forEach(id => { if(map.getLayer(id)) map.setLayoutProperty(id, 'visibility', val); });
752
+ btn.innerHTML = isVisible ? '<span class="icon">📡</span> Toggle Data' : '<span class="icon">🚫</span> Toggle Data';
753
+ btn.style.opacity = isVisible ? '1' : '0.6';
754
+ };
755
+ }
756
+
757
+ function setupFilterMenu(map, statsColor) {
758
+ const btn = document.getElementById('filter-btn');
759
+ const menu = document.getElementById('filter-menu');
760
+ if (!btn || !menu) return;
761
+
762
+ // Define stability levels based on Standard Deviation thresholds
763
+ const levels = [
764
+ { label: "Level 5: Highly Unstable", color: "#e84393", filter: ['>=', 'load_std', statsColor.t4] },
765
+ { label: "Level 4: Volatile", color: "#fd79a8", filter: ['all', ['>=', 'load_std', statsColor.t3], ['<', 'load_std', statsColor.t4]] },
766
+ { label: "Level 3: Normal", color: "#00cec9", filter: ['all', ['>=', 'load_std', statsColor.t2], ['<', 'load_std', statsColor.t3]] },
767
+ { label: "Level 2: Stable", color: "#0984e3", filter: ['all', ['>=', 'load_std', statsColor.t1], ['<', 'load_std', statsColor.t2]] },
768
+ { label: "Level 1: Highly Stable", color: "#1e1e2e", filter: ['<', 'load_std', statsColor.t1] }
769
+ ];
770
+
771
+ menu.innerHTML = '';
772
+ levels.forEach((lvl) => {
773
+ const item = document.createElement('div');
774
+ item.className = 'filter-item';
775
+ item.innerHTML = `<div class="color-box" style="background:${lvl.color}; box-shadow: 0 0 5px ${lvl.color};"></div><span>${lvl.label}</span>`;
776
+ item.onclick = (e) => {
777
+ e.stopPropagation();
778
+ if (item.classList.contains('selected')) {
779
+ item.classList.remove('selected');
780
+ applyFilter(map, null);
781
+ } else {
782
+ document.querySelectorAll('.filter-item').forEach(el => el.classList.remove('selected'));
783
+ item.classList.add('selected');
784
+ applyFilter(map, lvl.filter);
785
+ }
786
+ };
787
+ menu.appendChild(item);
788
+ });
789
+
790
+ // Toggle menu visibility
791
+ btn.onclick = (e) => { e.stopPropagation(); menu.classList.toggle('active'); };
792
+ document.addEventListener('click', (e) => { if (!menu.contains(e.target) && !btn.contains(e.target)) menu.classList.remove('active'); });
793
+ }
794
+
795
+ function applyFilter(map, filterExpression) {
796
+ const targetLayers = ['stations-3d-pillars', 'stations-2d-dots', 'stations-heatmap', 'stations-hitbox'];
797
+ targetLayers.forEach(layerId => { if (map.getLayer(layerId)) map.setFilter(layerId, filterExpression); });
798
+ }
799
+
800
+ function setupSearch(map, globalData) {
801
+ const input = document.getElementById('search-input');
802
+ const btn = document.getElementById('search-btn');
803
+ const clearBtn = document.getElementById('clear-search-btn');
804
+ const keepCheck = document.getElementById('keep-markers-check');
805
+
806
+ if (!input || !btn) return;
807
+
808
+ let searchMarkers = [];
809
+
810
+ const clearAllMarkers = () => {
811
+ searchMarkers.forEach(marker => marker.remove());
812
+ searchMarkers = [];
813
+ };
814
+
815
+ const performSearch = async () => {
816
+ const queryId = input.value.trim();
817
+ if (!queryId) return;
818
+
819
+ const target = globalData.find(s => String(s.id) === String(queryId));
820
+
821
+ if (target) {
822
+ if (!keepCheck.checked) {
823
+ clearAllMarkers();
824
+ }
825
+
826
+ // Fly to searched station and switch to high-detail view
827
+ map.flyTo({
828
+ center: target.loc,
829
+ zoom: 16,
830
+ pitch: 60,
831
+ essential: true
832
+ });
833
+
834
+ document.getElementById('selected-id').innerText = target.id;
835
+ try {
836
+ const detailData = await fetchStationDetail(target.id);
837
+ if (detailData) {
838
+ const stats = detailData.stats || {avg:0, std:0};
839
+ document.getElementById('station-details').innerHTML =
840
+ `<div style="margin-top:10px;">
841
+ <p><strong>Longitude:</strong> ${detailData.loc[0].toFixed(4)}</p>
842
+ <p><strong>Latitude:</strong> ${detailData.loc[1].toFixed(4)}</p>
843
+ <hr style="border:0; border-top:1px solid #444; margin:5px 0;">
844
+ <p><strong>Avg Load:</strong> <span style="color:#00cec9">${stats.avg.toFixed(4)}</span></p>
845
+ <p><strong>Stability:</strong> <span style="color:#fd79a8">${stats.std.toFixed(4)}</span></p>
846
+ </div>`;
847
+
848
+ if (detailData.bs_record) renderChart(detailData.bs_record);
849
+ }
850
+ } catch (e) {
851
+ console.error("Fetch details failed", e);
852
+ }
853
+
854
+ // Create red highlight marker for searched target
855
+ const marker = new mapboxgl.Marker({ color: '#ff0000', scale: 0.8 })
856
+ .setLngLat(target.loc)
857
+ .setPopup(new mapboxgl.Popup({ offset: 25 }).setText(`Station ID: ${target.id}`))
858
+ .addTo(map);
859
+
860
+ searchMarkers.push(marker);
861
+
862
+ } else {
863
+ alert("Station ID not found!");
864
+ }
865
+ };
866
+
867
+ btn.onclick = performSearch;
868
+
869
+ input.addEventListener('keypress', (e) => {
870
+ if (e.key === 'Enter') performSearch();
871
+ });
872
+
873
+ if (clearBtn) {
874
+ clearBtn.onclick = () => {
875
+ clearAllMarkers();
876
+ input.value = '';
877
+ };
878
+ }
879
+ }
880
+
881
+ // Sidebar & Panel Toggle Logic
882
+ function setupPanelToggles(map) {
883
+ const leftSidebar = document.querySelector('.sidebar');
884
+ const leftToggleBtn = document.getElementById('toggle-left-btn');
885
+
886
+ if (leftToggleBtn && leftSidebar) {
887
+ leftToggleBtn.addEventListener('click', () => {
888
+ leftSidebar.classList.toggle('collapsed');
889
+ leftToggleBtn.classList.toggle('collapsed');
890
+ leftToggleBtn.innerText = leftSidebar.classList.contains('collapsed') ? '▶' : '◀';
891
+ setTimeout(() => map.resize(), 300);
892
+ });
893
+ }
894
+
895
+ const rightSidebar = document.getElementById('prediction-panel');
896
+ const rightToggleBtn = document.getElementById('toggle-right-btn');
897
+
898
+ if (rightToggleBtn && rightSidebar) {
899
+ rightToggleBtn.addEventListener('click', () => {
900
+ rightSidebar.classList.toggle('collapsed');
901
+ rightToggleBtn.classList.toggle('collapsed');
902
+ rightToggleBtn.innerText = rightSidebar.classList.contains('collapsed') ? '◀' : '▶';
903
+ setTimeout(() => map.resize(), 300);
904
+ });
905
+ }
906
+ }
907
+
908
+ // ==========================================
909
+ // 9. Main Entry Point
910
+ // ==========================================
911
+ window.onload = async () => {
912
+ const map = initMap();
913
+
914
+ map.on('load', async () => {
915
+ setupMapEnvironment(map);
916
+
917
+ try {
918
+ // Load initial station metadata
919
+ const data = await fetchLocations();
920
+ globalStationData = data.stations;
921
+ document.getElementById('total-stations').innerText = globalStationData.length;
922
+
923
+ // Initialize Map Layers with empty data initially
924
+ addStationLayers(map,
925
+ {points: {type:'FeatureCollection', features:[]}, polys: {type:'FeatureCollection', features:[]} },
926
+ data.stats_height, data.stats_color);
927
+
928
+ // Immediately load data for T=0 (initial state)
929
+ updateGeoJSONData(map, globalStationData, 'time', 0);
930
+ updateChartCursor(0);
931
+
932
+ // Start Time Lapse
933
+ setupTimeLapse(map, globalStationData);
934
+
935
+ // Bind Interactions
936
+ setupPredictionMode(map); // Initialize AI Prediction events
937
+ setupInteraction(map); // Initialize standard map clicks/popups
938
+ setupModeToggle(map); // 2D/3D View switch
939
+ setupDataToggle(map); // Layer visibility switch
940
+ setupFilterMenu(map, data.stats_color); // Load-stability filters
941
+ setupSearch(map, globalStationData); // Search bar logic
942
+
943
+ // Initialize sidebar collapse/expand controls
944
+ setupPanelToggles(map);
945
+
946
+ // Remove Loading Screen
947
+ document.getElementById('loading').style.display = 'none';
948
+ } catch (e) {
949
+ console.error(e);
950
+ alert('System Initialization Failed. Check Console.');
951
+ document.getElementById('loading').innerHTML = '<h2>Error Loading Data</h2>';
952
+ }
953
+ });
954
+ };
server.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import math
5
+ from flask import Flask, jsonify, send_from_directory, request
6
+ from flask_cors import CORS
7
+
8
+ # Import the custom prediction backend module
9
+ try:
10
+ from prediction_backend import TrafficPredictor
11
+ except ImportError:
12
+ print("Warning: prediction_backend.py not found. Prediction features will be disabled.")
13
+ TrafficPredictor = None
14
+ except Exception as e:
15
+ print(f"Warning: Failed to import prediction_backend: {e}")
16
+ TrafficPredictor = None
17
+
18
+ # ==========================================
19
+ # Flask Server
20
+ # ==========================================
21
+ app = Flask(__name__, static_folder='.')
22
+ CORS(app)
23
+
24
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ # Data directory path
26
+ DATA_DIR = os.path.abspath(os.path.join(BASE_DIR, 'data'))
27
+
28
+ # File path configurations
29
+ JSON_PATH = os.path.join(DATA_DIR, 'base2info.json')
30
+ TRAFFIC_PATH = os.path.join(DATA_DIR, 'bs_record_energy_normalized_sampled.npz')
31
+ SPATIAL_PATH = os.path.join(DATA_DIR, 'spatial_features.npz')
32
+ MODEL_PATH = os.path.join(BASE_DIR, 'best_corr_model.pt')
33
+
34
+ # ==========================================
35
+ # Utility Functions
36
+ # ==========================================
37
+ def calculate_std_dev(records, avg):
38
+ """Calculates standard deviation for a given set of records and their average."""
39
+ if not records or len(records) < 2:
40
+ return 0
41
+ variance = sum((x - avg) ** 2 for x in records) / len(records)
42
+ return math.sqrt(variance)
43
+
44
+ def calculate_stats(data_list):
45
+ """Calculate global statistics for frontend normalization"""
46
+ print("Calculating statistical distribution (Avg & Std)...")
47
+ avgs = []
48
+ stds = []
49
+
50
+ for item in data_list:
51
+ records = item.get('bs_record', [])
52
+ if records:
53
+ avg = sum(records) / len(records)
54
+ std = calculate_std_dev(records, avg)
55
+ else:
56
+ avg = 0
57
+ std = 0
58
+ avgs.append(avg)
59
+ stds.append(std)
60
+
61
+ def get_percentiles(values):
62
+ """Calculates percentiles to create data brackets for visualization."""
63
+ values.sort()
64
+ n = len(values)
65
+ if n == 0: return {k:0 for k in ['min','max','t1','t2','t3','t4']}
66
+ return {
67
+ "min": values[0],
68
+ "max": values[-1],
69
+ "t1": values[int(n * 0.2)],
70
+ "t2": values[int(n * 0.4)],
71
+ "t3": values[int(n * 0.6)],
72
+ "t4": values[int(n * 0.8)]
73
+ }
74
+
75
+ stats_h = get_percentiles(avgs) # Statistics for pillar heights
76
+ stats_c = get_percentiles(stds) # Statistics for pillar colors (stability)
77
+ return stats_h, stats_c
78
+
79
+ def _convert_numpy_type(val):
80
+ if isinstance(val, np.ndarray): return val.tolist()
81
+ elif isinstance(val, (np.integer, np.int64, np.int32, np.int16)): return int(val)
82
+ elif isinstance(val, (np.floating, np.float64, np.float32)): return float(val)
83
+ elif isinstance(val, bytes): return val.decode('utf-8')
84
+ else: return val
85
+
86
+ def load_and_process_data(json_path, npz_path):
87
+ print(f"[DataLoader] Loading basic data...")
88
+ print(f" - JSON: {json_path}")
89
+ print(f" - Traffic NPZ : {npz_path}")
90
+
91
+ if not os.path.exists(json_path) or not os.path.exists(npz_path):
92
+ print("[DataLoader] Error: Input files not found.")
93
+ return []
94
+
95
+ try:
96
+ npz_data = np.load(npz_path)
97
+ with open(json_path, 'r', encoding='utf-8') as f:
98
+ json_map = json.load(f)
99
+ except Exception as e:
100
+ print(f"[DataLoader] Read error: {e}")
101
+ return []
102
+
103
+ # Handle binary strings if present in NPZ
104
+ raw_bs_ids = npz_data['bs_id']
105
+ bs_ids = [x.decode('utf-8') if isinstance(x, bytes) else str(x) for x in raw_bs_ids]
106
+ num_stations = len(bs_ids)
107
+
108
+ # Identify available time-series attributes in NPZ
109
+ station_attributes = []
110
+ for key in npz_data.files:
111
+ if key == 'bs_id': continue
112
+ if npz_data[key].shape[0] == num_stations:
113
+ station_attributes.append(key)
114
+
115
+ merged_data = []
116
+ match_count = 0
117
+
118
+ for i in range(num_stations):
119
+ current_id = bs_ids[i]
120
+ json_key = f"Base_{current_id}"
121
+
122
+ if json_key in json_map:
123
+ match_count += 1
124
+ entry = {
125
+ "id": current_id,
126
+ "npz_index": i, # Store original index for prediction lookups
127
+ "loc": json_map[json_key]["loc"]
128
+ }
129
+ for attr in station_attributes:
130
+ val = npz_data[attr][i]
131
+ entry[attr] = _convert_numpy_type(val)
132
+ merged_data.append(entry)
133
+
134
+ print(f"[DataLoader] Merge complete! Matched: {match_count}/{num_stations}")
135
+ return merged_data
136
+
137
+ # ==========================================
138
+ # Initialization Sequence
139
+ # ==========================================
140
+
141
+ print("Server Initializing...")
142
+
143
+ # 1. Load basic station data for frontend display
144
+ ALL_DATA = load_and_process_data(JSON_PATH, TRAFFIC_PATH)
145
+
146
+ STATS_HEIGHT = {}
147
+ STATS_COLOR = {}
148
+
149
+ if ALL_DATA:
150
+ STATS_HEIGHT, STATS_COLOR = calculate_stats(ALL_DATA)
151
+ else:
152
+ print("⚠️ CRITICAL WARNING: Data list is empty!")
153
+
154
+ # 2. Initialize AI Predictor with Spatial Features
155
+ predictor = None
156
+ if TrafficPredictor:
157
+ try:
158
+ print(f"[AI] Initializing Predictor with model: {MODEL_PATH}")
159
+ # Initialize the predictor using the model and spatial feature files
160
+ predictor = TrafficPredictor(
161
+ model_path=MODEL_PATH,
162
+ spatial_path=SPATIAL_PATH,
163
+ traffic_path=TRAFFIC_PATH
164
+ )
165
+ print("[AI] Predictor loaded successfully.")
166
+ except Exception as e:
167
+ print(f"[AI] Failed to load predictor: {e}")
168
+
169
+ # ==========================================
170
+ # API Routes
171
+ # ==========================================
172
+
173
+ @app.route('/')
174
+ def index():
175
+ """Serves the main dashboard page."""
176
+ return send_from_directory('.', 'index.html')
177
+
178
+ @app.route('/<path:path>')
179
+ def serve_static(path):
180
+ """Serves static assets (JS, CSS, Images)."""
181
+ return send_from_directory('.', path)
182
+
183
+ @app.route('/api/stations/locations')
184
+ def get_station_locations():
185
+ """Returns a lightweight list of station coordinates and statistical summaries."""
186
+ lightweight_data = []
187
+ for item in ALL_DATA:
188
+ records = item.get('bs_record', [])
189
+ if records:
190
+ avg = sum(records) / len(records)
191
+ std = calculate_std_dev(records, avg)
192
+ else:
193
+ avg = 0
194
+ std = 0
195
+
196
+ lightweight_data.append({
197
+ "id": item['id'],
198
+ "loc": item['loc'],
199
+ "val_h": avg,
200
+ "val_c": std,
201
+ "vals": records
202
+ })
203
+
204
+ return jsonify({
205
+ "stats_height": STATS_HEIGHT,
206
+ "stats_color": STATS_COLOR,
207
+ "stations": lightweight_data
208
+ })
209
+
210
+ @app.route('/api/stations/detail/<station_id>')
211
+ def get_station_detail(station_id):
212
+ """Returns detailed metadata and stats for a specific station."""
213
+ for item in ALL_DATA:
214
+ if str(item['id']) == str(station_id):
215
+ records = item.get('bs_record', [])
216
+ avg = sum(records)/len(records) if records else 0
217
+ std = calculate_std_dev(records, avg)
218
+
219
+ response = item.copy()
220
+ response['stats'] = {"avg": avg, "std": std}
221
+ return jsonify(response)
222
+
223
+ return jsonify({"error": "Station not found"}), 404
224
+
225
+ @app.route('/api/predict/<station_id>')
226
+ def predict_traffic(station_id):
227
+ """Triggers the ML model to predict future traffic for a specific station."""
228
+ if not predictor:
229
+ return jsonify({"error": "Prediction service not available"}), 503
230
+
231
+ try:
232
+ target_idx = -1
233
+
234
+ # Map Station ID to its internal index in the NPZ file
235
+ for item in ALL_DATA:
236
+ if str(item['id']) == str(station_id):
237
+ target_idx = item.get('npz_index', -1)
238
+ break
239
+
240
+ if target_idx == -1:
241
+ # Fallback: Check if the ID provided is directly a numerical index
242
+ if str(station_id).isdigit():
243
+ target_idx = int(station_id)
244
+ else:
245
+ return jsonify({"error": "Station ID not found in mapping"}), 404
246
+
247
+ # Execute prediction through the ML backend
248
+ result = predictor.predict(target_idx)
249
+
250
+ if "error" in result:
251
+ return jsonify(result), 500
252
+
253
+ return jsonify(result)
254
+
255
+ except Exception as e:
256
+ print(f"Prediction Error: {e}")
257
+ return jsonify({"error": str(e)}), 500
258
+
259
+ # Local development server
260
+ # if __name__ == '__main__':
261
+ # print(f"Monitoring Data Directory: {DATA_DIR}")
262
+ # print("Server running on http://127.0.0.1:5000")
263
+ # app.run(debug=True, port=5000)
264
+
265
+ # FOR ONLINE
266
+ if __name__ == '__main__':
267
+ print(f"Monitoring Data Directory: {DATA_DIR}")
268
+ print("Server running on port 7860...")
269
+ app.run(host='0.0.0.0', port=7860) # <--- 就改这一行!取消 debug=True,改 host 和 port
style.css ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* =========================================
2
+ 1. Base Reset & Layout
3
+ ========================================= */
4
+ * {
5
+ margin: 0;
6
+ padding: 0;
7
+ box-sizing: border-box;
8
+ /* Clean sans-serif stack for readability in high-tech interfaces */
9
+ font-family: 'Segoe UI', 'Roboto', Helvetica, Arial, sans-serif;
10
+ }
11
+
12
+ body {
13
+ background: #0f0c29; /* Deep space navy */
14
+ color: #e0e0e0;
15
+ height: 100vh;
16
+ overflow: hidden; /* Prevents browser scrollbars during panel transitions */
17
+ display: flex; /* Sidebars and main content align horizontally */
18
+ }
19
+
20
+ /* =========================================
21
+ 2. Loading Overlay
22
+ ========================================= */
23
+ .loading-overlay {
24
+ position: absolute;
25
+ top: 0;
26
+ left: 0;
27
+ width: 100%;
28
+ height: 100%;
29
+ background: #0f0c29;
30
+ z-index: 9999; /* Ensure it stays above map and sidebars */
31
+ display: flex;
32
+ flex-direction: column;
33
+ justify-content: center;
34
+ align-items: center;
35
+ transition: opacity 0.8s ease-out;
36
+ }
37
+
38
+ .spinner {
39
+ width: 50px;
40
+ height: 50px;
41
+ border: 3px solid rgba(0, 206, 201, 0.1);
42
+ border-top: 3px solid #00cec9; /* Neon teal accent */
43
+ border-radius: 50%;
44
+ animation: spin 1s infinite cubic-bezier(0.55, 0.15, 0.45, 0.85);
45
+ margin-bottom: 20px;
46
+ }
47
+
48
+ @keyframes spin {
49
+ 0% { transform: rotate(0deg); }
50
+ 100% { transform: rotate(360deg); }
51
+ }
52
+
53
+ /* =========================================
54
+ 3. Left Sidebar (Main Controls)
55
+ ========================================= */
56
+ .sidebar {
57
+ width: 380px;
58
+ background: rgba(15, 23, 42, 0.85);
59
+ backdrop-filter: blur(20px) saturate(180%);
60
+ /* Glassmorphism: blur effect creates depth against the map */
61
+ -webkit-backdrop-filter: blur(20px) saturate(180%);
62
+ padding: 25px;
63
+ display: flex;
64
+ flex-direction: column;
65
+ gap: 20px;
66
+ box-shadow: 10px 0 30px rgba(0,0,0,0.5);
67
+ z-index: 10;
68
+ border-right: 1px solid rgba(255,255,255,0.05);
69
+ overflow-y: auto;
70
+ }
71
+
72
+ .header h1 {
73
+ font-size: 24px;
74
+ font-weight: 300;
75
+ letter-spacing: 1px;
76
+ /* Gradient text for futuristic branding */
77
+ background: linear-gradient(to right, #00cec9, #a29bfe);
78
+ -webkit-background-clip: text;
79
+ -webkit-text-fill-color: transparent;
80
+ }
81
+
82
+ /* Search Section in Left Sidebar */
83
+ .search-section {
84
+ margin-top: 15px;
85
+ margin-bottom: 10px;
86
+ display: flex;
87
+ flex-direction: column;
88
+ gap: 8px;
89
+ }
90
+
91
+ .search-container {
92
+ display: flex;
93
+ gap: 8px;
94
+ }
95
+
96
+ #search-input {
97
+ flex: 1;
98
+ background: rgba(15, 23, 42, 0.6);
99
+ border: 1px solid rgba(0, 206, 201, 0.4);
100
+ color: #fff;
101
+ padding: 6px 10px;
102
+ border-radius: 4px;
103
+ outline: none;
104
+ font-family: 'Courier New', monospace; /* "Code" feel for data input */
105
+ font-size: 14px;
106
+ }
107
+
108
+ #search-input:focus {
109
+ border-color: #00cec9;
110
+ box-shadow: 0 0 8px rgba(0, 206, 201, 0.2);
111
+ }
112
+
113
+ .search-mode {
114
+ display: flex;
115
+ align-items: center;
116
+ gap: 8px;
117
+ font-size: 11px;
118
+ color: #94a3b8;
119
+ padding-left: 2px;
120
+ }
121
+
122
+ .search-mode input { cursor: pointer; accent-color: #00cec9; }
123
+ .search-mode label { cursor: pointer; }
124
+
125
+ /* Small Utility Buttons */
126
+ .cyber-btn-small {
127
+ background: rgba(0, 206, 201, 0.1);
128
+ color: #00cec9;
129
+ border: 1px solid #00cec9;
130
+ padding: 0 12px;
131
+ border-radius: 4px;
132
+ cursor: pointer;
133
+ font-weight: bold;
134
+ transition: all 0.2s;
135
+ display: flex;
136
+ align-items: center;
137
+ justify-content: center;
138
+ }
139
+
140
+ .cyber-btn-small:hover {
141
+ background: #00cec9;
142
+ color: #000;
143
+ }
144
+
145
+ #clear-search-btn {
146
+ border-color: #fd79a8;
147
+ color: #fd79a8;
148
+ background: rgba(253, 121, 168, 0.1);
149
+ }
150
+
151
+ #clear-search-btn:hover {
152
+ background: #fd79a8;
153
+ color: #fff;
154
+ }
155
+
156
+ /* Cards & Charts */
157
+ .card {
158
+ background: rgba(255, 255, 255, 0.03);
159
+ padding: 20px;
160
+ border-radius: 12px;
161
+ border: 1px solid rgba(255,255,255,0.02);
162
+ transition: transform 0.2s, background 0.2s;
163
+ }
164
+
165
+ .card:hover {
166
+ background: rgba(255, 255, 255, 0.05);
167
+ border-color: rgba(0, 206, 201, 0.2);
168
+ }
169
+
170
+ .card h2 {
171
+ font-size: 12px;
172
+ color: #94a3b8;
173
+ margin-bottom: 15px;
174
+ border-bottom: 1px solid rgba(255,255,255,0.05);
175
+ padding-bottom: 8px;
176
+ text-transform: uppercase;
177
+ font-weight: 600;
178
+ }
179
+
180
+ .chart-container {
181
+ height: 140px;
182
+ width: 100%;
183
+ position: relative;
184
+ }
185
+
186
+ .stat-row {
187
+ display: flex;
188
+ justify-content: space-between;
189
+ margin-bottom: 5px;
190
+ }
191
+
192
+ .value {
193
+ font-size: 18px;
194
+ font-weight: 500;
195
+ color: #f1f5f9;
196
+ font-family: 'Courier New', monospace;
197
+ }
198
+
199
+ .value.highlight { color: #00cec9; }
200
+
201
+ .details-content p {
202
+ font-size: 13px;
203
+ line-height: 1.8;
204
+ color: #cbd5e1;
205
+ }
206
+
207
+ /* =========================================
208
+ 4. Main Content & Map Area
209
+ ========================================= */
210
+ .main-content {
211
+ flex: 1;
212
+ height: 100%;
213
+ position: relative;
214
+ overflow: hidden;
215
+ }
216
+
217
+ #map { width: 100%; height: 100%; }
218
+
219
+ /* Mapbox Popup Overrides */
220
+ .mapboxgl-popup-content {
221
+ background: rgba(15, 23, 42, 0.95) !important;
222
+ border: 1px solid rgba(0, 206, 201, 0.5);
223
+ box-shadow: 0 0 15px rgba(0, 206, 201, 0.2);
224
+ padding: 8px 12px !important;
225
+ border-radius: 6px !important;
226
+ color: #e0e0e0;
227
+ min-width: 120px;
228
+ }
229
+ .mapboxgl-popup-tip {
230
+ border-top-color: rgba(0, 206, 201, 0.5) !important;
231
+ margin-bottom: -1px;
232
+ }
233
+
234
+ /* =========================================
235
+ 5. Top-Left Controls
236
+ ========================================= */
237
+ .controls-container {
238
+ position: absolute;
239
+ top: 20px;
240
+ left: 20px;
241
+ z-index: 100;
242
+ display: flex;
243
+ gap: 10px;
244
+ }
245
+
246
+ .cyber-btn {
247
+ background: rgba(15, 23, 42, 0.9);
248
+ color: #00cec9;
249
+ border: 1px solid #00cec9;
250
+ padding: 10px 20px;
251
+ font-size: 14px;
252
+ font-weight: 600;
253
+ border-radius: 4px;
254
+ cursor: pointer;
255
+ box-shadow: 0 0 10px rgba(0, 206, 201, 0.2);
256
+ transition: all 0.3s ease;
257
+ display: flex;
258
+ align-items: center;
259
+ gap: 8px;
260
+ text-transform: uppercase;
261
+ letter-spacing: 1px;
262
+ white-space: nowrap;
263
+ }
264
+
265
+ .cyber-btn:hover {
266
+ background: rgba(0, 206, 201, 0.15);
267
+ box-shadow: 0 0 20px rgba(0, 206, 201, 0.4);
268
+ transform: translateY(-1px);
269
+ }
270
+
271
+ /* Filter Menu */
272
+ .filter-wrapper { position: relative; display: inline-block; }
273
+
274
+ .filter-menu {
275
+ position: absolute;
276
+ top: 50px;
277
+ left: 0;
278
+ width: 170px;
279
+ background: rgba(15, 23, 42, 0.95);
280
+ border: 1px solid #00cec9;
281
+ border-radius: 4px;
282
+ padding: 8px;
283
+ display: flex;
284
+ flex-direction: column;
285
+ gap: 6px;
286
+ opacity: 0;
287
+ visibility: hidden;
288
+ transform: translateY(-10px);
289
+ transition: all 0.3s ease;
290
+ box-shadow: 0 5px 20px rgba(0,0,0,0.5);
291
+ }
292
+
293
+ .filter-menu.active {
294
+ opacity: 1;
295
+ visibility: visible;
296
+ transform: translateY(0);
297
+ }
298
+
299
+ .filter-item {
300
+ display: flex;
301
+ align-items: center;
302
+ gap: 10px;
303
+ padding: 8px 10px;
304
+ border-radius: 4px;
305
+ cursor: pointer;
306
+ transition: background 0.2s;
307
+ font-size: 12px;
308
+ color: #ccc;
309
+ border: 1px solid transparent;
310
+ }
311
+
312
+ .filter-item:hover { background: rgba(255, 255, 255, 0.1); }
313
+
314
+ .filter-item.selected {
315
+ background: rgba(0, 206, 201, 0.2);
316
+ border-color: rgba(0, 206, 201, 0.5);
317
+ color: #fff; font-weight: bold;
318
+ }
319
+
320
+ .color-box { width: 12px; height: 12px; border-radius: 2px; }
321
+
322
+ /* =========================================
323
+ 6. Bottom Time Panel
324
+ ========================================= */
325
+ .time-panel {
326
+ position: absolute;
327
+ bottom: 40px;
328
+ left: 50%;
329
+ transform: translateX(-50%);
330
+ width: 60%;
331
+ min-width: 500px;
332
+ height: 70px;
333
+ background: rgba(15, 23, 42, 0.9);
334
+ border: 1px solid #00cec9;
335
+ box-shadow: 0 0 20px rgba(0, 206, 201, 0.15), inset 0 0 50px rgba(0,0,0,0.6);
336
+ backdrop-filter: blur(10px);
337
+ border-radius: 50px;
338
+ z-index: 100;
339
+ display: flex;
340
+ align-items: center;
341
+ padding: 0 25px;
342
+ gap: 20px;
343
+ }
344
+
345
+ .play-control {
346
+ min-width: 45px;
347
+ height: 45px;
348
+ border-radius: 50%;
349
+ justify-content: center;
350
+ padding: 0;
351
+ font-size: 18px;
352
+ border-width: 2px;
353
+ }
354
+
355
+ .digital-clock {
356
+ font-family: 'Courier New', monospace;
357
+ font-size: 20px;
358
+ font-weight: bold;
359
+ color: #00cec9;
360
+ text-shadow: 0 0 8px rgba(0, 206, 201, 0.8);
361
+ background: rgba(0, 0, 0, 0.4);
362
+ padding: 5px 12px;
363
+ border-radius: 6px;
364
+ border: 1px solid rgba(0, 206, 201, 0.2);
365
+ min-width: 80px;
366
+ text-align: center;
367
+ }
368
+
369
+ .slider-wrapper {
370
+ flex: 1;
371
+ display: flex;
372
+ flex-direction: column;
373
+ justify-content: center;
374
+ position: relative;
375
+ margin-top: -2px;
376
+ }
377
+
378
+ .slider-ticks {
379
+ display: flex;
380
+ justify-content: space-between;
381
+ margin-top: 8px;
382
+ font-size: 10px;
383
+ color: #64748b;
384
+ font-family: monospace;
385
+ padding: 0 2px;
386
+ }
387
+
388
+ #time-slider {
389
+ -webkit-appearance: none;
390
+ width: 100%;
391
+ height: 4px;
392
+ background: rgba(255, 255, 255, 0.1);
393
+ border-radius: 2px;
394
+ outline: none;
395
+ cursor: pointer;
396
+ transition: background 0.3s;
397
+ }
398
+
399
+ #time-slider:hover { background: rgba(255, 255, 255, 0.2); }
400
+
401
+ #time-slider::-webkit-slider-thumb {
402
+ -webkit-appearance: none;
403
+ width: 22px;
404
+ height: 22px;
405
+ border-radius: 50%;
406
+ background: #0f172a;
407
+ border: 2px solid #00cec9;
408
+ box-shadow: 0 0 10px #00cec9;
409
+ margin-top: 0px;
410
+ transition: transform 0.1s;
411
+ }
412
+
413
+ #time-slider::-webkit-slider-thumb:hover {
414
+ transform: scale(1.2);
415
+ background: #00cec9;
416
+ }
417
+
418
+ /* =========================================
419
+ 7. Right Sidebar (Prediction Mode)
420
+ ========================================= */
421
+ /* Container for the sliding panel */
422
+ .sidebar-right {
423
+ position: fixed;
424
+ top: 0;
425
+ right: -450px; /* Hidden by default */
426
+ width: 400px;
427
+ height: 100vh;
428
+ background: rgba(10, 10, 30, 0.95);
429
+ border-left: 1px solid #f39c12;
430
+ backdrop-filter: blur(10px);
431
+ transition: right 0.3s cubic-bezier(0.4, 0, 0.2, 1);
432
+ z-index: 1000;
433
+ padding: 25px;
434
+ color: #fff;
435
+ box-shadow: -10px 0 30px rgba(0,0,0,0.5);
436
+ display: flex;
437
+ flex-direction: column;
438
+ overflow-y: auto;
439
+ }
440
+ /* Active state (Slid in) */
441
+ .sidebar-right.active {
442
+ right: 0;
443
+ }
444
+
445
+ /* Custom Cyber-Scrollbar for Prediction Sidebar */
446
+ .sidebar-right::-webkit-scrollbar {
447
+ width: 6px;
448
+ }
449
+ .sidebar-right::-webkit-scrollbar-track {
450
+ background: rgba(0, 0, 0, 0.3);
451
+ }
452
+ .sidebar-right::-webkit-scrollbar-thumb {
453
+ background: #f39c12;
454
+ border-radius: 3px;
455
+ }
456
+ .sidebar-right::-webkit-scrollbar-thumb:hover {
457
+ background: #e67e22;
458
+ }
459
+
460
+ /* Header within right sidebar */
461
+ .sidebar-right .header {
462
+ display: flex;
463
+ justify-content: space-between;
464
+ align-items: center;
465
+ border-bottom: 1px solid rgba(243, 156, 18, 0.3);
466
+ padding-bottom: 15px;
467
+ margin-bottom: 20px;
468
+ }
469
+
470
+ .sidebar-right h1 {
471
+ font-size: 20px;
472
+ background: linear-gradient(to right, #f39c12, #f1c40f);
473
+ -webkit-background-clip: text;
474
+ -webkit-text-fill-color: transparent;
475
+ text-transform: uppercase;
476
+ letter-spacing: 1px;
477
+ }
478
+
479
+ /* Close button for right sidebar */
480
+ #close-pred-btn {
481
+ border-color: #666;
482
+ color: #aaa;
483
+ background: transparent;
484
+ }
485
+ #close-pred-btn:hover {
486
+ border-color: #fff;
487
+ color: #fff;
488
+ background: rgba(255,255,255,0.1);
489
+ }
490
+
491
+ /* Prediction specific button states */
492
+ #predict-toggle.predict-on {
493
+ background: rgba(243, 156, 18, 0.2);
494
+ box-shadow: 0 0 20px rgba(243, 156, 18, 0.4);
495
+ border-color: #f39c12;
496
+ color: #f39c12;
497
+ }
498
+
499
+ /* Legend items in prediction panel */
500
+ .legend-box {
501
+ margin-top: auto; /* Push to bottom if needed, or just normal flow */
502
+ border: 1px solid rgba(255,255,255,0.1);
503
+ }
504
+
505
+ .dot {
506
+ width: 10px;
507
+ height: 10px;
508
+ display: inline-block;
509
+ border-radius: 50%;
510
+ margin-right: 8px;
511
+ }
512
+
513
+
514
+ /* =========================================
515
+ 8. Map Custom Pins & AI Log Visuals
516
+ ========================================= */
517
+ .optimal-pulse-pin {
518
+ width: 20px;
519
+ height: 20px;
520
+ background-color: #2ecc71;
521
+ border-radius: 50%;
522
+ border: 3px solid #ffffff;
523
+ box-shadow: 0 0 15px #2ecc71, 0 0 30px #2ecc71;
524
+ animation: optimal-pulse 1.5s infinite cubic-bezier(0.66, 0, 0, 1);
525
+ cursor: pointer;
526
+ }
527
+
528
+ @keyframes optimal-pulse {
529
+ to {
530
+ box-shadow: 0 0 0 20px rgba(46, 204, 113, 0);
531
+ background-color: rgba(46, 204, 113, 0.8);
532
+ }
533
+ }
534
+
535
+ .cyber-explanation {
536
+ margin-top: 12px;
537
+ padding: 12px;
538
+ background: rgba(0, 206, 201, 0.05);
539
+ border-left: 3px solid #00cec9;
540
+ border-radius: 0 4px 4px 0;
541
+ font-size: 11px;
542
+ color: #a29bfe;
543
+ font-family: 'Courier New', Courier, monospace;
544
+ line-height: 1.6;
545
+ box-shadow: inset 0 0 10px rgba(0, 206, 201, 0.05);
546
+ }
547
+
548
+ .cyber-explanation strong {
549
+ color: #00cec9;
550
+ font-weight: bold;
551
+ }
552
+
553
+ /* =========================================
554
+ 9. Panel Toggle Buttons & Navigation UI
555
+ ========================================= */
556
+ .sidebar {
557
+ transition: margin-left 0.3s cubic-bezier(0.4, 0, 0.2, 1);
558
+ position: relative;
559
+ }
560
+ .sidebar.collapsed {
561
+ margin-left: -380px;
562
+ }
563
+ .sidebar-right.active.collapsed {
564
+ right: -400px;
565
+ }
566
+
567
+ /* Sticky toggle buttons on panel edges */
568
+ .panel-toggle-btn {
569
+ position: absolute;
570
+ top: 50%;
571
+ transform: translateY(-50%);
572
+ width: 22px;
573
+ height: 60px;
574
+ background: rgba(15, 23, 42, 0.9);
575
+ border: 1px solid #00cec9;
576
+ color: #00cec9;
577
+ cursor: pointer;
578
+ z-index: 1000;
579
+ display: flex;
580
+ align-items: center;
581
+ justify-content: center;
582
+ font-size: 10px;
583
+ box-shadow: 0 0 10px rgba(0, 206, 201, 0.2);
584
+ backdrop-filter: blur(5px);
585
+ transition: all 0.2s ease;
586
+ outline: none;
587
+ }
588
+
589
+ .panel-toggle-btn:hover {
590
+ background: #00cec9;
591
+ color: #000;
592
+ box-shadow: 0 0 15px rgba(0, 206, 201, 0.5);
593
+ }
594
+
595
+ /* Positioning logic for sidebars toggles */
596
+ .left-toggle {
597
+ left: 380px; /* Aligned with sidebar width */
598
+ border-left: none;
599
+ border-radius: 0 6px 6px 0;
600
+ transition: left 0.3s cubic-bezier(0.4, 0, 0.2, 1);
601
+ }
602
+ .left-toggle.collapsed {
603
+ left: 0;
604
+ }
605
+
606
+ .right-toggle {
607
+ right: -22px; /* Starts hidden */
608
+ border-right: none;
609
+ border-radius: 6px 0 0 6px;
610
+ transition: right 0.3s cubic-bezier(0.4, 0, 0.2, 1);
611
+ }
612
+ .right-toggle.active {
613
+ right: 400px;
614
+ }
615
+ .right-toggle.collapsed {
616
+ right: 0;
617
+ }