Sadjad Alikhani commited on
Commit
265d187
·
1 Parent(s): 80a230c

Add data dir and docs

Browse files
.gitignore CHANGED
@@ -38,7 +38,6 @@ MANIFEST
38
  *.h5
39
  *.hdf5
40
  cache/
41
- data/
42
  !examples/data/
43
  !examples/data/*.p
44
  !examples/data/README.md
 
38
  *.h5
39
  *.hdf5
40
  cache/
 
41
  !examples/data/
42
  !examples/data/*.p
43
  !examples/data/README.md
LWMTemporal/data/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data utilities for the LWM foundation package."""
2
+
3
+ from .angle_delay import AngleDelayConfig, AngleDelayProcessor
4
+ from .datasets import AngleDelaySequenceDataset, AngleDelayDatasetConfig, load_adseq_dataset
5
+ from .scenario_generation import (
6
+ AntennaArrayConfig,
7
+ DynamicScenarioGenerator,
8
+ GridConfig,
9
+ ScenarioGenerationConfig,
10
+ ScenarioGenerationResult,
11
+ ScenarioSamplingConfig,
12
+ TrafficConfig,
13
+ generate_dynamic_scenario,
14
+ generate_dynamic_scenario_dataset,
15
+ )
16
+
17
+ __all__ = [
18
+ "AngleDelayConfig",
19
+ "AngleDelayProcessor",
20
+ "AngleDelaySequenceDataset",
21
+ "AngleDelayDatasetConfig",
22
+ "load_adseq_dataset",
23
+ "AntennaArrayConfig",
24
+ "TrafficConfig",
25
+ "GridConfig",
26
+ "ScenarioSamplingConfig",
27
+ "ScenarioGenerationConfig",
28
+ "ScenarioGenerationResult",
29
+ "DynamicScenarioGenerator",
30
+ "generate_dynamic_scenario_dataset",
31
+ "generate_dynamic_scenario",
32
+ ]
LWMTemporal/data/angle_delay.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import math
5
+ from pathlib import Path
6
+ from typing import Iterable, List, Optional, Sequence
7
+
8
+ import imageio.v2 as imageio
9
+ import matplotlib.animation as animation
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class AngleDelayConfig:
17
+ """Configuration options for angle-delay processing."""
18
+
19
+ angle_range: tuple[float, float] = (-math.pi / 2, math.pi / 2)
20
+ delay_range: tuple[float, float] = (0.0, 100.0)
21
+ keep_percentage: float = 0.25
22
+ fps: int = 4
23
+ dpi: int = 120
24
+ num_bins: int = 6
25
+ output_dir: Path = Path("figs")
26
+
27
+ def validate(self) -> None:
28
+ if not 0.0 < self.keep_percentage <= 1.0:
29
+ raise ValueError("keep_percentage must be in (0, 1]")
30
+ if self.fps <= 0:
31
+ raise ValueError("fps must be positive")
32
+ if self.dpi <= 0:
33
+ raise ValueError("dpi must be positive")
34
+ if self.num_bins <= 0:
35
+ raise ValueError("num_bins must be positive")
36
+
37
+
38
+ class AngleDelayProcessor:
39
+ """Project complex channels into the angle-delay domain and visualise them."""
40
+
41
+ def __init__(self, config: AngleDelayConfig | None = None) -> None:
42
+ self.config = config or AngleDelayConfig()
43
+ self.config.validate()
44
+
45
+ # ------------------------------------------------------------------
46
+ # Core transforms
47
+ # ------------------------------------------------------------------
48
+ @staticmethod
49
+ def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor:
50
+ if not torch.is_complex(tensor):
51
+ raise TypeError("expected complex tensor")
52
+ return tensor
53
+
54
+ def forward(self, channel: torch.Tensor) -> torch.Tensor:
55
+ channel = self._ensure_complex(channel)
56
+ angle_domain = torch.fft.fft(channel, dim=1, norm="ortho")
57
+ delay_domain = torch.fft.ifft(angle_domain, dim=2, norm="ortho")
58
+ return delay_domain
59
+
60
+ def inverse(self, angle_delay: torch.Tensor) -> torch.Tensor:
61
+ angle_delay = self._ensure_complex(angle_delay)
62
+ subcarrier = torch.fft.fft(angle_delay, dim=2, norm="ortho")
63
+ antenna = torch.fft.ifft(subcarrier, dim=1, norm="ortho")
64
+ return antenna
65
+
66
+ # ------------------------------------------------------------------
67
+ # Truncation helpers and metrics
68
+ # ------------------------------------------------------------------
69
+ def truncate_delay_bins(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
70
+ tensor = self._ensure_complex(tensor)
71
+ if tensor.ndim != 3:
72
+ raise ValueError("angle-delay tensor must have shape (T, N, M)")
73
+ keep = max(1, int(round(tensor.size(-1) * self.config.keep_percentage)))
74
+ truncated = tensor[..., :keep]
75
+ padded = torch.zeros_like(tensor)
76
+ padded[..., :keep] = truncated
77
+ return truncated, padded
78
+
79
+ @staticmethod
80
+ def nmse(reference: torch.Tensor, reconstruction: torch.Tensor) -> float:
81
+ reference = AngleDelayProcessor._ensure_complex(reference)
82
+ reconstruction = AngleDelayProcessor._ensure_complex(reconstruction)
83
+ mse = torch.mean(torch.abs(reference - reconstruction) ** 2)
84
+ power = torch.mean(torch.abs(reference) ** 2).clamp_min(1e-12)
85
+ return float(10.0 * torch.log10(mse / power))
86
+
87
+ def reconstruction_nmse(self, channel: torch.Tensor) -> tuple[float, float]:
88
+ ad_full = self.forward(channel)
89
+ recon_full = self.inverse(ad_full)
90
+ nmse_full = self.nmse(channel, recon_full)
91
+ truncated, padded = self.truncate_delay_bins(ad_full)
92
+ recon_trunc = self.inverse(padded)
93
+ nmse_trunc = self.nmse(channel, recon_trunc)
94
+ return nmse_full, nmse_trunc
95
+
96
+ # ------------------------------------------------------------------
97
+ # Visualisation helpers
98
+ # ------------------------------------------------------------------
99
+ def save_angle_delay_gif(self, tensor: torch.Tensor, output_path: Path, fps: Optional[int] = None) -> None:
100
+ tensor = self._ensure_complex(tensor)
101
+ output_path = Path(output_path)
102
+ output_path.parent.mkdir(parents=True, exist_ok=True)
103
+
104
+ magnitude = tensor.abs().cpu()
105
+ vmin, vmax = float(magnitude.min()), float(magnitude.max())
106
+ frames: List[np.ndarray] = []
107
+ for frame_idx in range(magnitude.size(0)):
108
+ fig, ax = plt.subplots(figsize=(8, 6))
109
+ im = ax.imshow(
110
+ magnitude[frame_idx].numpy(),
111
+ cmap="gray_r",
112
+ origin="lower",
113
+ aspect="auto",
114
+ extent=[*self.config.delay_range, *self.config.angle_range],
115
+ vmin=vmin,
116
+ vmax=vmax,
117
+ )
118
+ ax.set_xlabel("Delay (samples)")
119
+ ax.set_ylabel("Angle (radians)")
120
+ ax.set_title(f"Angle-Delay Map – Frame {frame_idx}")
121
+ fig.colorbar(im, ax=ax, label="Power (linear)")
122
+ fig.canvas.draw()
123
+ frames.append(np.asarray(fig.canvas.buffer_rgba()))
124
+ plt.close(fig)
125
+ imageio.mimsave(output_path, frames, fps=fps or self.config.fps)
126
+
127
+ def _save_animation(self, fig: plt.Figure, animate_fn, output_path: Path, fps: Optional[int] = None, dpi: Optional[int] = None) -> None:
128
+ anim = animation.FuncAnimation(fig, animate_fn)
129
+ output_path = Path(output_path)
130
+ output_path.parent.mkdir(parents=True, exist_ok=True)
131
+ anim.save(output_path, writer="pillow", fps=fps or self.config.fps, dpi=dpi or self.config.dpi)
132
+ plt.close(fig)
133
+
134
+ def save_channel_animation(self, channel: torch.Tensor, output_path: Path) -> None:
135
+ channel = self._ensure_complex(channel)
136
+ magnitude = channel.abs().cpu()
137
+ phase = torch.angle(channel).cpu()
138
+ vmin, vmax = float(magnitude.min()), float(magnitude.max())
139
+ temporal_mag = magnitude.mean(dim=(1, 2))
140
+ temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy())
141
+
142
+ fig, (ax_mag, ax_mag_line, ax_phase_line) = plt.subplots(1, 3, figsize=(20, 6))
143
+ mag_img = ax_mag.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto", vmin=vmin, vmax=vmax)
144
+ ax_mag.set_xlabel("Subcarrier")
145
+ ax_mag.set_ylabel("Antenna")
146
+ fig.colorbar(mag_img, ax=ax_mag, label="Magnitude")
147
+
148
+ mag_line, = ax_mag_line.plot([], [], "b-o", linewidth=2, markersize=4)
149
+ ax_mag_line.set_xlim(0, magnitude.size(0) - 1)
150
+ ax_mag_line.set_ylim(float(temporal_mag.min()) * 0.9, float(temporal_mag.max()) * 1.1)
151
+ ax_mag_line.grid(True, alpha=0.3)
152
+ ax_mag_line.set_xlabel("Frame")
153
+ ax_mag_line.set_ylabel("Average Magnitude")
154
+
155
+ phase_line, = ax_phase_line.plot([], [], "r-s", linewidth=2, markersize=4)
156
+ ax_phase_line.set_xlim(0, magnitude.size(0) - 1)
157
+ ax_phase_line.set_ylim(float(temporal_phase.min()) - 0.1, float(temporal_phase.max()) + 0.1)
158
+ ax_phase_line.grid(True, alpha=0.3)
159
+ ax_phase_line.set_xlabel("Frame")
160
+ ax_phase_line.set_ylabel("Average Phase (rad)")
161
+
162
+ def animate(idx: int):
163
+ mag_img.set_array(magnitude[idx].numpy())
164
+ ax_mag.set_title(f"Magnitude – Frame {idx}")
165
+ xs = np.arange(idx + 1)
166
+ mag_line.set_data(xs, temporal_mag[: idx + 1].numpy())
167
+ phase_line.set_data(xs, temporal_phase[: idx + 1])
168
+ return mag_img, mag_line, phase_line
169
+
170
+ self._save_animation(fig, animate, output_path)
171
+
172
+ def save_angle_delay_animation(self, tensor: torch.Tensor, output_path: Path, keep_percentage: Optional[float] = None) -> None:
173
+ tensor = self._ensure_complex(tensor)
174
+ magnitude = tensor.abs().cpu()
175
+ phase = torch.angle(tensor).cpu()
176
+ keep_suffix = "" if keep_percentage is None else f" (keep={keep_percentage * 100:.0f}%)"
177
+
178
+ fig, axes = plt.subplots(2, 2, figsize=(18, 10))
179
+ mag_ax, phase_ax, mag_line_ax, phase_line_ax = axes.flat
180
+ mag_img = mag_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto")
181
+ mag_ax.set_xlabel("Delay Bin")
182
+ mag_ax.set_ylabel("Angle Bin")
183
+ fig.colorbar(mag_img, ax=mag_ax, label="Magnitude")
184
+
185
+ phase_img = phase_ax.imshow(phase[0].numpy(), cmap="twilight", origin="upper", aspect="auto", vmin=-math.pi, vmax=math.pi)
186
+ phase_ax.set_xlabel("Delay Bin")
187
+ phase_ax.set_ylabel("Angle Bin")
188
+ fig.colorbar(phase_img, ax=phase_ax, label="Phase (rad)")
189
+
190
+ temporal_mag = magnitude.mean(dim=(1, 2))
191
+ temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy())
192
+ mag_line, = mag_line_ax.plot([], [], "r-o", linewidth=2)
193
+ phase_line, = phase_line_ax.plot([], [], "b-s", linewidth=2)
194
+
195
+ for axis, label in ((mag_line_ax, "Average Magnitude"), (phase_line_ax, "Average Phase (rad)")):
196
+ axis.set_xlabel("Frame")
197
+ axis.set_ylabel(label)
198
+ axis.set_xlim(0, tensor.size(0) - 1)
199
+ axis.grid(True, alpha=0.3)
200
+
201
+ def animate(idx: int):
202
+ mag_img.set_array(magnitude[idx].numpy())
203
+ phase_img.set_array(phase[idx].numpy())
204
+ mag_ax.set_title(f"AD Magnitude – Frame {idx}{keep_suffix}")
205
+ phase_ax.set_title(f"AD Phase – Frame {idx}{keep_suffix}")
206
+ xs = np.arange(idx + 1)
207
+ mag_line.set_data(xs, temporal_mag[: idx + 1].numpy())
208
+ phase_line.set_data(xs, temporal_phase[: idx + 1])
209
+ return mag_img, phase_img, mag_line, phase_line
210
+
211
+ self._save_animation(fig, animate, output_path)
212
+
213
+ def save_dominant_bin_animation(self, tensor: torch.Tensor, output_path: Path, threshold_ratio: float = 0.05) -> None:
214
+ tensor = self._ensure_complex(tensor)
215
+ magnitude = tensor.abs().cpu()
216
+ threshold = float(magnitude.max()) * threshold_ratio
217
+ dominant_counts = (magnitude > threshold).sum(dim=(1, 2)).numpy()
218
+
219
+ fig, (heat_ax, line_ax) = plt.subplots(1, 2, figsize=(16, 6))
220
+ heat_img = heat_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto")
221
+ heat_ax.set_xlabel("Delay Bin")
222
+ heat_ax.set_ylabel("Angle Bin")
223
+ fig.colorbar(heat_img, ax=heat_ax, label="Magnitude")
224
+
225
+ count_line, = line_ax.plot([], [], "r-s", linewidth=2)
226
+ line_ax.set_xlabel("Frame")
227
+ line_ax.set_ylabel("Dominant Bin Count")
228
+ line_ax.set_xlim(0, tensor.size(0) - 1)
229
+ line_ax.set_ylim(0, dominant_counts.max() * 1.1)
230
+ line_ax.grid(True, alpha=0.3)
231
+
232
+ def animate(idx: int):
233
+ heat_img.set_array(magnitude[idx].numpy())
234
+ heat_ax.set_title(f"Magnitude – Frame {idx}")
235
+ xs = np.arange(idx + 1)
236
+ count_line.set_data(xs, dominant_counts[: idx + 1])
237
+ return heat_img, count_line
238
+
239
+ self._save_animation(fig, animate, output_path)
240
+
241
+ def save_bin_evolution_plot(self, tensor: torch.Tensor, output_path: Path) -> None:
242
+ tensor = self._ensure_complex(tensor)
243
+ magnitude = tensor.abs()
244
+ avg_mag = magnitude.mean(dim=0)
245
+ flat_mag = avg_mag.flatten()
246
+ k = min(self.config.num_bins, flat_mag.numel())
247
+ _, indices = torch.topk(flat_mag, k)
248
+ angle_indices = (indices // tensor.size(-1)).tolist()
249
+ delay_indices = (indices % tensor.size(-1)).tolist()
250
+
251
+ time_axis = np.arange(tensor.size(0))
252
+ fig, axes = plt.subplots(k, 2, figsize=(12, 3 * k), squeeze=False)
253
+ for row in range(k):
254
+ series = tensor[:, angle_indices[row], delay_indices[row]]
255
+ mag_series = torch.abs(series).cpu().numpy()
256
+ phase_series = np.unwrap(torch.angle(series).cpu().numpy())
257
+ ax_mag, ax_phase = axes[row]
258
+ ax_mag.plot(time_axis, mag_series, "b-", linewidth=2)
259
+ ax_mag.set_title(f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) – Magnitude")
260
+ ax_mag.set_xlabel("Frame")
261
+ ax_mag.set_ylabel("Magnitude")
262
+ ax_mag.grid(True, alpha=0.3)
263
+ ax_phase.plot(time_axis, phase_series, "r-", linewidth=2)
264
+ ax_phase.set_title(f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) – Phase")
265
+ ax_phase.set_xlabel("Frame")
266
+ ax_phase.set_ylabel("Phase (rad)")
267
+ ax_phase.grid(True, alpha=0.3)
268
+ fig.tight_layout()
269
+ output_path = Path(output_path)
270
+ output_path.parent.mkdir(parents=True, exist_ok=True)
271
+ fig.savefig(output_path, dpi=self.config.dpi, bbox_inches="tight")
272
+ plt.close(fig)
LWMTemporal/data/datasets.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import math
5
+ import pickle
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+
13
+ from .angle_delay import AngleDelayConfig, AngleDelayProcessor
14
+ from ..models.lwm import ComplexPatchTokenizer
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class AngleDelayDatasetConfig:
19
+ raw_path: Path
20
+ keep_percentage: float = 0.25
21
+ normalize: str = "global_rms"
22
+ cache_dir: Optional[Path] = Path("cache")
23
+ use_cache: bool = True
24
+ overwrite_cache: bool = False
25
+ snr_db: Optional[float] = None
26
+ noise_seed: Optional[int] = None
27
+ max_time_steps: Optional[int] = None
28
+ patch_size: Tuple[int, int] = (1, 1)
29
+ phase_mode: str = "real_imag"
30
+
31
+
32
+ class AngleDelaySequenceDataset(Dataset):
33
+ """Angle-delay dataset with optional caching and metadata retention."""
34
+
35
+ def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None:
36
+ super().__init__()
37
+ self.config = config
38
+ self.logger = logger
39
+ self.tokenizer = ComplexPatchTokenizer(config.phase_mode)
40
+ self.samples: List[torch.Tensor]
41
+ self.avg_speed_mps: List[Optional[float]]
42
+ cache_hit = False
43
+ cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None
44
+ if cache_path and cache_path.exists() and not config.overwrite_cache:
45
+ try:
46
+ payload = torch.load(cache_path, map_location="cpu")
47
+ if isinstance(payload, dict) and "samples" in payload:
48
+ self.samples = payload["samples"]
49
+ self.avg_speed_mps = payload.get("avg_speed_mps", [None] * len(self.samples))
50
+ else:
51
+ self.samples = payload
52
+ self.avg_speed_mps = [None] * len(self.samples)
53
+ cache_hit = True
54
+ except Exception:
55
+ cache_path.unlink(missing_ok=True)
56
+ cache_hit = False
57
+ if not cache_hit:
58
+ self.samples, self.avg_speed_mps = self._build_samples()
59
+ if cache_path is not None:
60
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
61
+ torch.save({"samples": self.samples, "avg_speed_mps": self.avg_speed_mps}, cache_path)
62
+ if self.config.snr_db is not None:
63
+ self._apply_noise()
64
+
65
+ def _cache_path(self) -> Path:
66
+ cfg = self.config
67
+ name = cfg.raw_path.stem
68
+ # Include patch_size and phase_mode in cache name to ensure cache invalidation
69
+ # when these parameters change. Also add 'v2' to invalidate old caches with wrong normalization.
70
+ ph, pw = cfg.patch_size
71
+ cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt"
72
+ return cfg.cache_dir / cache_name # type: ignore[operator]
73
+
74
+ def _load_raw(self) -> Any:
75
+ with self.config.raw_path.open("rb") as handle:
76
+ return pickle.load(handle)
77
+
78
+ def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor:
79
+ """Normalize a single sample by its own RMS."""
80
+ rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
81
+ return tensor / rms.to(tensor.dtype)
82
+
83
+ def _estimate_speed(self, pos_meta: Any, dt_meta: Any, index: int) -> Optional[float]:
84
+ if pos_meta is None or dt_meta is None:
85
+ return None
86
+
87
+ def _extract_dt(meta: Any) -> Optional[float]:
88
+ if isinstance(meta, (list, tuple)) and index < len(meta):
89
+ return float(meta[index])
90
+ if isinstance(meta, np.ndarray) and meta.ndim == 1 and index < meta.shape[0]:
91
+ return float(meta[index])
92
+ if isinstance(meta, (int, float)):
93
+ return float(meta)
94
+ return None
95
+
96
+ dt = _extract_dt(dt_meta)
97
+ if dt is None or dt <= 0:
98
+ return None
99
+
100
+ def _extract_positions(meta: Any) -> Optional[List[np.ndarray]]:
101
+ if isinstance(meta, (list, tuple)):
102
+ if index >= len(meta):
103
+ return None
104
+ candidate = meta[index]
105
+ elif isinstance(meta, np.ndarray):
106
+ if meta.ndim < 2 or index >= meta.shape[0]:
107
+ return None
108
+ candidate = meta[index]
109
+ else:
110
+ return None
111
+ if isinstance(candidate, list):
112
+ trajs = [np.asarray(traj) for traj in candidate if isinstance(traj, (list, tuple, np.ndarray))]
113
+ else:
114
+ trajs = [np.asarray(candidate)]
115
+ usable = [traj for traj in trajs if traj.ndim >= 2 and traj.shape[0] >= 2]
116
+ return usable if usable else None
117
+
118
+ trajectories = _extract_positions(pos_meta)
119
+ if not trajectories:
120
+ return None
121
+
122
+ speeds: List[float] = []
123
+ for traj in trajectories:
124
+ diffs = np.linalg.norm(np.diff(traj[:, :2], axis=0), axis=1)
125
+ if diffs.size == 0:
126
+ continue
127
+ velocity = diffs / dt
128
+ finite = velocity[np.isfinite(velocity)]
129
+ if finite.size > 0:
130
+ speeds.append(float(finite.mean()))
131
+ if not speeds:
132
+ return None
133
+ return float(np.mean(speeds))
134
+
135
+ def _build_samples(self) -> tuple[List[torch.Tensor], List[Optional[float]]]:
136
+ payload = self._load_raw()
137
+ pos_meta = payload.get("pos") if isinstance(payload, dict) else None
138
+ dt_meta = payload.get("dt") if isinstance(payload, dict) else None
139
+ channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload
140
+ channel_tensor = torch.as_tensor(channel, dtype=torch.complex64)
141
+ if channel_tensor.ndim == 3:
142
+ channel_tensor = channel_tensor.unsqueeze(0)
143
+ if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps:
144
+ channel_tensor = channel_tensor[:, : self.config.max_time_steps]
145
+ processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage))
146
+ samples: List[torch.Tensor] = []
147
+ avg_speeds: List[Optional[float]] = []
148
+ for idx, seq in enumerate(channel_tensor):
149
+ ad = processor.forward(seq)
150
+ truncated, _ = processor.truncate_delay_bins(ad)
151
+ samples.append(truncated)
152
+ avg_speeds.append(self._estimate_speed(pos_meta, dt_meta, idx))
153
+
154
+ # Apply normalization after collecting all samples
155
+ if self.config.normalize == "per_sample_rms":
156
+ samples = [self._normalize_sample(s) for s in samples]
157
+ elif self.config.normalize == "global_rms":
158
+ # Compute global RMS across all samples
159
+ total_sum_sq = 0.0
160
+ total_count = 0
161
+ for s in samples:
162
+ s_real = s.real.float()
163
+ s_imag = s.imag.float()
164
+ total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item()
165
+ total_count += s_real.numel()
166
+ if total_count > 0:
167
+ global_rms = math.sqrt(total_sum_sq / total_count)
168
+ global_rms = max(global_rms, 1e-8)
169
+ samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples]
170
+
171
+ return samples, avg_speeds
172
+
173
+ def _apply_noise(self) -> None:
174
+ if self.config.noise_seed is not None:
175
+ torch.manual_seed(int(self.config.noise_seed))
176
+ noisy: List[torch.Tensor] = []
177
+ snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0)
178
+ for sample in self.samples:
179
+ real = sample.real.float()
180
+ imag = sample.imag.float()
181
+ power = (real.square() + imag.square()).mean().item()
182
+ if power <= 0:
183
+ noisy.append(sample)
184
+ continue
185
+ noise_var = power / snr_lin
186
+ std = math.sqrt(noise_var / 2.0)
187
+ noise_real = torch.randn_like(real) * std
188
+ noise_imag = torch.randn_like(imag) * std
189
+ noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype))
190
+ noisy.append((sample + noise).to(sample.dtype))
191
+ self.samples = noisy
192
+
193
+ def __len__(self) -> int:
194
+ return len(self.samples)
195
+
196
+ def __getitem__(self, index: int) -> Dict[str, Any]:
197
+ sample = self.samples[index]
198
+ tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size)
199
+ tokens = tokens.squeeze(0)
200
+ base_mask = base_mask.squeeze(0)
201
+ T, N, M = sample.shape
202
+ ph, pw = self.config.patch_size
203
+ H = N // ph
204
+ W = M // pw
205
+ shape = torch.tensor([T, H, W], dtype=torch.long)
206
+ avg_speed = self.avg_speed_mps[index] if index < len(self.avg_speed_mps) else None
207
+ payload: Dict[str, Any] = {
208
+ "sequence": sample,
209
+ "tokens": tokens,
210
+ "base_mask": base_mask,
211
+ "shape": shape,
212
+ }
213
+ if avg_speed is not None:
214
+ payload["avg_speed"] = torch.tensor(avg_speed, dtype=torch.float32)
215
+ return payload
216
+
217
+
218
+ def load_adseq_dataset(
219
+ data_path: str | Path,
220
+ keep_percentage: float = 0.25,
221
+ normalize: str = "global_rms",
222
+ cache_dir: Optional[str | Path] = "cache",
223
+ use_cache: bool = True,
224
+ overwrite_cache: bool = False,
225
+ logger: Optional[Any] = None,
226
+ snr_db: Optional[float] = None,
227
+ noise_seed: Optional[int] = None,
228
+ max_time_steps: Optional[int] = None,
229
+ ) -> AngleDelaySequenceDataset:
230
+ cfg = AngleDelayDatasetConfig(
231
+ raw_path=Path(data_path),
232
+ keep_percentage=keep_percentage,
233
+ normalize=normalize,
234
+ cache_dir=None if cache_dir is None else Path(cache_dir),
235
+ use_cache=use_cache,
236
+ overwrite_cache=overwrite_cache,
237
+ snr_db=snr_db,
238
+ noise_seed=noise_seed,
239
+ max_time_steps=max_time_steps,
240
+ )
241
+ return AngleDelaySequenceDataset(cfg, logger=logger)
LWMTemporal/data/deepmimo_adapter.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
7
+
8
+ import numpy as np # type: ignore[import]
9
+
10
+ _LOGGER = logging.getLogger(__name__)
11
+
12
+ try: # pragma: no cover - optional dependency
13
+ import deepmimo # type: ignore[import]
14
+ from deepmimo import config as deepmimo_config # type: ignore[import]
15
+
16
+ _HAS_DEEPMIMO_V4 = True
17
+ except Exception: # pragma: no cover - DeepMIMO v4 not installed
18
+ deepmimo = None # type: ignore[assignment]
19
+ deepmimo_config = None # type: ignore[assignment]
20
+ _HAS_DEEPMIMO_V4 = False
21
+
22
+ try: # pragma: no cover - legacy fallback
23
+ from input_preprocess import DeepMIMO_data_gen as _legacy_data_gen # type: ignore[import]
24
+ except Exception: # pragma: no cover - legacy loader unavailable
25
+ _legacy_data_gen = None
26
+
27
+ ArrayLike = Union[np.ndarray, "np.typing.NDArray[np.floating[Any]]"]
28
+
29
+
30
+ @dataclass
31
+ class _PathTable:
32
+ power: np.ndarray
33
+ phase: np.ndarray
34
+ delay: np.ndarray
35
+ aoa_az: np.ndarray
36
+ aoa_el: np.ndarray
37
+ aod_az: np.ndarray
38
+ aod_el: np.ndarray
39
+ interactions: np.ndarray
40
+ num_paths: np.ndarray
41
+ los_user: np.ndarray
42
+ locations: np.ndarray
43
+
44
+
45
+ class _LazyPathAccessor:
46
+ """Lazy view over per-user path dictionaries compatible with v3 interface."""
47
+
48
+ def __init__(self, data: _PathTable) -> None:
49
+ self._data = data
50
+
51
+ def __len__(self) -> int:
52
+ return int(self._data.num_paths.shape[0])
53
+
54
+ def __getitem__(self, index: Union[int, slice, Sequence[int]]) -> Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]:
55
+ if isinstance(index, slice):
56
+ return [self[i] for i in range(*index.indices(len(self)))]
57
+ if isinstance(index, Sequence) and not isinstance(index, (str, bytes)):
58
+ return [self[int(i)] for i in index]
59
+ idx = int(index)
60
+ count = int(self._data.num_paths[idx])
61
+ if count <= 0:
62
+ empty = np.empty((0,), dtype=np.float32)
63
+ return {
64
+ "num_paths": 0,
65
+ "DoD_theta": empty,
66
+ "DoD_phi": empty,
67
+ "DoA_theta": empty,
68
+ "DoA_phi": empty,
69
+ "phase": empty,
70
+ "ToA": empty,
71
+ "power": empty,
72
+ "LoS": np.empty((0,), dtype=np.int32),
73
+ }
74
+ sl = slice(0, count)
75
+ interactions = np.asarray(self._data.interactions[idx, sl])
76
+ los_per_path = np.where(np.isnan(interactions), 0, (interactions == 0).astype(np.int32))
77
+ return {
78
+ "num_paths": count,
79
+ "DoD_theta": np.asarray(self._data.aod_el[idx, sl]),
80
+ "DoD_phi": np.asarray(self._data.aod_az[idx, sl]),
81
+ "DoA_theta": np.asarray(self._data.aoa_el[idx, sl]),
82
+ "DoA_phi": np.asarray(self._data.aoa_az[idx, sl]),
83
+ "phase": np.asarray(self._data.phase[idx, sl]),
84
+ "ToA": np.asarray(self._data.delay[idx, sl]),
85
+ "power": np.asarray(self._data.power[idx, sl]),
86
+ "LoS": los_per_path.astype(np.int32),
87
+ }
88
+
89
+
90
+ def _cast(array: ArrayLike, dtype: np.dtype[Any]) -> np.ndarray:
91
+ arr = np.asarray(array)
92
+ if arr.dtype == dtype:
93
+ return arr
94
+ return arr.astype(dtype, copy=True)
95
+
96
+
97
+ def _load_v4_dataset(
98
+ scenario: str,
99
+ *,
100
+ scenarios_dir: Optional[Path],
101
+ load_params: Optional[Dict[str, Any]],
102
+ max_paths: Optional[int],
103
+ array_dtype: np.dtype[Any],
104
+ logger: Optional[logging.Logger],
105
+ ) -> Dict[str, Any]:
106
+ if not _HAS_DEEPMIMO_V4:
107
+ raise RuntimeError("DeepMIMO v4 package is not available in the current environment")
108
+
109
+ if scenarios_dir is not None:
110
+ deepmimo_config.set("scenarios_folder", str(scenarios_dir)) # type: ignore[attr-defined]
111
+
112
+ params = dict(load_params or {})
113
+ if max_paths is not None:
114
+ params.setdefault("max_paths", int(max_paths))
115
+
116
+ dataset = deepmimo.load(scenario, **params) # type: ignore[call-arg]
117
+ logger = logger or _LOGGER
118
+ logger.info(
119
+ "Loaded DeepMIMO v4 scenario '%s' with %s users and %s max paths", # pragma: no cover - logging
120
+ scenario,
121
+ getattr(dataset, "n_ue", "unknown"),
122
+ params.get("max_paths", "default"),
123
+ )
124
+
125
+ num_paths_raw = np.asarray(dataset.num_paths)
126
+ tx_axis: Optional[int] = None
127
+ if num_paths_raw.ndim > 1 and num_paths_raw.shape[0] > 1:
128
+ axes = tuple(range(1, num_paths_raw.ndim))
129
+ scores = num_paths_raw.sum(axis=axes)
130
+ tx_axis = int(np.argmax(scores))
131
+
132
+ def _select_tx(arr: Any, dtype: Optional[np.dtype[Any]] = None) -> np.ndarray:
133
+ out = np.asarray(arr)
134
+ if out.dtype == object:
135
+ out = np.stack([np.asarray(v) for v in out], axis=0)
136
+ if tx_axis is not None and out.ndim >= 1 and out.shape[0] == num_paths_raw.shape[0]:
137
+ out = out[tx_axis]
138
+ if dtype is not None:
139
+ out = out.astype(dtype, copy=False)
140
+ return out
141
+
142
+ num_paths = _select_tx(num_paths_raw, dtype=np.int32).reshape(-1)
143
+ power = _select_tx(dataset.power, dtype=array_dtype)
144
+ phase = _select_tx(dataset.phase, dtype=array_dtype)
145
+ delay = _select_tx(dataset.delay, dtype=array_dtype)
146
+ aoa_az = _select_tx(dataset.aoa_az, dtype=array_dtype)
147
+ aoa_el = _select_tx(dataset.aoa_el, dtype=array_dtype)
148
+ aod_az = _select_tx(dataset.aod_az, dtype=array_dtype)
149
+ aod_el = _select_tx(dataset.aod_el, dtype=array_dtype)
150
+ interactions = _select_tx(dataset.inter, dtype=array_dtype)
151
+ los_raw = getattr(dataset, "los", None)
152
+ if los_raw is None:
153
+ los_selected = np.zeros_like(num_paths, dtype=np.int8)
154
+ else:
155
+ los_selected = _select_tx(los_raw, dtype=np.int8)
156
+ locations = _select_tx(dataset.rx_pos, dtype=np.float32)
157
+ if locations.ndim == 1:
158
+ if locations.size % 3 == 0:
159
+ locations = locations.reshape(-1, 3)
160
+ else:
161
+ locations = locations.reshape(-1, 1)
162
+ if locations.ndim > 2:
163
+ locations = locations.reshape(locations.shape[0], -1)
164
+
165
+ path_table = _PathTable(
166
+ power=power,
167
+ phase=phase,
168
+ delay=delay,
169
+ aoa_az=aoa_az,
170
+ aoa_el=aoa_el,
171
+ aod_az=aod_az,
172
+ aod_el=aod_el,
173
+ interactions=interactions,
174
+ num_paths=num_paths,
175
+ los_user=los_selected.reshape(-1),
176
+ locations=locations,
177
+ )
178
+
179
+ # Help GC release original dataset arrays early
180
+ del dataset
181
+
182
+ user_payload = {
183
+ "paths": _LazyPathAccessor(path_table),
184
+ "LoS": path_table.los_user,
185
+ "location": path_table.locations,
186
+ }
187
+
188
+ return {
189
+ "user": user_payload,
190
+ "_path_data": path_table,
191
+ "_source": "deepmimo_v4",
192
+ }
193
+
194
+
195
+ def load_deepmimo_user_data(
196
+ scenario: str,
197
+ *,
198
+ scenarios_dir: Optional[Path] = None,
199
+ load_params: Optional[Dict[str, Any]] = None,
200
+ max_paths: Optional[int] = None,
201
+ array_dtype: np.dtype[Any] = np.float32,
202
+ logger: Optional[logging.Logger] = None,
203
+ ) -> Dict[str, Any]:
204
+ """Load DeepMIMO scenario data in a form compatible with legacy utilities.
205
+
206
+ The returned dictionary mimics the structure produced by DeepMIMO v3's
207
+ ``DeepMIMO_data_gen`` so downstream utilities (e.g., dynamic scenario
208
+ generation) can operate without modification. When DeepMIMO v4 is not
209
+ available, the function falls back to the legacy generator if present.
210
+ """
211
+
212
+ if _HAS_DEEPMIMO_V4:
213
+ return _load_v4_dataset(
214
+ scenario,
215
+ scenarios_dir=scenarios_dir,
216
+ load_params=load_params,
217
+ max_paths=max_paths,
218
+ array_dtype=array_dtype,
219
+ logger=logger,
220
+ )
221
+
222
+ if _legacy_data_gen is not None:
223
+ raise RuntimeError(
224
+ "DeepMIMO v4 is not installed. The repository still includes the legacy "
225
+ "DeepMIMO_data_gen interface, but integration parameters must be provided "
226
+ "explicitly. Please migrate to the official DeepMIMO package or invoke "
227
+ "DeepMIMO_data_gen directly from your own tooling."
228
+ )
229
+
230
+ raise RuntimeError(
231
+ "Neither DeepMIMO v4 nor the legacy DeepMIMO_data_gen function is available. "
232
+ "Please install the DeepMIMO package or provide the legacy generator."
233
+ )
234
+
235
+
236
+ __all__ = ["load_deepmimo_user_data"]
LWMTemporal/data/scenario_generation.py ADDED
@@ -0,0 +1,1161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import logging
5
+ import warnings
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
9
+
10
+ import matplotlib.pyplot as plt # type: ignore[import]
11
+ from matplotlib.lines import Line2D # type: ignore[attr-defined]
12
+ import networkx as nx # type: ignore[import]
13
+ import numpy as np # type: ignore[import]
14
+ from scipy.spatial import KDTree # type: ignore[import]
15
+
16
+ from .deepmimo_adapter import load_deepmimo_user_data
17
+
18
+ try: # pragma: no cover - DeepMIMO may not be installed in all environments
19
+ import DeepMIMOv3.consts as c # type: ignore
20
+ except Exception: # pragma: no cover - maintain compatibility when DeepMIMO is missing
21
+ class _C: # noqa: D401 - minimal stub to satisfy typing
22
+ """Fallback constants when DeepMIMOv3 is not available."""
23
+
24
+ c = _C()
25
+ c.PARAMSET_OFDM = "OFDM"
26
+ c.PARAMSET_OFDM_BW = "bandwidth"
27
+ c.PARAMSET_OFDM_BW_MULT = 1e9
28
+ c.PARAMSET_OFDM_SC_SAMP = "selected_subcarriers"
29
+ c.PARAMSET_OFDM_SC_NUM = "subcarriers"
30
+ c.PARAMSET_OFDM_LPF = "LPF"
31
+ c.PARAMSET_FDTD = "FDTD"
32
+ c.PARAMSET_ANT_SHAPE = "shape"
33
+ c.PARAMSET_ANT_SPACING = "spacing"
34
+ c.PARAMSET_ANT_ROTATION = "rotation"
35
+ c.PARAMSET_ANT_FOV = "FoV"
36
+ c.PARAMSET_ANT_RAD_PAT = "radiation_pattern"
37
+ c.OUT_PATH_NUM = "num_paths"
38
+ c.OUT_PATH_DOD_THETA = "DoD_theta"
39
+ c.OUT_PATH_DOD_PHI = "DoD_phi"
40
+ c.OUT_PATH_DOA_THETA = "DoA_theta"
41
+ c.OUT_PATH_DOA_PHI = "DoA_phi"
42
+ c.OUT_PATH_PHASE = "phase"
43
+ c.OUT_PATH_TOA = "ToA"
44
+ c.OUT_PATH_RX_POW = "power"
45
+ c.OUT_PATH_DOP_VEL = "Doppler_vel"
46
+ c.OUT_PATH_DOP_ACC = "Doppler_acc"
47
+ c.PARAMSET_DOPPLER_EN = "Doppler"
48
+ c.PARAMSET_SCENARIO_PARAMS = "scenario_params"
49
+ c.PARAMSET_SCENARIO_PARAMS_DOPPLER_EN = "Doppler_enabled"
50
+ c.PARAMSET_SCENARIO_PARAMS_CF = "carrier_freq"
51
+ c.LIGHTSPEED = 3e8
52
+
53
+
54
+ _LOGGER = logging.getLogger(__name__)
55
+
56
+
57
+ @dataclass
58
+ class AntennaArrayConfig:
59
+ """Configuration describing the transmit and receive array geometry."""
60
+
61
+ tx_horizontal: int = 32
62
+ tx_vertical: int = 1
63
+ rx_horizontal: int = 1
64
+ rx_vertical: int = 1
65
+ subcarriers: int = 32
66
+ spacing: float = 0.5
67
+ tx_rotation: Tuple[float, float, float] = (0.0, 0.0, -135.0)
68
+ rx_rotation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
69
+ field_of_view: Tuple[float, float] = (360.0, 180.0)
70
+
71
+ def total_tx_elements(self) -> int:
72
+ return int(self.tx_horizontal * self.tx_vertical)
73
+
74
+ def tx_shape(self) -> np.ndarray:
75
+ return np.asarray([self.tx_horizontal, self.tx_vertical, 1])
76
+
77
+ def rx_shape(self) -> np.ndarray:
78
+ return np.asarray([self.rx_horizontal, self.rx_vertical, 1])
79
+
80
+
81
+ @dataclass
82
+ class TrafficConfig:
83
+ """Parameters controlling vehicle and pedestrian traffic synthesis."""
84
+
85
+ num_vehicles: int = 50
86
+ num_pedestrians: int = 10
87
+ vehicle_speed_range: Tuple[float, float] = (5 / 3.6, 60 / 3.6)
88
+ pedestrian_speed_range: Tuple[float, float] = (0.5, 2.0)
89
+ turn_probability: float = 0.1
90
+ max_attempts: int = 300
91
+ pedestrian_angle_std: float = 0.1
92
+
93
+
94
+ @dataclass
95
+ class GridConfig:
96
+ """Describe the road grid geometry to be inferred from DeepMIMO positions."""
97
+
98
+ road_width: float = 6.0
99
+ road_center_spacing: float = 25.0
100
+ step_size: float = 1.0
101
+ auto_step_size: bool = True
102
+
103
+
104
+ @dataclass
105
+ class ScenarioSamplingConfig:
106
+ """Temporal sampling configuration for the dynamic scenario."""
107
+
108
+ time_steps: int = 20
109
+ continuous_length: Optional[int] = None
110
+ sample_dt: float = 1e-3
111
+ continuous_mode: bool = True
112
+
113
+
114
+ @dataclass
115
+ class ScenarioGenerationConfig:
116
+ """All inputs required to synthesize a dynamic DeepMIMO scenario."""
117
+
118
+ scenario: str
119
+ antenna: AntennaArrayConfig = field(default_factory=AntennaArrayConfig)
120
+ sampling: ScenarioSamplingConfig = field(default_factory=ScenarioSamplingConfig)
121
+ grid: GridConfig = field(default_factory=GridConfig)
122
+ traffic: TrafficConfig = field(default_factory=TrafficConfig)
123
+ carrier_frequency_hz: float = 3.5e9
124
+ output_dir: Path = Path("examples/data")
125
+ full_output_dir: Path = Path("examples/full_data")
126
+ figures_dir: Optional[Path] = Path("figs")
127
+ export_environment_plot: bool = True
128
+ rng_seed: Optional[int] = None
129
+ scenarios_dir: Optional[Path] = None
130
+ deepmimo_max_paths: Optional[int] = 6
131
+ deepmimo_load_params: Optional[Dict[str, Any]] = None
132
+ deepmimo_array_dtype: np.dtype[Any] = np.float32
133
+
134
+
135
+ @dataclasses.dataclass
136
+ class ScenarioGenerationResult:
137
+ """Container for the generated dynamic scenario payload."""
138
+
139
+ payload: Dict[str, Any]
140
+ output_path: Path
141
+ full_output_path: Path
142
+ generated: bool
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # Geometry helpers
147
+ # ---------------------------------------------------------------------------
148
+
149
+ def infer_grid_step(positions: np.ndarray, max_samples: int = 20000) -> float:
150
+ """Infer the base grid spacing from the DeepMIMO user positions."""
151
+
152
+ unique_xy = np.unique(positions[:, :2], axis=0)
153
+ if unique_xy.shape[0] < 2:
154
+ return 1.0
155
+ if unique_xy.shape[0] > max_samples:
156
+ rng = np.random.default_rng()
157
+ unique_xy = unique_xy[rng.choice(unique_xy.shape[0], size=max_samples, replace=False)]
158
+ tree = KDTree(unique_xy)
159
+ distances, _ = tree.query(unique_xy, k=2)
160
+ nearest = distances[:, 1]
161
+ nearest = nearest[nearest > 0]
162
+ if nearest.size == 0:
163
+ return 1.0
164
+ min_distance = float(np.min(nearest))
165
+ tolerance = min_distance * 0.1
166
+ count_at_min = np.sum(np.abs(nearest - min_distance) < tolerance)
167
+ if count_at_min >= 0.1 * nearest.size:
168
+ return min_distance
169
+ lo, hi = np.percentile(nearest, [5, 30])
170
+ mask = (nearest >= lo) & (nearest <= hi)
171
+ candidate = nearest[mask]
172
+ return float(np.median(candidate if candidate.size else nearest))
173
+
174
+
175
+ def filter_road_positions(valid_positions: np.ndarray, road_width: float, spacing: float) -> Tuple[np.ndarray, Dict[Tuple[float, float, float], Tuple[Tuple[int, int], str]]]:
176
+ road_positions: List[np.ndarray] = []
177
+ lane_info: Dict[Tuple[float, float, float], Tuple[Tuple[int, int], str]] = {}
178
+ half = road_width / 2.0
179
+ for pos in valid_positions:
180
+ x, y, _ = pos
181
+ cx = round(x / spacing) * spacing
182
+ cy = round(y / spacing) * spacing
183
+ dx, dy = x - cx, y - cy
184
+ on_vertical = abs(dx) < half
185
+ on_horizontal = abs(dy) < half
186
+ key = tuple(pos)
187
+ if on_vertical and not on_horizontal:
188
+ direction = (0, 1) if dx >= 0 else (0, -1)
189
+ lane_info[key] = (direction, "vertical")
190
+ road_positions.append(pos)
191
+ elif on_horizontal and not on_vertical:
192
+ direction = (1, 0) if dy < 0 else (-1, 0)
193
+ lane_info[key] = (direction, "horizontal")
194
+ road_positions.append(pos)
195
+ elif on_vertical and on_horizontal:
196
+ direction = (0, 1) if dx >= 0 else (0, -1)
197
+ lane_info[key] = (direction, "intersection")
198
+ road_positions.append(pos)
199
+ return np.asarray(road_positions), lane_info
200
+
201
+
202
+ def create_grid_road_network(road_positions: np.ndarray, lane_info: Dict[Tuple[float, float, float], Tuple[Tuple[int, int], str]], step_size: float) -> nx.DiGraph:
203
+ graph: nx.DiGraph = nx.DiGraph()
204
+ pos_dict = {tuple(pos): idx for idx, pos in enumerate(road_positions)}
205
+ for pos, idx in pos_dict.items():
206
+ if pos not in lane_info:
207
+ continue
208
+ direction, lane_type = lane_info[pos]
209
+ graph.add_node(idx, pos=np.array(pos), direction=direction, lane_type=lane_type)
210
+ tree = KDTree(road_positions)
211
+ for idx, pos in enumerate(road_positions):
212
+ if idx not in graph.nodes:
213
+ continue
214
+ neighbors = tree.query_ball_point(pos, r=step_size + 0.1)
215
+ for nb in neighbors:
216
+ if nb == idx or nb not in graph.nodes:
217
+ continue
218
+ target = road_positions[nb]
219
+ distance = np.linalg.norm(pos - target)
220
+ if not np.isclose(distance, step_size, atol=0.1):
221
+ continue
222
+ move_dir = (int(np.sign(target[0] - pos[0])), int(np.sign(target[1] - pos[1])))
223
+ lane_type = graph.nodes[idx].get("lane_type", "vertical")
224
+ if lane_type == "intersection":
225
+ if move_dir in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
226
+ graph.add_edge(idx, nb, weight=distance)
227
+ elif move_dir == graph.nodes[idx]["direction"]:
228
+ graph.add_edge(idx, nb, weight=distance)
229
+ return graph
230
+
231
+
232
+ def _fallback_anywalk(road_positions: np.ndarray, step_size: float, length: int, start_idx: Optional[int] = None) -> np.ndarray:
233
+ if start_idx is None:
234
+ start_idx = np.random.randint(0, len(road_positions))
235
+ pos = road_positions[start_idx]
236
+ trajectory = [pos]
237
+ tree = KDTree(road_positions)
238
+ for _ in range(length - 1):
239
+ idxs = tree.query_ball_point(pos, r=step_size * 1.1)
240
+ candidates = []
241
+ for idx in idxs:
242
+ if np.allclose(road_positions[idx], pos):
243
+ continue
244
+ d = np.linalg.norm(road_positions[idx] - pos)
245
+ if np.isclose(d, step_size, atol=0.15 * max(1.0, step_size)):
246
+ candidates.append(idx)
247
+ if not candidates:
248
+ idxs = tree.query_ball_point(pos, r=max(1.5 * step_size, step_size + 0.5))
249
+ if not idxs:
250
+ break
251
+ candidate_idx = min(idxs, key=lambda j: np.linalg.norm(road_positions[j] - pos))
252
+ else:
253
+ candidate_idx = np.random.choice(candidates)
254
+ pos = road_positions[candidate_idx]
255
+ trajectory.append(pos)
256
+ return np.asarray(trajectory)
257
+
258
+
259
+ def generate_smooth_grid_trajectory(graph: nx.DiGraph, road_positions: np.ndarray, turn_probability: float, sequence_length: int = 12, start_node: Optional[int] = None) -> np.ndarray:
260
+ if start_node is None:
261
+ nodes = list(graph.nodes)
262
+ if not nodes:
263
+ return np.empty((0, 3))
264
+ start_node = np.random.choice(nodes)
265
+ trajectory = [road_positions[start_node]]
266
+ current = start_node
267
+ previous: Optional[int] = None
268
+ for _ in range(sequence_length - 1):
269
+ if current not in graph:
270
+ break
271
+ neighbors = list(graph.neighbors(current))
272
+ if previous in neighbors:
273
+ neighbors.remove(previous)
274
+ if not neighbors:
275
+ remaining = sequence_length - len(trajectory)
276
+ if remaining > 0:
277
+ trajectory.extend([road_positions[current]] * remaining)
278
+ break
279
+ node_data = graph.nodes[current]
280
+ lane_type = node_data.get("lane_type", "vertical")
281
+ if lane_type == "intersection" and previous is not None:
282
+ prev_pos = np.asarray(graph.nodes[previous]["pos"])
283
+ curr_pos = np.asarray(node_data["pos"])
284
+ incoming = (
285
+ int(np.sign(curr_pos[0] - prev_pos[0])),
286
+ int(np.sign(curr_pos[1] - prev_pos[1])),
287
+ )
288
+ default_direction = incoming
289
+ else:
290
+ default_direction = node_data.get("direction", None)
291
+ position = np.asarray(node_data["pos"])
292
+ forward_neighbors: List[Tuple[int, float]] = []
293
+ turn_neighbors: List[Tuple[int, float]] = []
294
+ for nb in neighbors:
295
+ neighbor_pos = np.asarray(graph.nodes[nb]["pos"])
296
+ move_dir = (
297
+ int(np.sign(neighbor_pos[0] - position[0])),
298
+ int(np.sign(neighbor_pos[1] - position[1])),
299
+ )
300
+ distance = float(np.linalg.norm(neighbor_pos - position))
301
+ if move_dir == default_direction:
302
+ forward_neighbors.append((nb, distance))
303
+ else:
304
+ turn_neighbors.append((nb, distance))
305
+ if lane_type == "intersection":
306
+ r = np.random.rand()
307
+ if forward_neighbors and r > turn_probability:
308
+ nxt = min(forward_neighbors, key=lambda item: item[1])[0]
309
+ elif turn_neighbors and r < turn_probability:
310
+ nxt = min(turn_neighbors, key=lambda item: item[1])[0]
311
+ elif forward_neighbors:
312
+ nxt = min(forward_neighbors, key=lambda item: item[1])[0]
313
+ elif turn_neighbors:
314
+ nxt = min(turn_neighbors, key=lambda item: item[1])[0]
315
+ else:
316
+ remaining = sequence_length - len(trajectory)
317
+ if remaining > 0:
318
+ trajectory.extend([road_positions[current]] * remaining)
319
+ break
320
+ else:
321
+ if forward_neighbors:
322
+ nxt = min(forward_neighbors, key=lambda item: item[1])[0]
323
+ else:
324
+ remaining = sequence_length - len(trajectory)
325
+ if remaining > 0:
326
+ trajectory.extend([road_positions[current]] * remaining)
327
+ break
328
+ trajectory.append(road_positions[nxt])
329
+ previous, current = current, nxt
330
+ return np.asarray(trajectory)
331
+
332
+
333
+ def generate_n_smooth_grid_trajectories(graph: nx.DiGraph, road_positions: np.ndarray, n: int, sequence_length: int = 12, turn_probability: float = 0.15, max_attempts: int = 2000, step_size: float = 1.0) -> List[np.ndarray]:
334
+ trajectories: List[np.ndarray] = []
335
+ attempts = 0
336
+ hard_cap = n * max_attempts
337
+ tree = KDTree(road_positions)
338
+ min_x, min_y = np.min(road_positions[:, 0]), np.min(road_positions[:, 1])
339
+ max_x, max_y = np.max(road_positions[:, 0]), np.max(road_positions[:, 1])
340
+ while len(trajectories) < n and attempts < hard_cap:
341
+ rand = [np.random.uniform(min_x, max_x), np.random.uniform(min_y, max_y), 0]
342
+ _, start_idx = tree.query(rand)
343
+ traj = generate_smooth_grid_trajectory(graph, road_positions, turn_probability, sequence_length, start_node=start_idx)
344
+ if traj.shape[0] < sequence_length:
345
+ traj = _fallback_anywalk(road_positions, step_size, sequence_length, start_idx=start_idx)
346
+ if traj.shape[0] >= sequence_length:
347
+ trajectories.append(traj[:sequence_length])
348
+ attempts += 1
349
+ if len(trajectories) < n:
350
+ _LOGGER.warning("Only generated %d/%d vehicle trajectories", len(trajectories), n)
351
+ return trajectories
352
+
353
+
354
+ def generate_pedestrian_trajectory(valid_positions: np.ndarray, sequence_length: int = 10, step_size: float = 2.5, angle_std: float = 0.1, start: Optional[np.ndarray] = None) -> np.ndarray:
355
+ tree = KDTree(valid_positions)
356
+ if start is None:
357
+ start = valid_positions[np.random.choice(len(valid_positions))]
358
+ trajectory = [start]
359
+ angle = np.random.uniform(0, 2 * np.pi)
360
+ current = start
361
+ for _ in range(sequence_length - 1):
362
+ angle += np.random.normal(0, angle_std)
363
+ candidate = current + np.array([step_size * np.cos(angle), step_size * np.sin(angle), 0])
364
+ _, idx = tree.query(candidate)
365
+ nxt = valid_positions[idx]
366
+ trajectory.append(nxt)
367
+ current = nxt
368
+ return np.asarray(trajectory)
369
+
370
+
371
+ def generate_n_pedestrian_trajectories(valid_positions: np.ndarray, n: int, sequence_length: int = 10, step_size: float = 2.5, angle_std: float = 0.1) -> List[np.ndarray]:
372
+ return [generate_pedestrian_trajectory(valid_positions, sequence_length, step_size, angle_std) for _ in range(n)]
373
+
374
+
375
+ def get_trajectory_indices(trajectories: Sequence[np.ndarray], pos_total: np.ndarray) -> List[List[int]]:
376
+ def _pos_key(pos: Any) -> Tuple[float, ...]:
377
+ arr = np.asarray(pos)
378
+ if arr.ndim == 0:
379
+ return (float(arr),)
380
+ flat = arr.reshape(-1)
381
+ return tuple(float(x) for x in flat[:3])
382
+
383
+ pos_to_idx = {_pos_key(pos): i for i, pos in enumerate(pos_total)}
384
+ indices: List[List[int]] = []
385
+ for traj in trajectories:
386
+ indices.append([pos_to_idx.get(_pos_key(point), -1) for point in traj])
387
+ return indices
388
+
389
+
390
+ def sample_continuous_along_polyline(
391
+ traj_pos: np.ndarray,
392
+ idxs: Sequence[int],
393
+ speed: float,
394
+ dt: float,
395
+ n_samples: int,
396
+ speed_profile: Optional[np.ndarray] = None,
397
+ ) -> Tuple[np.ndarray, List[Tuple[int, int]], np.ndarray, np.ndarray]:
398
+ traj_pos = np.asarray(traj_pos, float)
399
+ if traj_pos.shape[0] < 2:
400
+ p = np.repeat(traj_pos[:1], n_samples, axis=0)
401
+ pairs = [(idxs[0], idxs[0])] * n_samples
402
+ alphas = np.zeros(n_samples, float)
403
+ vdirs = np.zeros((n_samples, 2), float)
404
+ return p.astype(np.float32), pairs, alphas.astype(np.float32), vdirs.astype(np.float32)
405
+ segment_vectors = traj_pos[1:] - traj_pos[:-1]
406
+ segment_lengths = np.linalg.norm(segment_vectors[:, :2], axis=1)
407
+ cumulative = np.zeros(len(segment_lengths) + 1, float)
408
+ cumulative[1:] = np.cumsum(segment_lengths)
409
+ if cumulative[-1] <= 1e-12:
410
+ p = np.repeat(traj_pos[:1], n_samples, axis=0)
411
+ pairs = [(idxs[0], idxs[0])] * n_samples
412
+ alphas = np.zeros(n_samples, float)
413
+ vdirs = np.zeros((n_samples, 2), float)
414
+ return p.astype(np.float32), pairs, alphas.astype(np.float32), vdirs.astype(np.float32)
415
+ if speed_profile is not None:
416
+ speeds = np.asarray(speed_profile, float)
417
+ if speeds.shape[0] != n_samples:
418
+ raise ValueError("speed_profile must match number of samples")
419
+ ds = np.zeros(n_samples, float)
420
+ for k in range(1, n_samples):
421
+ ds[k] = ds[k - 1] + speeds[k - 1] * dt
422
+ ds = np.clip(ds, 0.0, max(cumulative[-1] - 1e-12, 0.0))
423
+ else:
424
+ ds = speed * dt * np.arange(n_samples, dtype=float)
425
+ ds = np.clip(ds, 0.0, max(cumulative[-1] - 1e-12, 0.0))
426
+ pos_c = np.zeros((n_samples, 3), float)
427
+ alphas = np.zeros(n_samples, float)
428
+ vdirs = np.zeros((n_samples, 2), float)
429
+ pairs: List[Tuple[int, int]] = []
430
+ for idx, distance in enumerate(ds):
431
+ seg_idx = int(np.searchsorted(cumulative, distance, side="right") - 1)
432
+ seg_idx = min(max(seg_idx, 0), len(segment_lengths) - 1)
433
+ seg_start = cumulative[seg_idx]
434
+ seg_length = segment_lengths[seg_idx]
435
+ alpha = (distance - seg_start) / (seg_length if seg_length > 1e-12 else 1.0)
436
+ alpha = min(max(alpha, 0.0), 1.0 - 1e-9)
437
+ p0 = traj_pos[seg_idx]
438
+ p1 = traj_pos[seg_idx + 1]
439
+ position = p0 + alpha * (p1 - p0)
440
+ direction = segment_vectors[seg_idx, :2] / (seg_length if seg_length > 1e-12 else 1.0)
441
+ pos_c[idx] = position
442
+ alphas[idx] = alpha
443
+ vdirs[idx] = direction
444
+ pairs.append((idxs[seg_idx], idxs[seg_idx + 1]))
445
+ return pos_c.astype(np.float32), pairs, alphas.astype(np.float32), vdirs.astype(np.float32)
446
+
447
+
448
+ # ---------------------------------------------------------------------------
449
+ # Channel construction helpers
450
+ # ---------------------------------------------------------------------------
451
+
452
+ def array_response_phase(theta: np.ndarray, phi: np.ndarray, kd: float) -> np.ndarray:
453
+ gamma_x = 1j * kd * np.sin(theta) * np.cos(phi)
454
+ gamma_y = 1j * kd * np.sin(theta) * np.sin(phi)
455
+ gamma_z = 1j * kd * np.cos(theta)
456
+ return np.vstack([gamma_x, gamma_y, gamma_z]).T
457
+
458
+
459
+ def array_response(ant_indices: np.ndarray, theta: np.ndarray, phi: np.ndarray, kd: float) -> np.ndarray:
460
+ gamma = array_response_phase(theta, phi, kd)
461
+ return np.exp(ant_indices @ gamma.T)
462
+
463
+
464
+ def ant_indices(panel_size: np.ndarray) -> np.ndarray:
465
+ gx = np.tile(np.arange(1), int(panel_size[0] * panel_size[1]))
466
+ gy = np.tile(np.repeat(np.arange(int(panel_size[0])), 1), int(panel_size[1]))
467
+ gz = np.repeat(np.arange(int(panel_size[1])), int(panel_size[0]))
468
+ return np.vstack([gx, gy, gz]).T
469
+
470
+
471
+ def apply_fov(fov: Sequence[float], theta: np.ndarray, phi: np.ndarray) -> np.ndarray:
472
+ theta = np.mod(theta, 2 * np.pi)
473
+ phi = np.mod(phi, 2 * np.pi)
474
+ fov = np.deg2rad(fov)
475
+ include_phi = np.logical_or(phi <= 0 + fov[0] / 2, phi >= 2 * np.pi - fov[0] / 2)
476
+ include_theta = np.logical_and(theta <= np.pi / 2 + fov[1] / 2, theta >= np.pi / 2 - fov[1] / 2)
477
+ return np.logical_and(include_phi, include_theta)
478
+
479
+
480
+ def rotate_angles(rotation: Optional[np.ndarray], theta: np.ndarray, phi: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
481
+ theta = np.deg2rad(theta)
482
+ phi = np.deg2rad(phi)
483
+ if rotation is not None:
484
+ R = np.deg2rad(rotation)
485
+ sa = np.sin(phi - R[2])
486
+ sb = np.sin(R[1])
487
+ sg = np.sin(R[0])
488
+ ca = np.cos(phi - R[2])
489
+ cb = np.cos(R[1])
490
+ cg = np.cos(R[0])
491
+ st, ct = np.sin(theta), np.cos(theta)
492
+ theta = np.arccos(cb * cg * ct + st * (sb * cg * ca - sg * sa))
493
+ phi = np.angle(cb * st * ca - sb * ct + 1j * (cb * sg * ct + st * (sb * sg * ca + cg * sa)))
494
+ return theta, phi
495
+
496
+
497
+ class OFDMPathGenerator:
498
+ def __init__(self, params: Dict[str, Any], subcarriers: np.ndarray) -> None:
499
+ self.params = params
500
+ self.ofdm = params[c.PARAMSET_OFDM]
501
+ self.subcarriers = subcarriers
502
+ self.total_subc = self.ofdm[c.PARAMSET_OFDM_SC_NUM]
503
+ self.delay_bins = np.arange(self.ofdm["subcarriers"])
504
+ self.delay_to_ofdm = np.exp(-1j * 2 * np.pi / self.total_subc * np.outer(self.delay_bins, self.subcarriers))
505
+
506
+ def _doppler_phase(self, raydata: Dict[str, Any]) -> Optional[np.ndarray]:
507
+ if not (
508
+ self.params[c.PARAMSET_DOPPLER_EN]
509
+ and self.params[c.PARAMSET_SCENARIO_PARAMS][c.PARAMSET_SCENARIO_PARAMS_DOPPLER_EN]
510
+ ):
511
+ return None
512
+ fc = self.params[c.PARAMSET_SCENARIO_PARAMS][c.PARAMSET_SCENARIO_PARAMS_CF]
513
+ velocities = np.asarray(raydata.get(c.OUT_PATH_DOP_VEL, 0.0)).reshape(-1, 1)
514
+ elapsed = np.asarray(raydata.get("elapsed_time", 0.0)).reshape(-1, 1)
515
+ return np.exp(-1j * 2 * np.pi * (fc / c.LIGHTSPEED) * (velocities * elapsed))
516
+
517
+ def generate(self, raydata: Dict[str, Any], ts: float) -> np.ndarray:
518
+ if self.ofdm[c.PARAMSET_OFDM_LPF] == 0:
519
+ return self.no_lpf(raydata, ts)
520
+ return self.with_lpf(raydata, ts)
521
+
522
+ def no_lpf(self, raydata: Dict[str, Any], ts: float) -> np.ndarray:
523
+ power = raydata[c.OUT_PATH_RX_POW].reshape(-1, 1)
524
+ delay_n = (raydata[c.OUT_PATH_TOA] / ts).reshape(-1, 1)
525
+ phase = raydata[c.OUT_PATH_PHASE].reshape(-1, 1)
526
+ over = delay_n >= self.ofdm["subcarriers"]
527
+ power[over] = 0
528
+ delay_n[over] = self.ofdm["subcarriers"]
529
+ path_const = np.sqrt(power / self.total_subc) * np.exp(
530
+ 1j * (np.deg2rad(phase) - (2 * np.pi / self.total_subc) * np.outer(delay_n, self.subcarriers))
531
+ )
532
+ doppler = self._doppler_phase(raydata)
533
+ if doppler is not None:
534
+ path_const *= doppler
535
+ return path_const
536
+
537
+ def with_lpf(self, raydata: Dict[str, Any], ts: float) -> np.ndarray:
538
+ power = raydata[c.OUT_PATH_RX_POW].reshape(-1, 1)
539
+ delay_n = (raydata[c.OUT_PATH_TOA] / ts).reshape(-1, 1)
540
+ phase = raydata[c.OUT_PATH_PHASE].reshape(-1, 1)
541
+ over = delay_n >= self.ofdm["subcarriers"]
542
+ power[over] = 0
543
+ delay_n[over] = self.ofdm["subcarriers"]
544
+ pulse = np.sinc(self.delay_bins - delay_n) * np.sqrt(power / self.total_subc) * np.exp(1j * np.deg2rad(phase))
545
+ doppler = self._doppler_phase(raydata)
546
+ if doppler is not None:
547
+ pulse *= doppler
548
+ return pulse
549
+
550
+
551
+ def generate_mimo_channel(raydata: Sequence[Dict[str, Any]], params: Dict[str, Any], tx_params: Dict[str, Any], rx_params: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
552
+ bw = params[c.PARAMSET_OFDM][c.PARAMSET_OFDM_BW] * c.PARAMSET_OFDM_BW_MULT
553
+ kd_tx = 2 * np.pi * tx_params[c.PARAMSET_ANT_SPACING]
554
+ kd_rx = 2 * np.pi * rx_params[c.PARAMSET_ANT_SPACING]
555
+ ts = 1 / bw
556
+ subcarriers = params[c.PARAMSET_OFDM][c.PARAMSET_OFDM_SC_SAMP]
557
+ generator = OFDMPathGenerator(params, subcarriers)
558
+ m_tx = int(np.prod(tx_params[c.PARAMSET_ANT_SHAPE]))
559
+ tx_indices = ant_indices(tx_params[c.PARAMSET_ANT_SHAPE])
560
+ m_rx = int(np.prod(rx_params[c.PARAMSET_ANT_SHAPE]))
561
+ rx_indices = ant_indices(rx_params[c.PARAMSET_ANT_SHAPE])
562
+ channel = np.zeros((len(raydata), m_rx, m_tx, len(subcarriers)), dtype=np.csingle)
563
+ los = np.zeros((len(raydata)), dtype=np.int8) - 2
564
+ for idx, item in enumerate(raydata):
565
+ if item[c.OUT_PATH_NUM] == 0:
566
+ los[idx] = -1
567
+ continue
568
+ dod_theta, dod_phi = rotate_angles(tx_params[c.PARAMSET_ANT_ROTATION], item[c.OUT_PATH_DOD_THETA], item[c.OUT_PATH_DOD_PHI])
569
+ doa_theta, doa_phi = rotate_angles(rx_params[c.PARAMSET_ANT_ROTATION], item[c.OUT_PATH_DOA_THETA], item[c.OUT_PATH_DOA_PHI])
570
+ include_tx = apply_fov(tx_params[c.PARAMSET_ANT_FOV], dod_theta, dod_phi)
571
+ include_rx = apply_fov(rx_params[c.PARAMSET_ANT_FOV], doa_theta, doa_phi)
572
+ include = np.logical_and(include_tx, include_rx)
573
+ dod_theta, dod_phi, doa_theta, doa_phi = dod_theta[include], dod_phi[include], doa_theta[include], doa_phi[include]
574
+ for key in list(item.keys()):
575
+ if key == c.OUT_PATH_NUM:
576
+ item[key] = include.sum()
577
+ elif isinstance(item[key], np.ndarray) and item[key].shape[0] == include.shape[0]:
578
+ item[key] = item[key][include]
579
+ if item[c.OUT_PATH_NUM] == 0:
580
+ los[idx] = -1
581
+ continue
582
+ los[idx] = int(np.sum(item.get("LoS", np.zeros(int(item[c.OUT_PATH_NUM])))))
583
+ tx_response = array_response(tx_indices, dod_theta, dod_phi, kd_tx)
584
+ rx_response = array_response(rx_indices, doa_theta, doa_phi, kd_rx)
585
+ factors = generator.generate(item, ts)
586
+ if params[c.PARAMSET_OFDM][c.PARAMSET_OFDM_LPF] == 0:
587
+ channel[idx] = np.sum(
588
+ rx_response[:, None, None, :] * tx_response[None, :, None, :] * factors.T[None, None, :, :],
589
+ axis=3,
590
+ )
591
+ else:
592
+ channel[idx] = (
593
+ np.sum(
594
+ rx_response[:, None, None, :] * tx_response[None, :, None, :] * factors.T[None, None, :, :],
595
+ axis=3,
596
+ )
597
+ ) @ generator.delay_to_ofdm
598
+ return channel, los
599
+
600
+
601
+ def generate_channel_from_interpolated_ray(ray_interp: Dict[str, Any], antenna_cfg: AntennaArrayConfig, carrier_frequency_hz: float) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
602
+ raydata = [
603
+ {
604
+ c.OUT_PATH_NUM: int(ray_interp["num_paths"]),
605
+ c.OUT_PATH_DOD_THETA: np.asarray(ray_interp["DoD_theta"]),
606
+ c.OUT_PATH_DOD_PHI: np.asarray(ray_interp["DoD_phi"]),
607
+ c.OUT_PATH_DOA_THETA: np.asarray(ray_interp["DoA_theta"]),
608
+ c.OUT_PATH_DOA_PHI: np.asarray(ray_interp["DoA_phi"]),
609
+ c.OUT_PATH_PHASE: np.asarray(ray_interp["phase"]),
610
+ c.OUT_PATH_TOA: np.asarray(ray_interp["ToA"]),
611
+ c.OUT_PATH_RX_POW: np.asarray(ray_interp["power"]),
612
+ "LoS": np.asarray(ray_interp.get("LoS", np.zeros(int(ray_interp["num_paths"])))),
613
+ c.OUT_PATH_DOP_VEL: np.asarray(ray_interp.get("Doppler_vel", np.zeros(int(ray_interp["num_paths"])))),
614
+ "elapsed_time": np.asarray(ray_interp.get("elapsed_time", np.zeros(int(ray_interp["num_paths"])))),
615
+ }
616
+ ]
617
+ params = {
618
+ c.PARAMSET_OFDM: {
619
+ c.PARAMSET_OFDM_BW: 1.92e-3,
620
+ c.PARAMSET_OFDM_SC_SAMP: np.arange(antenna_cfg.subcarriers),
621
+ c.PARAMSET_OFDM_SC_NUM: antenna_cfg.subcarriers,
622
+ c.PARAMSET_OFDM_LPF: 0,
623
+ "subcarriers": antenna_cfg.subcarriers,
624
+ },
625
+ c.PARAMSET_FDTD: True,
626
+ c.PARAMSET_DOPPLER_EN: True,
627
+ c.PARAMSET_SCENARIO_PARAMS: {
628
+ c.PARAMSET_SCENARIO_PARAMS_DOPPLER_EN: True,
629
+ c.PARAMSET_SCENARIO_PARAMS_CF: float(carrier_frequency_hz),
630
+ },
631
+ }
632
+ tx = {
633
+ c.PARAMSET_ANT_SHAPE: antenna_cfg.tx_shape(),
634
+ c.PARAMSET_ANT_SPACING: antenna_cfg.spacing,
635
+ c.PARAMSET_ANT_ROTATION: np.asarray(antenna_cfg.tx_rotation),
636
+ c.PARAMSET_ANT_FOV: list(antenna_cfg.field_of_view),
637
+ c.PARAMSET_ANT_RAD_PAT: "isotropic",
638
+ }
639
+ rx = {
640
+ c.PARAMSET_ANT_SHAPE: antenna_cfg.rx_shape(),
641
+ c.PARAMSET_ANT_SPACING: antenna_cfg.spacing,
642
+ c.PARAMSET_ANT_ROTATION: np.asarray(antenna_cfg.rx_rotation),
643
+ c.PARAMSET_ANT_FOV: list(antenna_cfg.field_of_view),
644
+ c.PARAMSET_ANT_RAD_PAT: "isotropic",
645
+ }
646
+ channel, los = generate_mimo_channel(raydata, params, tx, rx)
647
+ return channel, None, los
648
+
649
+
650
+ def unwrap_angle_deg(a0: np.ndarray, a1: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
651
+ delta = ((a1 - a0 + 180.0) % 360.0) - 180.0
652
+ return a0, a0 + delta
653
+
654
+
655
+ def interpolate_ray_params(deepmimo_data: Dict[str, Any], idx0: int, idx1: int, alpha: float) -> Dict[str, Any]:
656
+ p0 = deepmimo_data["user"]["paths"][idx0]
657
+ p1 = deepmimo_data["user"]["paths"][idx1]
658
+ l0, l1 = int(p0["num_paths"]), int(p1["num_paths"])
659
+ if l0 == 0 or l1 == 0:
660
+ return {
661
+ "num_paths": 0,
662
+ "DoD_theta": np.array([]),
663
+ "DoD_phi": np.array([]),
664
+ "DoA_theta": np.array([]),
665
+ "DoA_phi": np.array([]),
666
+ "phase": np.array([]),
667
+ "ToA": np.array([]),
668
+ "power": np.array([]),
669
+ "LoS": np.array([], int),
670
+ }
671
+
672
+ def _select_paths(path: Dict[str, Any], count: int) -> np.ndarray:
673
+ scores = np.asarray(path["power"]).flatten()
674
+ order = np.argsort(-scores)
675
+ return order[:count]
676
+
677
+ count = min(l0, l1)
678
+ idxs0 = _select_paths(p0, count)
679
+ idxs1 = _select_paths(p1, count)
680
+
681
+ def _g(source: Dict[str, Any], key: str, selection: np.ndarray) -> np.ndarray:
682
+ return np.asarray(source[key]).flatten()[selection].astype(float)
683
+
684
+ power0, power1 = _g(p0, "power", idxs0), _g(p1, "power", idxs1)
685
+ power = (1 - alpha) * power0 + alpha * power1
686
+ toa0, toa1 = _g(p0, "ToA", idxs0), _g(p1, "ToA", idxs1)
687
+ toa = (1 - alpha) * toa0 + alpha * toa1
688
+
689
+ def _ainterp(key: str) -> np.ndarray:
690
+ a0, a1 = _g(p0, key, idxs0), _g(p1, key, idxs1)
691
+ out = np.zeros_like(a0)
692
+ for n in range(count):
693
+ u0, u1 = unwrap_angle_deg(a0[n], a1[n])
694
+ out[n] = (1 - alpha) * u0 + alpha * u1
695
+ return out
696
+
697
+ dod_theta, dod_phi = _ainterp("DoD_theta"), _ainterp("DoD_phi")
698
+ doa_theta, doa_phi = _ainterp("DoA_theta"), _ainterp("DoA_phi")
699
+ phase0 = np.deg2rad(_g(p0, "phase", idxs0))
700
+ phase1 = np.deg2rad(_g(p1, "phase", idxs1))
701
+ dphi = np.angle(np.exp(1j * (phase1 - phase0)))
702
+ phase = np.rad2deg(phase0 + alpha * dphi)
703
+ los = (
704
+ _g(p0, "LoS", idxs0).astype(int) + _g(p1, "LoS", idxs1).astype(int) > 0
705
+ ).astype(int)
706
+
707
+ return {
708
+ "num_paths": count,
709
+ "DoD_theta": dod_theta,
710
+ "DoD_phi": dod_phi,
711
+ "DoA_theta": doa_theta,
712
+ "DoA_phi": doa_phi,
713
+ "phase": phase,
714
+ "ToA": toa,
715
+ "power": power,
716
+ "LoS": los,
717
+ }
718
+
719
+
720
+ # ---------------------------------------------------------------------------
721
+ # Scenario generator
722
+ # ---------------------------------------------------------------------------
723
+
724
+
725
+ class DynamicScenarioGenerator:
726
+ """High-level orchestrator for dynamic DeepMIMO scenario synthesis."""
727
+
728
+ def __init__(self, config: ScenarioGenerationConfig, logger: Optional[logging.Logger] = None) -> None:
729
+ self.config = config
730
+ self.logger = logger or _LOGGER
731
+ if config.rng_seed is not None:
732
+ np.random.seed(config.rng_seed)
733
+
734
+ # ------------------------------------------------------------------
735
+ # public API
736
+ # ------------------------------------------------------------------
737
+ def generate(self, overwrite: bool = False) -> ScenarioGenerationResult:
738
+ cfg = self.config
739
+ cfg.output_dir.mkdir(parents=True, exist_ok=True)
740
+ cfg.full_output_dir.mkdir(parents=True, exist_ok=True)
741
+ filename = self._build_filename()
742
+ data_path = cfg.output_dir / filename
743
+ full_path = cfg.full_output_dir / filename.replace(".p", "_full.p")
744
+
745
+ if data_path.exists() and full_path.exists() and not overwrite:
746
+ payload = self._load_pickle(data_path)
747
+ self.logger.info("Loaded cached scenario from %s", data_path)
748
+ return ScenarioGenerationResult(payload=payload, output_path=data_path, full_output_path=full_path, generated=False)
749
+
750
+ deepmimo_data = self._load_or_build_deepmimo()
751
+ path_exist = np.asarray(deepmimo_data["user"]["LoS"])
752
+ pos_total = np.asarray(deepmimo_data["user"]["location"])
753
+
754
+ grid_step = self._infer_grid_step(pos_total)
755
+ road_positions, lane_info = filter_road_positions(
756
+ pos_total[(path_exist == 0) | (path_exist == 1)], cfg.grid.road_width, cfg.grid.road_center_spacing
757
+ )
758
+ road_graph = create_grid_road_network(road_positions, lane_info, grid_step)
759
+ self.logger.info("Road graph has %d nodes", len(road_graph.nodes))
760
+
761
+ vehicle_trajs = generate_n_smooth_grid_trajectories(
762
+ road_graph,
763
+ road_positions,
764
+ cfg.traffic.num_vehicles,
765
+ sequence_length=cfg.sampling.time_steps,
766
+ turn_probability=cfg.traffic.turn_probability,
767
+ max_attempts=cfg.traffic.max_attempts,
768
+ step_size=grid_step,
769
+ )
770
+ pedestrian_trajs = generate_n_pedestrian_trajectories(
771
+ pos_total[(path_exist == 0) | (path_exist == 1)],
772
+ cfg.traffic.num_pedestrians,
773
+ sequence_length=cfg.sampling.time_steps,
774
+ step_size=grid_step,
775
+ angle_std=cfg.traffic.pedestrian_angle_std,
776
+ )
777
+
778
+ veh_idx = get_trajectory_indices(vehicle_trajs, pos_total)
779
+ ped_idx = get_trajectory_indices(pedestrian_trajs, pos_total)
780
+
781
+ car_tracks = self._build_tracks(vehicle_trajs, veh_idx, cfg.traffic.vehicle_speed_range, cfg.sampling)
782
+ ped_tracks = self._build_tracks(pedestrian_trajs, ped_idx, cfg.traffic.pedestrian_speed_range, cfg.sampling)
783
+
784
+ if cfg.export_environment_plot and cfg.figures_dir is not None:
785
+ self._save_environment_plot(
786
+ cfg.figures_dir,
787
+ pos_total,
788
+ path_exist,
789
+ road_positions=road_positions,
790
+ vehicle_trajs=[track["pos"] for track in car_tracks],
791
+ ped_trajs=[track["pos"] for track in ped_tracks],
792
+ )
793
+
794
+ channel_payload = self._channels_for_tracks(
795
+ deepmimo_data,
796
+ car_tracks + ped_tracks,
797
+ cfg.antenna,
798
+ cfg.carrier_frequency_hz,
799
+ cfg.sampling,
800
+ )
801
+
802
+ discrete_channels = np.zeros(
803
+ (
804
+ cfg.traffic.num_vehicles + cfg.traffic.num_pedestrians,
805
+ cfg.sampling.time_steps,
806
+ cfg.antenna.total_tx_elements(),
807
+ cfg.antenna.subcarriers,
808
+ ),
809
+ np.complex128,
810
+ )
811
+
812
+ payload: Dict[str, Any] = {
813
+ "scenario": cfg.scenario,
814
+ "index_discrete": veh_idx + ped_idx,
815
+ "grid_step": grid_step,
816
+ "sample_dt": cfg.sampling.sample_dt,
817
+ "car_speed_range": cfg.traffic.vehicle_speed_range,
818
+ "ped_speed_range": cfg.traffic.pedestrian_speed_range,
819
+ "los": path_exist,
820
+ "channel_cont": channel_payload["channel"],
821
+ "pos_cont": channel_payload["pos"],
822
+ "vel_cont": channel_payload["vel"],
823
+ "acc_cont": channel_payload["acc"],
824
+ "doppler_vel_cont": channel_payload["doppler"],
825
+ "angle_cont": channel_payload["angle"],
826
+ "delay_cont": channel_payload["delay"],
827
+ "pos_step_xy": channel_payload["step_xy"],
828
+ "channel_discrete": discrete_channels,
829
+ "pos_discrete": vehicle_trajs + pedestrian_trajs,
830
+ }
831
+
832
+ if cfg.sampling.continuous_mode:
833
+ payload.update(
834
+ {
835
+ "channel": channel_payload["channel"],
836
+ "pos": channel_payload["pos"],
837
+ "vel": channel_payload["vel"],
838
+ "acc": channel_payload["acc"],
839
+ "doppler_vel": channel_payload["doppler"],
840
+ "angle": channel_payload["angle"],
841
+ "delay": channel_payload["delay"],
842
+ }
843
+ )
844
+ else:
845
+ payload["channel"] = discrete_channels
846
+
847
+ self._dump_pickle(data_path, payload["channel"])
848
+ self._dump_pickle(full_path, payload)
849
+ self.logger.info("Generated scenario saved to %s", data_path)
850
+ return ScenarioGenerationResult(payload=payload, output_path=data_path, full_output_path=full_path, generated=True)
851
+
852
+ # ------------------------------------------------------------------
853
+ # internal helpers
854
+ # ------------------------------------------------------------------
855
+ def _load_or_build_deepmimo(self) -> Dict[str, Any]:
856
+ cfg = self.config
857
+ return load_deepmimo_user_data(
858
+ cfg.scenario,
859
+ scenarios_dir=cfg.scenarios_dir,
860
+ load_params=cfg.deepmimo_load_params,
861
+ max_paths=cfg.deepmimo_max_paths,
862
+ array_dtype=cfg.deepmimo_array_dtype,
863
+ logger=self.logger,
864
+ )
865
+
866
+ def _save_environment_plot(
867
+ self,
868
+ figures_dir: Path,
869
+ positions: np.ndarray,
870
+ los: np.ndarray,
871
+ road_positions: Optional[np.ndarray] = None,
872
+ vehicle_trajs: Optional[Sequence[np.ndarray]] = None,
873
+ ped_trajs: Optional[Sequence[np.ndarray]] = None,
874
+ ) -> None:
875
+ figures_dir.mkdir(parents=True, exist_ok=True)
876
+ x, y = positions[:, 0], positions[:, 1]
877
+ los_array = np.asarray(los)
878
+ if los_array.ndim > 1:
879
+ N = x.shape[0]
880
+ shape = list(los_array.shape)
881
+ if N in shape:
882
+ axis_user = shape.index(N)
883
+ los_matrix = np.moveaxis(los_array, axis_user, -1).reshape(-1, N)
884
+ los_per_user = np.full(N, -1, np.int8)
885
+ any_los = (los_matrix == 1).any(axis=0)
886
+ los_per_user[any_los] = 1
887
+ any_nlos = (los_matrix == 0).any(axis=0) & ~any_los
888
+ los_per_user[any_nlos] = 0
889
+ else:
890
+ los_per_user = np.full(x.shape[0], -1, np.int8)
891
+ else:
892
+ los_per_user = los_array.astype(np.int8, copy=False)
893
+
894
+ colors = {
895
+ 1: "#2ca8c2", # teal for LoS
896
+ 0: "#f39c12", # amber for NLoS
897
+ -1: "#5d6d7e", # muted slate for inactive users
898
+ }
899
+ plt.figure(figsize=(8, 8), dpi=500)
900
+ legend_handles = []
901
+ legend_labels = []
902
+ for value, color in colors.items():
903
+ mask = los_per_user == value
904
+ if np.any(mask):
905
+ handle = plt.scatter(x[mask], y[mask], s=3, c=color, alpha=0.6)
906
+ legend_handles.append(handle)
907
+ legend_labels.append(f"LoS={value}")
908
+
909
+ if road_positions is not None and len(road_positions):
910
+ rp = np.asarray(road_positions)
911
+ handle = plt.scatter(rp[:, 0], rp[:, 1], s=1, c="#a0a0a0", alpha=0.25)
912
+ legend_handles.append(handle)
913
+ legend_labels.append("Road nodes")
914
+
915
+ if vehicle_trajs:
916
+ cmap_vehicle = plt.cm.get_cmap("tab20", max(len(vehicle_trajs), 1))
917
+ for idx, traj in enumerate(vehicle_trajs):
918
+ arr = np.asarray(traj, dtype=float)
919
+ arr = arr.reshape(arr.shape[0], -1)
920
+ if arr.size == 0:
921
+ continue
922
+ color = cmap_vehicle(idx % cmap_vehicle.N)
923
+ plt.plot(arr[:, 0], arr[:, 1], color=color, alpha=0.85, linewidth=1.2)
924
+ plt.scatter(
925
+ arr[0:1, 0],
926
+ arr[0:1, 1],
927
+ s=20,
928
+ marker="o",
929
+ facecolor="white",
930
+ edgecolor=color,
931
+ linewidths=1.2,
932
+ zorder=3,
933
+ )
934
+ legend_handles.append(Line2D([0], [0], color=cmap_vehicle(0), linewidth=1.2))
935
+ legend_labels.append("Vehicle trajectories")
936
+ legend_handles.append(Line2D([0], [0], marker="o", color="white", markeredgecolor=cmap_vehicle(0), markerfacecolor="white", linestyle="None", markersize=5, markeredgewidth=1.2))
937
+ legend_labels.append("Vehicle start")
938
+
939
+ if ped_trajs:
940
+ cmap_ped = plt.cm.get_cmap("Set2", max(len(ped_trajs), 1))
941
+ for idx, traj in enumerate(ped_trajs):
942
+ arr = np.asarray(traj, dtype=float)
943
+ arr = arr.reshape(arr.shape[0], -1)
944
+ if arr.size == 0:
945
+ continue
946
+ color = cmap_ped(idx % cmap_ped.N)
947
+ plt.plot(arr[:, 0], arr[:, 1], color=color, alpha=0.85, linewidth=1.2, linestyle="--")
948
+ plt.scatter(
949
+ arr[0:1, 0],
950
+ arr[0:1, 1],
951
+ s=20,
952
+ marker="o",
953
+ facecolor="white",
954
+ edgecolor=color,
955
+ linewidths=1.2,
956
+ zorder=3,
957
+ )
958
+ legend_handles.append(Line2D([0], [0], color=cmap_ped(0), linewidth=1.2, linestyle="--"))
959
+ legend_labels.append("Pedestrian trajectories")
960
+ legend_handles.append(Line2D([0], [0], marker="o", color="white", markeredgecolor=cmap_ped(0), markerfacecolor="white", linestyle="None", markersize=5, markeredgewidth=1.2))
961
+ legend_labels.append("Pedestrian start")
962
+
963
+ if legend_handles:
964
+ plt.legend(legend_handles, legend_labels, loc="best", fontsize=8)
965
+ # plt.grid(True)
966
+ plt.gca().set_aspect("equal", adjustable="box")
967
+ plt.savefig(figures_dir / "environment.png", bbox_inches="tight", pad_inches=0.1)
968
+ plt.close()
969
+
970
+ def _infer_grid_step(self, positions: np.ndarray) -> float:
971
+ cfg = self.config
972
+ if cfg.grid.auto_step_size:
973
+ inferred = infer_grid_step(positions)
974
+ self.logger.info("Inferred road step size %.6f m", inferred)
975
+ return inferred
976
+ return float(cfg.grid.step_size)
977
+
978
+ def _build_tracks(
979
+ self,
980
+ trajectories: Sequence[np.ndarray],
981
+ indices: Sequence[Sequence[int]],
982
+ speed_range: Tuple[float, float],
983
+ sampling_cfg: ScenarioSamplingConfig,
984
+ ) -> List[Dict[str, Any]]:
985
+ tracks: List[Dict[str, Any]] = []
986
+ horizon = sampling_cfg.continuous_length or sampling_cfg.time_steps
987
+ for traj, idxs in zip(trajectories, indices):
988
+ speed = float(np.random.uniform(*speed_range))
989
+ idxs_list = list(idxs)
990
+ speed_profile = np.full(horizon, speed, float)
991
+ last_change = -1
992
+ for k in range(len(idxs_list) - 1):
993
+ if idxs_list[k] != idxs_list[k + 1]:
994
+ last_change = k
995
+ if last_change == -1:
996
+ speed_profile[:] = 0.0
997
+ else:
998
+ last_move_sample = last_change + 1
999
+ if last_move_sample < horizon - 1:
1000
+ ramp = min(5, last_move_sample + 1)
1001
+ start_idx = max(0, last_move_sample - ramp + 1)
1002
+ speed_profile[:start_idx] = speed
1003
+ for k in range(start_idx, last_move_sample + 1):
1004
+ factor = float(last_move_sample - k + 1) / float(ramp)
1005
+ speed_profile[k] = speed * factor
1006
+ speed_profile[last_move_sample + 1 :] = 0.0
1007
+ pos_c, pairs, alpha, vdir = sample_continuous_along_polyline(
1008
+ traj,
1009
+ idxs,
1010
+ speed,
1011
+ sampling_cfg.sample_dt,
1012
+ horizon,
1013
+ speed_profile=speed_profile,
1014
+ )
1015
+ tracks.append(
1016
+ {
1017
+ "speed": speed,
1018
+ "speed_profile": speed_profile,
1019
+ "pos": pos_c,
1020
+ "pairs": pairs,
1021
+ "alpha": alpha,
1022
+ "vdir": vdir,
1023
+ }
1024
+ )
1025
+ return tracks
1026
+
1027
+ def _channels_for_tracks(
1028
+ self,
1029
+ deepmimo_data: Dict[str, Any],
1030
+ tracks: Sequence[Dict[str, Any]],
1031
+ antenna_cfg: AntennaArrayConfig,
1032
+ carrier_frequency_hz: float,
1033
+ sampling_cfg: ScenarioSamplingConfig,
1034
+ ) -> Dict[str, Any]:
1035
+ channels: List[np.ndarray] = []
1036
+ positions: List[np.ndarray] = []
1037
+ velocities: List[np.ndarray] = []
1038
+ accelerations: List[np.ndarray] = []
1039
+ dopplers: List[List[np.ndarray]] = []
1040
+ angles: List[List[np.ndarray]] = []
1041
+ delays: List[List[np.ndarray]] = []
1042
+ steps_xy: List[np.ndarray] = []
1043
+ for track in tracks:
1044
+ base_speed = track["speed"]
1045
+ speed_profile = track.get("speed_profile")
1046
+ pos = track["pos"]
1047
+ pairs = track["pairs"]
1048
+ alpha = track["alpha"]
1049
+ vdir = track["vdir"]
1050
+ horizon = len(alpha)
1051
+ channel_sequence: List[np.ndarray] = []
1052
+ doppler_list: List[np.ndarray] = []
1053
+ angle_list: List[np.ndarray] = []
1054
+ delay_list: List[np.ndarray] = []
1055
+ step_xy = np.zeros(horizon, float)
1056
+ step_xy[1:] = np.linalg.norm(pos[1:, :2] - pos[:-1, :2], axis=1)
1057
+ vel_series = np.zeros(horizon, float)
1058
+ if horizon > 1:
1059
+ vel_series[1:] = step_xy[1:] / sampling_cfg.sample_dt
1060
+ if vel_series[0] == 0.0:
1061
+ if speed_profile is not None:
1062
+ vel_series[0] = float(speed_profile[0])
1063
+ else:
1064
+ vel_series[0] = base_speed
1065
+ acc_series = np.zeros_like(vel_series)
1066
+ if horizon > 1:
1067
+ acc_series[1:] = (vel_series[1:] - vel_series[:-1]) / sampling_cfg.sample_dt
1068
+ acc_series[0] = acc_series[1] if horizon > 1 else 0.0
1069
+ for idx, (pair, a, vd) in enumerate(zip(pairs, alpha, vdir)):
1070
+ i0, i1 = pair
1071
+ ray_interp = interpolate_ray_params(deepmimo_data, i0, i1, float(a))
1072
+ if ray_interp["num_paths"] == 0:
1073
+ channel_sequence.append(np.zeros((antenna_cfg.total_tx_elements(), antenna_cfg.subcarriers), np.complex128))
1074
+ doppler_list.append(np.zeros((0,), np.float32))
1075
+ angle_list.append(np.zeros((0,), np.float32))
1076
+ delay_list.append(np.zeros((0,), np.float32))
1077
+ continue
1078
+ doa_phi = np.deg2rad(ray_interp["DoA_phi"])
1079
+ aoa_unit = np.stack([np.cos(doa_phi), np.sin(doa_phi)], axis=1)
1080
+ speed_inst = vel_series[idx] if idx < vel_series.shape[0] else base_speed
1081
+ v_proj = -np.sum(aoa_unit * vd[None, :], axis=1) * speed_inst
1082
+ ray_interp["Doppler_vel"] = v_proj.astype(np.float32)
1083
+ ray_interp["elapsed_time"] = np.ones_like(ray_interp["power"], dtype=np.float32) * (idx * sampling_cfg.sample_dt)
1084
+ pred, _, _ = generate_channel_from_interpolated_ray(ray_interp, antenna_cfg, carrier_frequency_hz)
1085
+ channel_sequence.append(np.asarray(pred[0]).squeeze(0))
1086
+ doppler_list.append(v_proj.astype(np.float32))
1087
+ angle_list.append(ray_interp["DoA_phi"].astype(np.float32))
1088
+ delay_list.append(ray_interp["ToA"].astype(np.float32))
1089
+ channels.append(np.stack(channel_sequence, axis=0))
1090
+ positions.append(pos.astype(np.float32))
1091
+ velocities.append(vel_series.astype(np.float32))
1092
+ accelerations.append(acc_series.astype(np.float32))
1093
+ dopplers.append(doppler_list)
1094
+ angles.append(angle_list)
1095
+ delays.append(delay_list)
1096
+ steps_xy.append(step_xy.astype(np.float32))
1097
+ return {
1098
+ "channel": np.stack(channels, axis=0),
1099
+ "pos": np.stack(positions, axis=0),
1100
+ "vel": np.stack(velocities, axis=0),
1101
+ "acc": np.stack(accelerations, axis=0),
1102
+ "doppler": dopplers,
1103
+ "angle": angles,
1104
+ "delay": delays,
1105
+ "step_xy": np.stack(steps_xy, axis=0),
1106
+ }
1107
+
1108
+ def _build_filename(self) -> str:
1109
+ cfg = self.config
1110
+ return f"{cfg.scenario}_{cfg.sampling.time_steps}_{cfg.antenna.total_tx_elements()}_{cfg.antenna.subcarriers}.p"
1111
+
1112
+ @staticmethod
1113
+ def _load_pickle(path: Path) -> Any:
1114
+ with path.open("rb") as handle:
1115
+ import pickle
1116
+
1117
+ return pickle.load(handle)
1118
+
1119
+ @staticmethod
1120
+ def _dump_pickle(path: Path, obj: Any) -> None:
1121
+ with path.open("wb") as handle:
1122
+ import pickle
1123
+
1124
+ pickle.dump(obj, handle)
1125
+
1126
+
1127
+ # Backwards compatibility helper -------------------------------------------------
1128
+
1129
+ def generate_dynamic_scenario_dataset(
1130
+ config: ScenarioGenerationConfig,
1131
+ overwrite: bool = False,
1132
+ logger: Optional[logging.Logger] = None,
1133
+ ) -> ScenarioGenerationResult:
1134
+ generator = DynamicScenarioGenerator(config, logger=logger)
1135
+ return generator.generate(overwrite=overwrite)
1136
+
1137
+
1138
+ def generate_dynamic_scenario(
1139
+ config: ScenarioGenerationConfig,
1140
+ overwrite: bool = False,
1141
+ logger: Optional[logging.Logger] = None,
1142
+ ) -> ScenarioGenerationResult:
1143
+ warnings.warn(
1144
+ "`generate_dynamic_scenario` is deprecated. Use `generate_dynamic_scenario_dataset` instead.",
1145
+ DeprecationWarning,
1146
+ stacklevel=2,
1147
+ )
1148
+ return generate_dynamic_scenario_dataset(config, overwrite=overwrite, logger=logger)
1149
+
1150
+
1151
+ __all__ = [
1152
+ "AntennaArrayConfig",
1153
+ "TrafficConfig",
1154
+ "GridConfig",
1155
+ "ScenarioSamplingConfig",
1156
+ "ScenarioGenerationConfig",
1157
+ "ScenarioGenerationResult",
1158
+ "DynamicScenarioGenerator",
1159
+ "generate_dynamic_scenario_dataset",
1160
+ "generate_dynamic_scenario",
1161
+ ]
docs/dynamic_scenario_pipeline.md CHANGED
@@ -214,14 +214,14 @@ from LWMTemporal.data.scenario_generation import (
214
  config = ScenarioGenerationConfig(
215
  scenario="city_0_newyork_3p5",
216
  antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32),
217
- sampling=ScenarioSamplingConfig(time_steps=40, sample_dt=0.1),
218
- traffic=TrafficConfig(num_vehicles=120, num_pedestrians=20, turn_probability=0.08),
219
- grid=GridConfig(road_width=6.0, road_center_spacing=25.0),
220
  output_dir=Path("examples/data"),
221
  full_output_dir=Path("examples/full_data"),
222
  figures_dir=Path("examples/figs/newyork"),
223
  scenarios_dir=Path("deepmimo_scenarios"),
224
- deepmimo_max_paths=6,
225
  )
226
 
227
  generator = DynamicScenarioGenerator(config)
 
214
  config = ScenarioGenerationConfig(
215
  scenario="city_0_newyork_3p5",
216
  antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32),
217
+ sampling=ScenarioSamplingConfig(time_steps=20, sample_dt=0.1),
218
+ traffic=TrafficConfig(num_vehicles=120, num_pedestrians=20, turn_probability=0.1),
219
+ grid=GridConfig(road_width=2.0, road_center_spacing=8.0),
220
  output_dir=Path("examples/data"),
221
  full_output_dir=Path("examples/full_data"),
222
  figures_dir=Path("examples/figs/newyork"),
223
  scenarios_dir=Path("deepmimo_scenarios"),
224
+ deepmimo_max_paths=25,
225
  )
226
 
227
  generator = DynamicScenarioGenerator(config)