Sadjad Alikhani commited on
Commit ·
265d187
1
Parent(s): 80a230c
Add data dir and docs
Browse files- .gitignore +0 -1
- LWMTemporal/data/__init__.py +32 -0
- LWMTemporal/data/angle_delay.py +272 -0
- LWMTemporal/data/datasets.py +241 -0
- LWMTemporal/data/deepmimo_adapter.py +236 -0
- LWMTemporal/data/scenario_generation.py +1161 -0
- docs/dynamic_scenario_pipeline.md +4 -4
.gitignore
CHANGED
|
@@ -38,7 +38,6 @@ MANIFEST
|
|
| 38 |
*.h5
|
| 39 |
*.hdf5
|
| 40 |
cache/
|
| 41 |
-
data/
|
| 42 |
!examples/data/
|
| 43 |
!examples/data/*.p
|
| 44 |
!examples/data/README.md
|
|
|
|
| 38 |
*.h5
|
| 39 |
*.hdf5
|
| 40 |
cache/
|
|
|
|
| 41 |
!examples/data/
|
| 42 |
!examples/data/*.p
|
| 43 |
!examples/data/README.md
|
LWMTemporal/data/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data utilities for the LWM foundation package."""
|
| 2 |
+
|
| 3 |
+
from .angle_delay import AngleDelayConfig, AngleDelayProcessor
|
| 4 |
+
from .datasets import AngleDelaySequenceDataset, AngleDelayDatasetConfig, load_adseq_dataset
|
| 5 |
+
from .scenario_generation import (
|
| 6 |
+
AntennaArrayConfig,
|
| 7 |
+
DynamicScenarioGenerator,
|
| 8 |
+
GridConfig,
|
| 9 |
+
ScenarioGenerationConfig,
|
| 10 |
+
ScenarioGenerationResult,
|
| 11 |
+
ScenarioSamplingConfig,
|
| 12 |
+
TrafficConfig,
|
| 13 |
+
generate_dynamic_scenario,
|
| 14 |
+
generate_dynamic_scenario_dataset,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"AngleDelayConfig",
|
| 19 |
+
"AngleDelayProcessor",
|
| 20 |
+
"AngleDelaySequenceDataset",
|
| 21 |
+
"AngleDelayDatasetConfig",
|
| 22 |
+
"load_adseq_dataset",
|
| 23 |
+
"AntennaArrayConfig",
|
| 24 |
+
"TrafficConfig",
|
| 25 |
+
"GridConfig",
|
| 26 |
+
"ScenarioSamplingConfig",
|
| 27 |
+
"ScenarioGenerationConfig",
|
| 28 |
+
"ScenarioGenerationResult",
|
| 29 |
+
"DynamicScenarioGenerator",
|
| 30 |
+
"generate_dynamic_scenario_dataset",
|
| 31 |
+
"generate_dynamic_scenario",
|
| 32 |
+
]
|
LWMTemporal/data/angle_delay.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import math
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Iterable, List, Optional, Sequence
|
| 7 |
+
|
| 8 |
+
import imageio.v2 as imageio
|
| 9 |
+
import matplotlib.animation as animation
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclasses.dataclass
|
| 16 |
+
class AngleDelayConfig:
|
| 17 |
+
"""Configuration options for angle-delay processing."""
|
| 18 |
+
|
| 19 |
+
angle_range: tuple[float, float] = (-math.pi / 2, math.pi / 2)
|
| 20 |
+
delay_range: tuple[float, float] = (0.0, 100.0)
|
| 21 |
+
keep_percentage: float = 0.25
|
| 22 |
+
fps: int = 4
|
| 23 |
+
dpi: int = 120
|
| 24 |
+
num_bins: int = 6
|
| 25 |
+
output_dir: Path = Path("figs")
|
| 26 |
+
|
| 27 |
+
def validate(self) -> None:
|
| 28 |
+
if not 0.0 < self.keep_percentage <= 1.0:
|
| 29 |
+
raise ValueError("keep_percentage must be in (0, 1]")
|
| 30 |
+
if self.fps <= 0:
|
| 31 |
+
raise ValueError("fps must be positive")
|
| 32 |
+
if self.dpi <= 0:
|
| 33 |
+
raise ValueError("dpi must be positive")
|
| 34 |
+
if self.num_bins <= 0:
|
| 35 |
+
raise ValueError("num_bins must be positive")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AngleDelayProcessor:
|
| 39 |
+
"""Project complex channels into the angle-delay domain and visualise them."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: AngleDelayConfig | None = None) -> None:
|
| 42 |
+
self.config = config or AngleDelayConfig()
|
| 43 |
+
self.config.validate()
|
| 44 |
+
|
| 45 |
+
# ------------------------------------------------------------------
|
| 46 |
+
# Core transforms
|
| 47 |
+
# ------------------------------------------------------------------
|
| 48 |
+
@staticmethod
|
| 49 |
+
def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
if not torch.is_complex(tensor):
|
| 51 |
+
raise TypeError("expected complex tensor")
|
| 52 |
+
return tensor
|
| 53 |
+
|
| 54 |
+
def forward(self, channel: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
channel = self._ensure_complex(channel)
|
| 56 |
+
angle_domain = torch.fft.fft(channel, dim=1, norm="ortho")
|
| 57 |
+
delay_domain = torch.fft.ifft(angle_domain, dim=2, norm="ortho")
|
| 58 |
+
return delay_domain
|
| 59 |
+
|
| 60 |
+
def inverse(self, angle_delay: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
angle_delay = self._ensure_complex(angle_delay)
|
| 62 |
+
subcarrier = torch.fft.fft(angle_delay, dim=2, norm="ortho")
|
| 63 |
+
antenna = torch.fft.ifft(subcarrier, dim=1, norm="ortho")
|
| 64 |
+
return antenna
|
| 65 |
+
|
| 66 |
+
# ------------------------------------------------------------------
|
| 67 |
+
# Truncation helpers and metrics
|
| 68 |
+
# ------------------------------------------------------------------
|
| 69 |
+
def truncate_delay_bins(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 70 |
+
tensor = self._ensure_complex(tensor)
|
| 71 |
+
if tensor.ndim != 3:
|
| 72 |
+
raise ValueError("angle-delay tensor must have shape (T, N, M)")
|
| 73 |
+
keep = max(1, int(round(tensor.size(-1) * self.config.keep_percentage)))
|
| 74 |
+
truncated = tensor[..., :keep]
|
| 75 |
+
padded = torch.zeros_like(tensor)
|
| 76 |
+
padded[..., :keep] = truncated
|
| 77 |
+
return truncated, padded
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def nmse(reference: torch.Tensor, reconstruction: torch.Tensor) -> float:
|
| 81 |
+
reference = AngleDelayProcessor._ensure_complex(reference)
|
| 82 |
+
reconstruction = AngleDelayProcessor._ensure_complex(reconstruction)
|
| 83 |
+
mse = torch.mean(torch.abs(reference - reconstruction) ** 2)
|
| 84 |
+
power = torch.mean(torch.abs(reference) ** 2).clamp_min(1e-12)
|
| 85 |
+
return float(10.0 * torch.log10(mse / power))
|
| 86 |
+
|
| 87 |
+
def reconstruction_nmse(self, channel: torch.Tensor) -> tuple[float, float]:
|
| 88 |
+
ad_full = self.forward(channel)
|
| 89 |
+
recon_full = self.inverse(ad_full)
|
| 90 |
+
nmse_full = self.nmse(channel, recon_full)
|
| 91 |
+
truncated, padded = self.truncate_delay_bins(ad_full)
|
| 92 |
+
recon_trunc = self.inverse(padded)
|
| 93 |
+
nmse_trunc = self.nmse(channel, recon_trunc)
|
| 94 |
+
return nmse_full, nmse_trunc
|
| 95 |
+
|
| 96 |
+
# ------------------------------------------------------------------
|
| 97 |
+
# Visualisation helpers
|
| 98 |
+
# ------------------------------------------------------------------
|
| 99 |
+
def save_angle_delay_gif(self, tensor: torch.Tensor, output_path: Path, fps: Optional[int] = None) -> None:
|
| 100 |
+
tensor = self._ensure_complex(tensor)
|
| 101 |
+
output_path = Path(output_path)
|
| 102 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
magnitude = tensor.abs().cpu()
|
| 105 |
+
vmin, vmax = float(magnitude.min()), float(magnitude.max())
|
| 106 |
+
frames: List[np.ndarray] = []
|
| 107 |
+
for frame_idx in range(magnitude.size(0)):
|
| 108 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 109 |
+
im = ax.imshow(
|
| 110 |
+
magnitude[frame_idx].numpy(),
|
| 111 |
+
cmap="gray_r",
|
| 112 |
+
origin="lower",
|
| 113 |
+
aspect="auto",
|
| 114 |
+
extent=[*self.config.delay_range, *self.config.angle_range],
|
| 115 |
+
vmin=vmin,
|
| 116 |
+
vmax=vmax,
|
| 117 |
+
)
|
| 118 |
+
ax.set_xlabel("Delay (samples)")
|
| 119 |
+
ax.set_ylabel("Angle (radians)")
|
| 120 |
+
ax.set_title(f"Angle-Delay Map – Frame {frame_idx}")
|
| 121 |
+
fig.colorbar(im, ax=ax, label="Power (linear)")
|
| 122 |
+
fig.canvas.draw()
|
| 123 |
+
frames.append(np.asarray(fig.canvas.buffer_rgba()))
|
| 124 |
+
plt.close(fig)
|
| 125 |
+
imageio.mimsave(output_path, frames, fps=fps or self.config.fps)
|
| 126 |
+
|
| 127 |
+
def _save_animation(self, fig: plt.Figure, animate_fn, output_path: Path, fps: Optional[int] = None, dpi: Optional[int] = None) -> None:
|
| 128 |
+
anim = animation.FuncAnimation(fig, animate_fn)
|
| 129 |
+
output_path = Path(output_path)
|
| 130 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 131 |
+
anim.save(output_path, writer="pillow", fps=fps or self.config.fps, dpi=dpi or self.config.dpi)
|
| 132 |
+
plt.close(fig)
|
| 133 |
+
|
| 134 |
+
def save_channel_animation(self, channel: torch.Tensor, output_path: Path) -> None:
|
| 135 |
+
channel = self._ensure_complex(channel)
|
| 136 |
+
magnitude = channel.abs().cpu()
|
| 137 |
+
phase = torch.angle(channel).cpu()
|
| 138 |
+
vmin, vmax = float(magnitude.min()), float(magnitude.max())
|
| 139 |
+
temporal_mag = magnitude.mean(dim=(1, 2))
|
| 140 |
+
temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy())
|
| 141 |
+
|
| 142 |
+
fig, (ax_mag, ax_mag_line, ax_phase_line) = plt.subplots(1, 3, figsize=(20, 6))
|
| 143 |
+
mag_img = ax_mag.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto", vmin=vmin, vmax=vmax)
|
| 144 |
+
ax_mag.set_xlabel("Subcarrier")
|
| 145 |
+
ax_mag.set_ylabel("Antenna")
|
| 146 |
+
fig.colorbar(mag_img, ax=ax_mag, label="Magnitude")
|
| 147 |
+
|
| 148 |
+
mag_line, = ax_mag_line.plot([], [], "b-o", linewidth=2, markersize=4)
|
| 149 |
+
ax_mag_line.set_xlim(0, magnitude.size(0) - 1)
|
| 150 |
+
ax_mag_line.set_ylim(float(temporal_mag.min()) * 0.9, float(temporal_mag.max()) * 1.1)
|
| 151 |
+
ax_mag_line.grid(True, alpha=0.3)
|
| 152 |
+
ax_mag_line.set_xlabel("Frame")
|
| 153 |
+
ax_mag_line.set_ylabel("Average Magnitude")
|
| 154 |
+
|
| 155 |
+
phase_line, = ax_phase_line.plot([], [], "r-s", linewidth=2, markersize=4)
|
| 156 |
+
ax_phase_line.set_xlim(0, magnitude.size(0) - 1)
|
| 157 |
+
ax_phase_line.set_ylim(float(temporal_phase.min()) - 0.1, float(temporal_phase.max()) + 0.1)
|
| 158 |
+
ax_phase_line.grid(True, alpha=0.3)
|
| 159 |
+
ax_phase_line.set_xlabel("Frame")
|
| 160 |
+
ax_phase_line.set_ylabel("Average Phase (rad)")
|
| 161 |
+
|
| 162 |
+
def animate(idx: int):
|
| 163 |
+
mag_img.set_array(magnitude[idx].numpy())
|
| 164 |
+
ax_mag.set_title(f"Magnitude – Frame {idx}")
|
| 165 |
+
xs = np.arange(idx + 1)
|
| 166 |
+
mag_line.set_data(xs, temporal_mag[: idx + 1].numpy())
|
| 167 |
+
phase_line.set_data(xs, temporal_phase[: idx + 1])
|
| 168 |
+
return mag_img, mag_line, phase_line
|
| 169 |
+
|
| 170 |
+
self._save_animation(fig, animate, output_path)
|
| 171 |
+
|
| 172 |
+
def save_angle_delay_animation(self, tensor: torch.Tensor, output_path: Path, keep_percentage: Optional[float] = None) -> None:
|
| 173 |
+
tensor = self._ensure_complex(tensor)
|
| 174 |
+
magnitude = tensor.abs().cpu()
|
| 175 |
+
phase = torch.angle(tensor).cpu()
|
| 176 |
+
keep_suffix = "" if keep_percentage is None else f" (keep={keep_percentage * 100:.0f}%)"
|
| 177 |
+
|
| 178 |
+
fig, axes = plt.subplots(2, 2, figsize=(18, 10))
|
| 179 |
+
mag_ax, phase_ax, mag_line_ax, phase_line_ax = axes.flat
|
| 180 |
+
mag_img = mag_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto")
|
| 181 |
+
mag_ax.set_xlabel("Delay Bin")
|
| 182 |
+
mag_ax.set_ylabel("Angle Bin")
|
| 183 |
+
fig.colorbar(mag_img, ax=mag_ax, label="Magnitude")
|
| 184 |
+
|
| 185 |
+
phase_img = phase_ax.imshow(phase[0].numpy(), cmap="twilight", origin="upper", aspect="auto", vmin=-math.pi, vmax=math.pi)
|
| 186 |
+
phase_ax.set_xlabel("Delay Bin")
|
| 187 |
+
phase_ax.set_ylabel("Angle Bin")
|
| 188 |
+
fig.colorbar(phase_img, ax=phase_ax, label="Phase (rad)")
|
| 189 |
+
|
| 190 |
+
temporal_mag = magnitude.mean(dim=(1, 2))
|
| 191 |
+
temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy())
|
| 192 |
+
mag_line, = mag_line_ax.plot([], [], "r-o", linewidth=2)
|
| 193 |
+
phase_line, = phase_line_ax.plot([], [], "b-s", linewidth=2)
|
| 194 |
+
|
| 195 |
+
for axis, label in ((mag_line_ax, "Average Magnitude"), (phase_line_ax, "Average Phase (rad)")):
|
| 196 |
+
axis.set_xlabel("Frame")
|
| 197 |
+
axis.set_ylabel(label)
|
| 198 |
+
axis.set_xlim(0, tensor.size(0) - 1)
|
| 199 |
+
axis.grid(True, alpha=0.3)
|
| 200 |
+
|
| 201 |
+
def animate(idx: int):
|
| 202 |
+
mag_img.set_array(magnitude[idx].numpy())
|
| 203 |
+
phase_img.set_array(phase[idx].numpy())
|
| 204 |
+
mag_ax.set_title(f"AD Magnitude – Frame {idx}{keep_suffix}")
|
| 205 |
+
phase_ax.set_title(f"AD Phase – Frame {idx}{keep_suffix}")
|
| 206 |
+
xs = np.arange(idx + 1)
|
| 207 |
+
mag_line.set_data(xs, temporal_mag[: idx + 1].numpy())
|
| 208 |
+
phase_line.set_data(xs, temporal_phase[: idx + 1])
|
| 209 |
+
return mag_img, phase_img, mag_line, phase_line
|
| 210 |
+
|
| 211 |
+
self._save_animation(fig, animate, output_path)
|
| 212 |
+
|
| 213 |
+
def save_dominant_bin_animation(self, tensor: torch.Tensor, output_path: Path, threshold_ratio: float = 0.05) -> None:
|
| 214 |
+
tensor = self._ensure_complex(tensor)
|
| 215 |
+
magnitude = tensor.abs().cpu()
|
| 216 |
+
threshold = float(magnitude.max()) * threshold_ratio
|
| 217 |
+
dominant_counts = (magnitude > threshold).sum(dim=(1, 2)).numpy()
|
| 218 |
+
|
| 219 |
+
fig, (heat_ax, line_ax) = plt.subplots(1, 2, figsize=(16, 6))
|
| 220 |
+
heat_img = heat_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto")
|
| 221 |
+
heat_ax.set_xlabel("Delay Bin")
|
| 222 |
+
heat_ax.set_ylabel("Angle Bin")
|
| 223 |
+
fig.colorbar(heat_img, ax=heat_ax, label="Magnitude")
|
| 224 |
+
|
| 225 |
+
count_line, = line_ax.plot([], [], "r-s", linewidth=2)
|
| 226 |
+
line_ax.set_xlabel("Frame")
|
| 227 |
+
line_ax.set_ylabel("Dominant Bin Count")
|
| 228 |
+
line_ax.set_xlim(0, tensor.size(0) - 1)
|
| 229 |
+
line_ax.set_ylim(0, dominant_counts.max() * 1.1)
|
| 230 |
+
line_ax.grid(True, alpha=0.3)
|
| 231 |
+
|
| 232 |
+
def animate(idx: int):
|
| 233 |
+
heat_img.set_array(magnitude[idx].numpy())
|
| 234 |
+
heat_ax.set_title(f"Magnitude – Frame {idx}")
|
| 235 |
+
xs = np.arange(idx + 1)
|
| 236 |
+
count_line.set_data(xs, dominant_counts[: idx + 1])
|
| 237 |
+
return heat_img, count_line
|
| 238 |
+
|
| 239 |
+
self._save_animation(fig, animate, output_path)
|
| 240 |
+
|
| 241 |
+
def save_bin_evolution_plot(self, tensor: torch.Tensor, output_path: Path) -> None:
|
| 242 |
+
tensor = self._ensure_complex(tensor)
|
| 243 |
+
magnitude = tensor.abs()
|
| 244 |
+
avg_mag = magnitude.mean(dim=0)
|
| 245 |
+
flat_mag = avg_mag.flatten()
|
| 246 |
+
k = min(self.config.num_bins, flat_mag.numel())
|
| 247 |
+
_, indices = torch.topk(flat_mag, k)
|
| 248 |
+
angle_indices = (indices // tensor.size(-1)).tolist()
|
| 249 |
+
delay_indices = (indices % tensor.size(-1)).tolist()
|
| 250 |
+
|
| 251 |
+
time_axis = np.arange(tensor.size(0))
|
| 252 |
+
fig, axes = plt.subplots(k, 2, figsize=(12, 3 * k), squeeze=False)
|
| 253 |
+
for row in range(k):
|
| 254 |
+
series = tensor[:, angle_indices[row], delay_indices[row]]
|
| 255 |
+
mag_series = torch.abs(series).cpu().numpy()
|
| 256 |
+
phase_series = np.unwrap(torch.angle(series).cpu().numpy())
|
| 257 |
+
ax_mag, ax_phase = axes[row]
|
| 258 |
+
ax_mag.plot(time_axis, mag_series, "b-", linewidth=2)
|
| 259 |
+
ax_mag.set_title(f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) – Magnitude")
|
| 260 |
+
ax_mag.set_xlabel("Frame")
|
| 261 |
+
ax_mag.set_ylabel("Magnitude")
|
| 262 |
+
ax_mag.grid(True, alpha=0.3)
|
| 263 |
+
ax_phase.plot(time_axis, phase_series, "r-", linewidth=2)
|
| 264 |
+
ax_phase.set_title(f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) – Phase")
|
| 265 |
+
ax_phase.set_xlabel("Frame")
|
| 266 |
+
ax_phase.set_ylabel("Phase (rad)")
|
| 267 |
+
ax_phase.grid(True, alpha=0.3)
|
| 268 |
+
fig.tight_layout()
|
| 269 |
+
output_path = Path(output_path)
|
| 270 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 271 |
+
fig.savefig(output_path, dpi=self.config.dpi, bbox_inches="tight")
|
| 272 |
+
plt.close(fig)
|
LWMTemporal/data/datasets.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import math
|
| 5 |
+
import pickle
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
|
| 13 |
+
from .angle_delay import AngleDelayConfig, AngleDelayProcessor
|
| 14 |
+
from ..models.lwm import ComplexPatchTokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclasses.dataclass
|
| 18 |
+
class AngleDelayDatasetConfig:
|
| 19 |
+
raw_path: Path
|
| 20 |
+
keep_percentage: float = 0.25
|
| 21 |
+
normalize: str = "global_rms"
|
| 22 |
+
cache_dir: Optional[Path] = Path("cache")
|
| 23 |
+
use_cache: bool = True
|
| 24 |
+
overwrite_cache: bool = False
|
| 25 |
+
snr_db: Optional[float] = None
|
| 26 |
+
noise_seed: Optional[int] = None
|
| 27 |
+
max_time_steps: Optional[int] = None
|
| 28 |
+
patch_size: Tuple[int, int] = (1, 1)
|
| 29 |
+
phase_mode: str = "real_imag"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AngleDelaySequenceDataset(Dataset):
|
| 33 |
+
"""Angle-delay dataset with optional caching and metadata retention."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.config = config
|
| 38 |
+
self.logger = logger
|
| 39 |
+
self.tokenizer = ComplexPatchTokenizer(config.phase_mode)
|
| 40 |
+
self.samples: List[torch.Tensor]
|
| 41 |
+
self.avg_speed_mps: List[Optional[float]]
|
| 42 |
+
cache_hit = False
|
| 43 |
+
cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None
|
| 44 |
+
if cache_path and cache_path.exists() and not config.overwrite_cache:
|
| 45 |
+
try:
|
| 46 |
+
payload = torch.load(cache_path, map_location="cpu")
|
| 47 |
+
if isinstance(payload, dict) and "samples" in payload:
|
| 48 |
+
self.samples = payload["samples"]
|
| 49 |
+
self.avg_speed_mps = payload.get("avg_speed_mps", [None] * len(self.samples))
|
| 50 |
+
else:
|
| 51 |
+
self.samples = payload
|
| 52 |
+
self.avg_speed_mps = [None] * len(self.samples)
|
| 53 |
+
cache_hit = True
|
| 54 |
+
except Exception:
|
| 55 |
+
cache_path.unlink(missing_ok=True)
|
| 56 |
+
cache_hit = False
|
| 57 |
+
if not cache_hit:
|
| 58 |
+
self.samples, self.avg_speed_mps = self._build_samples()
|
| 59 |
+
if cache_path is not None:
|
| 60 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
torch.save({"samples": self.samples, "avg_speed_mps": self.avg_speed_mps}, cache_path)
|
| 62 |
+
if self.config.snr_db is not None:
|
| 63 |
+
self._apply_noise()
|
| 64 |
+
|
| 65 |
+
def _cache_path(self) -> Path:
|
| 66 |
+
cfg = self.config
|
| 67 |
+
name = cfg.raw_path.stem
|
| 68 |
+
# Include patch_size and phase_mode in cache name to ensure cache invalidation
|
| 69 |
+
# when these parameters change. Also add 'v2' to invalidate old caches with wrong normalization.
|
| 70 |
+
ph, pw = cfg.patch_size
|
| 71 |
+
cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt"
|
| 72 |
+
return cfg.cache_dir / cache_name # type: ignore[operator]
|
| 73 |
+
|
| 74 |
+
def _load_raw(self) -> Any:
|
| 75 |
+
with self.config.raw_path.open("rb") as handle:
|
| 76 |
+
return pickle.load(handle)
|
| 77 |
+
|
| 78 |
+
def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
"""Normalize a single sample by its own RMS."""
|
| 80 |
+
rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
|
| 81 |
+
return tensor / rms.to(tensor.dtype)
|
| 82 |
+
|
| 83 |
+
def _estimate_speed(self, pos_meta: Any, dt_meta: Any, index: int) -> Optional[float]:
|
| 84 |
+
if pos_meta is None or dt_meta is None:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
def _extract_dt(meta: Any) -> Optional[float]:
|
| 88 |
+
if isinstance(meta, (list, tuple)) and index < len(meta):
|
| 89 |
+
return float(meta[index])
|
| 90 |
+
if isinstance(meta, np.ndarray) and meta.ndim == 1 and index < meta.shape[0]:
|
| 91 |
+
return float(meta[index])
|
| 92 |
+
if isinstance(meta, (int, float)):
|
| 93 |
+
return float(meta)
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
dt = _extract_dt(dt_meta)
|
| 97 |
+
if dt is None or dt <= 0:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def _extract_positions(meta: Any) -> Optional[List[np.ndarray]]:
|
| 101 |
+
if isinstance(meta, (list, tuple)):
|
| 102 |
+
if index >= len(meta):
|
| 103 |
+
return None
|
| 104 |
+
candidate = meta[index]
|
| 105 |
+
elif isinstance(meta, np.ndarray):
|
| 106 |
+
if meta.ndim < 2 or index >= meta.shape[0]:
|
| 107 |
+
return None
|
| 108 |
+
candidate = meta[index]
|
| 109 |
+
else:
|
| 110 |
+
return None
|
| 111 |
+
if isinstance(candidate, list):
|
| 112 |
+
trajs = [np.asarray(traj) for traj in candidate if isinstance(traj, (list, tuple, np.ndarray))]
|
| 113 |
+
else:
|
| 114 |
+
trajs = [np.asarray(candidate)]
|
| 115 |
+
usable = [traj for traj in trajs if traj.ndim >= 2 and traj.shape[0] >= 2]
|
| 116 |
+
return usable if usable else None
|
| 117 |
+
|
| 118 |
+
trajectories = _extract_positions(pos_meta)
|
| 119 |
+
if not trajectories:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
speeds: List[float] = []
|
| 123 |
+
for traj in trajectories:
|
| 124 |
+
diffs = np.linalg.norm(np.diff(traj[:, :2], axis=0), axis=1)
|
| 125 |
+
if diffs.size == 0:
|
| 126 |
+
continue
|
| 127 |
+
velocity = diffs / dt
|
| 128 |
+
finite = velocity[np.isfinite(velocity)]
|
| 129 |
+
if finite.size > 0:
|
| 130 |
+
speeds.append(float(finite.mean()))
|
| 131 |
+
if not speeds:
|
| 132 |
+
return None
|
| 133 |
+
return float(np.mean(speeds))
|
| 134 |
+
|
| 135 |
+
def _build_samples(self) -> tuple[List[torch.Tensor], List[Optional[float]]]:
|
| 136 |
+
payload = self._load_raw()
|
| 137 |
+
pos_meta = payload.get("pos") if isinstance(payload, dict) else None
|
| 138 |
+
dt_meta = payload.get("dt") if isinstance(payload, dict) else None
|
| 139 |
+
channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload
|
| 140 |
+
channel_tensor = torch.as_tensor(channel, dtype=torch.complex64)
|
| 141 |
+
if channel_tensor.ndim == 3:
|
| 142 |
+
channel_tensor = channel_tensor.unsqueeze(0)
|
| 143 |
+
if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps:
|
| 144 |
+
channel_tensor = channel_tensor[:, : self.config.max_time_steps]
|
| 145 |
+
processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage))
|
| 146 |
+
samples: List[torch.Tensor] = []
|
| 147 |
+
avg_speeds: List[Optional[float]] = []
|
| 148 |
+
for idx, seq in enumerate(channel_tensor):
|
| 149 |
+
ad = processor.forward(seq)
|
| 150 |
+
truncated, _ = processor.truncate_delay_bins(ad)
|
| 151 |
+
samples.append(truncated)
|
| 152 |
+
avg_speeds.append(self._estimate_speed(pos_meta, dt_meta, idx))
|
| 153 |
+
|
| 154 |
+
# Apply normalization after collecting all samples
|
| 155 |
+
if self.config.normalize == "per_sample_rms":
|
| 156 |
+
samples = [self._normalize_sample(s) for s in samples]
|
| 157 |
+
elif self.config.normalize == "global_rms":
|
| 158 |
+
# Compute global RMS across all samples
|
| 159 |
+
total_sum_sq = 0.0
|
| 160 |
+
total_count = 0
|
| 161 |
+
for s in samples:
|
| 162 |
+
s_real = s.real.float()
|
| 163 |
+
s_imag = s.imag.float()
|
| 164 |
+
total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item()
|
| 165 |
+
total_count += s_real.numel()
|
| 166 |
+
if total_count > 0:
|
| 167 |
+
global_rms = math.sqrt(total_sum_sq / total_count)
|
| 168 |
+
global_rms = max(global_rms, 1e-8)
|
| 169 |
+
samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples]
|
| 170 |
+
|
| 171 |
+
return samples, avg_speeds
|
| 172 |
+
|
| 173 |
+
def _apply_noise(self) -> None:
|
| 174 |
+
if self.config.noise_seed is not None:
|
| 175 |
+
torch.manual_seed(int(self.config.noise_seed))
|
| 176 |
+
noisy: List[torch.Tensor] = []
|
| 177 |
+
snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0)
|
| 178 |
+
for sample in self.samples:
|
| 179 |
+
real = sample.real.float()
|
| 180 |
+
imag = sample.imag.float()
|
| 181 |
+
power = (real.square() + imag.square()).mean().item()
|
| 182 |
+
if power <= 0:
|
| 183 |
+
noisy.append(sample)
|
| 184 |
+
continue
|
| 185 |
+
noise_var = power / snr_lin
|
| 186 |
+
std = math.sqrt(noise_var / 2.0)
|
| 187 |
+
noise_real = torch.randn_like(real) * std
|
| 188 |
+
noise_imag = torch.randn_like(imag) * std
|
| 189 |
+
noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype))
|
| 190 |
+
noisy.append((sample + noise).to(sample.dtype))
|
| 191 |
+
self.samples = noisy
|
| 192 |
+
|
| 193 |
+
def __len__(self) -> int:
|
| 194 |
+
return len(self.samples)
|
| 195 |
+
|
| 196 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
| 197 |
+
sample = self.samples[index]
|
| 198 |
+
tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size)
|
| 199 |
+
tokens = tokens.squeeze(0)
|
| 200 |
+
base_mask = base_mask.squeeze(0)
|
| 201 |
+
T, N, M = sample.shape
|
| 202 |
+
ph, pw = self.config.patch_size
|
| 203 |
+
H = N // ph
|
| 204 |
+
W = M // pw
|
| 205 |
+
shape = torch.tensor([T, H, W], dtype=torch.long)
|
| 206 |
+
avg_speed = self.avg_speed_mps[index] if index < len(self.avg_speed_mps) else None
|
| 207 |
+
payload: Dict[str, Any] = {
|
| 208 |
+
"sequence": sample,
|
| 209 |
+
"tokens": tokens,
|
| 210 |
+
"base_mask": base_mask,
|
| 211 |
+
"shape": shape,
|
| 212 |
+
}
|
| 213 |
+
if avg_speed is not None:
|
| 214 |
+
payload["avg_speed"] = torch.tensor(avg_speed, dtype=torch.float32)
|
| 215 |
+
return payload
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def load_adseq_dataset(
|
| 219 |
+
data_path: str | Path,
|
| 220 |
+
keep_percentage: float = 0.25,
|
| 221 |
+
normalize: str = "global_rms",
|
| 222 |
+
cache_dir: Optional[str | Path] = "cache",
|
| 223 |
+
use_cache: bool = True,
|
| 224 |
+
overwrite_cache: bool = False,
|
| 225 |
+
logger: Optional[Any] = None,
|
| 226 |
+
snr_db: Optional[float] = None,
|
| 227 |
+
noise_seed: Optional[int] = None,
|
| 228 |
+
max_time_steps: Optional[int] = None,
|
| 229 |
+
) -> AngleDelaySequenceDataset:
|
| 230 |
+
cfg = AngleDelayDatasetConfig(
|
| 231 |
+
raw_path=Path(data_path),
|
| 232 |
+
keep_percentage=keep_percentage,
|
| 233 |
+
normalize=normalize,
|
| 234 |
+
cache_dir=None if cache_dir is None else Path(cache_dir),
|
| 235 |
+
use_cache=use_cache,
|
| 236 |
+
overwrite_cache=overwrite_cache,
|
| 237 |
+
snr_db=snr_db,
|
| 238 |
+
noise_seed=noise_seed,
|
| 239 |
+
max_time_steps=max_time_steps,
|
| 240 |
+
)
|
| 241 |
+
return AngleDelaySequenceDataset(cfg, logger=logger)
|
LWMTemporal/data/deepmimo_adapter.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np # type: ignore[import]
|
| 9 |
+
|
| 10 |
+
_LOGGER = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
try: # pragma: no cover - optional dependency
|
| 13 |
+
import deepmimo # type: ignore[import]
|
| 14 |
+
from deepmimo import config as deepmimo_config # type: ignore[import]
|
| 15 |
+
|
| 16 |
+
_HAS_DEEPMIMO_V4 = True
|
| 17 |
+
except Exception: # pragma: no cover - DeepMIMO v4 not installed
|
| 18 |
+
deepmimo = None # type: ignore[assignment]
|
| 19 |
+
deepmimo_config = None # type: ignore[assignment]
|
| 20 |
+
_HAS_DEEPMIMO_V4 = False
|
| 21 |
+
|
| 22 |
+
try: # pragma: no cover - legacy fallback
|
| 23 |
+
from input_preprocess import DeepMIMO_data_gen as _legacy_data_gen # type: ignore[import]
|
| 24 |
+
except Exception: # pragma: no cover - legacy loader unavailable
|
| 25 |
+
_legacy_data_gen = None
|
| 26 |
+
|
| 27 |
+
ArrayLike = Union[np.ndarray, "np.typing.NDArray[np.floating[Any]]"]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class _PathTable:
|
| 32 |
+
power: np.ndarray
|
| 33 |
+
phase: np.ndarray
|
| 34 |
+
delay: np.ndarray
|
| 35 |
+
aoa_az: np.ndarray
|
| 36 |
+
aoa_el: np.ndarray
|
| 37 |
+
aod_az: np.ndarray
|
| 38 |
+
aod_el: np.ndarray
|
| 39 |
+
interactions: np.ndarray
|
| 40 |
+
num_paths: np.ndarray
|
| 41 |
+
los_user: np.ndarray
|
| 42 |
+
locations: np.ndarray
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class _LazyPathAccessor:
|
| 46 |
+
"""Lazy view over per-user path dictionaries compatible with v3 interface."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, data: _PathTable) -> None:
|
| 49 |
+
self._data = data
|
| 50 |
+
|
| 51 |
+
def __len__(self) -> int:
|
| 52 |
+
return int(self._data.num_paths.shape[0])
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, index: Union[int, slice, Sequence[int]]) -> Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]:
|
| 55 |
+
if isinstance(index, slice):
|
| 56 |
+
return [self[i] for i in range(*index.indices(len(self)))]
|
| 57 |
+
if isinstance(index, Sequence) and not isinstance(index, (str, bytes)):
|
| 58 |
+
return [self[int(i)] for i in index]
|
| 59 |
+
idx = int(index)
|
| 60 |
+
count = int(self._data.num_paths[idx])
|
| 61 |
+
if count <= 0:
|
| 62 |
+
empty = np.empty((0,), dtype=np.float32)
|
| 63 |
+
return {
|
| 64 |
+
"num_paths": 0,
|
| 65 |
+
"DoD_theta": empty,
|
| 66 |
+
"DoD_phi": empty,
|
| 67 |
+
"DoA_theta": empty,
|
| 68 |
+
"DoA_phi": empty,
|
| 69 |
+
"phase": empty,
|
| 70 |
+
"ToA": empty,
|
| 71 |
+
"power": empty,
|
| 72 |
+
"LoS": np.empty((0,), dtype=np.int32),
|
| 73 |
+
}
|
| 74 |
+
sl = slice(0, count)
|
| 75 |
+
interactions = np.asarray(self._data.interactions[idx, sl])
|
| 76 |
+
los_per_path = np.where(np.isnan(interactions), 0, (interactions == 0).astype(np.int32))
|
| 77 |
+
return {
|
| 78 |
+
"num_paths": count,
|
| 79 |
+
"DoD_theta": np.asarray(self._data.aod_el[idx, sl]),
|
| 80 |
+
"DoD_phi": np.asarray(self._data.aod_az[idx, sl]),
|
| 81 |
+
"DoA_theta": np.asarray(self._data.aoa_el[idx, sl]),
|
| 82 |
+
"DoA_phi": np.asarray(self._data.aoa_az[idx, sl]),
|
| 83 |
+
"phase": np.asarray(self._data.phase[idx, sl]),
|
| 84 |
+
"ToA": np.asarray(self._data.delay[idx, sl]),
|
| 85 |
+
"power": np.asarray(self._data.power[idx, sl]),
|
| 86 |
+
"LoS": los_per_path.astype(np.int32),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _cast(array: ArrayLike, dtype: np.dtype[Any]) -> np.ndarray:
|
| 91 |
+
arr = np.asarray(array)
|
| 92 |
+
if arr.dtype == dtype:
|
| 93 |
+
return arr
|
| 94 |
+
return arr.astype(dtype, copy=True)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _load_v4_dataset(
|
| 98 |
+
scenario: str,
|
| 99 |
+
*,
|
| 100 |
+
scenarios_dir: Optional[Path],
|
| 101 |
+
load_params: Optional[Dict[str, Any]],
|
| 102 |
+
max_paths: Optional[int],
|
| 103 |
+
array_dtype: np.dtype[Any],
|
| 104 |
+
logger: Optional[logging.Logger],
|
| 105 |
+
) -> Dict[str, Any]:
|
| 106 |
+
if not _HAS_DEEPMIMO_V4:
|
| 107 |
+
raise RuntimeError("DeepMIMO v4 package is not available in the current environment")
|
| 108 |
+
|
| 109 |
+
if scenarios_dir is not None:
|
| 110 |
+
deepmimo_config.set("scenarios_folder", str(scenarios_dir)) # type: ignore[attr-defined]
|
| 111 |
+
|
| 112 |
+
params = dict(load_params or {})
|
| 113 |
+
if max_paths is not None:
|
| 114 |
+
params.setdefault("max_paths", int(max_paths))
|
| 115 |
+
|
| 116 |
+
dataset = deepmimo.load(scenario, **params) # type: ignore[call-arg]
|
| 117 |
+
logger = logger or _LOGGER
|
| 118 |
+
logger.info(
|
| 119 |
+
"Loaded DeepMIMO v4 scenario '%s' with %s users and %s max paths", # pragma: no cover - logging
|
| 120 |
+
scenario,
|
| 121 |
+
getattr(dataset, "n_ue", "unknown"),
|
| 122 |
+
params.get("max_paths", "default"),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
num_paths_raw = np.asarray(dataset.num_paths)
|
| 126 |
+
tx_axis: Optional[int] = None
|
| 127 |
+
if num_paths_raw.ndim > 1 and num_paths_raw.shape[0] > 1:
|
| 128 |
+
axes = tuple(range(1, num_paths_raw.ndim))
|
| 129 |
+
scores = num_paths_raw.sum(axis=axes)
|
| 130 |
+
tx_axis = int(np.argmax(scores))
|
| 131 |
+
|
| 132 |
+
def _select_tx(arr: Any, dtype: Optional[np.dtype[Any]] = None) -> np.ndarray:
|
| 133 |
+
out = np.asarray(arr)
|
| 134 |
+
if out.dtype == object:
|
| 135 |
+
out = np.stack([np.asarray(v) for v in out], axis=0)
|
| 136 |
+
if tx_axis is not None and out.ndim >= 1 and out.shape[0] == num_paths_raw.shape[0]:
|
| 137 |
+
out = out[tx_axis]
|
| 138 |
+
if dtype is not None:
|
| 139 |
+
out = out.astype(dtype, copy=False)
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
num_paths = _select_tx(num_paths_raw, dtype=np.int32).reshape(-1)
|
| 143 |
+
power = _select_tx(dataset.power, dtype=array_dtype)
|
| 144 |
+
phase = _select_tx(dataset.phase, dtype=array_dtype)
|
| 145 |
+
delay = _select_tx(dataset.delay, dtype=array_dtype)
|
| 146 |
+
aoa_az = _select_tx(dataset.aoa_az, dtype=array_dtype)
|
| 147 |
+
aoa_el = _select_tx(dataset.aoa_el, dtype=array_dtype)
|
| 148 |
+
aod_az = _select_tx(dataset.aod_az, dtype=array_dtype)
|
| 149 |
+
aod_el = _select_tx(dataset.aod_el, dtype=array_dtype)
|
| 150 |
+
interactions = _select_tx(dataset.inter, dtype=array_dtype)
|
| 151 |
+
los_raw = getattr(dataset, "los", None)
|
| 152 |
+
if los_raw is None:
|
| 153 |
+
los_selected = np.zeros_like(num_paths, dtype=np.int8)
|
| 154 |
+
else:
|
| 155 |
+
los_selected = _select_tx(los_raw, dtype=np.int8)
|
| 156 |
+
locations = _select_tx(dataset.rx_pos, dtype=np.float32)
|
| 157 |
+
if locations.ndim == 1:
|
| 158 |
+
if locations.size % 3 == 0:
|
| 159 |
+
locations = locations.reshape(-1, 3)
|
| 160 |
+
else:
|
| 161 |
+
locations = locations.reshape(-1, 1)
|
| 162 |
+
if locations.ndim > 2:
|
| 163 |
+
locations = locations.reshape(locations.shape[0], -1)
|
| 164 |
+
|
| 165 |
+
path_table = _PathTable(
|
| 166 |
+
power=power,
|
| 167 |
+
phase=phase,
|
| 168 |
+
delay=delay,
|
| 169 |
+
aoa_az=aoa_az,
|
| 170 |
+
aoa_el=aoa_el,
|
| 171 |
+
aod_az=aod_az,
|
| 172 |
+
aod_el=aod_el,
|
| 173 |
+
interactions=interactions,
|
| 174 |
+
num_paths=num_paths,
|
| 175 |
+
los_user=los_selected.reshape(-1),
|
| 176 |
+
locations=locations,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Help GC release original dataset arrays early
|
| 180 |
+
del dataset
|
| 181 |
+
|
| 182 |
+
user_payload = {
|
| 183 |
+
"paths": _LazyPathAccessor(path_table),
|
| 184 |
+
"LoS": path_table.los_user,
|
| 185 |
+
"location": path_table.locations,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
return {
|
| 189 |
+
"user": user_payload,
|
| 190 |
+
"_path_data": path_table,
|
| 191 |
+
"_source": "deepmimo_v4",
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_deepmimo_user_data(
|
| 196 |
+
scenario: str,
|
| 197 |
+
*,
|
| 198 |
+
scenarios_dir: Optional[Path] = None,
|
| 199 |
+
load_params: Optional[Dict[str, Any]] = None,
|
| 200 |
+
max_paths: Optional[int] = None,
|
| 201 |
+
array_dtype: np.dtype[Any] = np.float32,
|
| 202 |
+
logger: Optional[logging.Logger] = None,
|
| 203 |
+
) -> Dict[str, Any]:
|
| 204 |
+
"""Load DeepMIMO scenario data in a form compatible with legacy utilities.
|
| 205 |
+
|
| 206 |
+
The returned dictionary mimics the structure produced by DeepMIMO v3's
|
| 207 |
+
``DeepMIMO_data_gen`` so downstream utilities (e.g., dynamic scenario
|
| 208 |
+
generation) can operate without modification. When DeepMIMO v4 is not
|
| 209 |
+
available, the function falls back to the legacy generator if present.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
if _HAS_DEEPMIMO_V4:
|
| 213 |
+
return _load_v4_dataset(
|
| 214 |
+
scenario,
|
| 215 |
+
scenarios_dir=scenarios_dir,
|
| 216 |
+
load_params=load_params,
|
| 217 |
+
max_paths=max_paths,
|
| 218 |
+
array_dtype=array_dtype,
|
| 219 |
+
logger=logger,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if _legacy_data_gen is not None:
|
| 223 |
+
raise RuntimeError(
|
| 224 |
+
"DeepMIMO v4 is not installed. The repository still includes the legacy "
|
| 225 |
+
"DeepMIMO_data_gen interface, but integration parameters must be provided "
|
| 226 |
+
"explicitly. Please migrate to the official DeepMIMO package or invoke "
|
| 227 |
+
"DeepMIMO_data_gen directly from your own tooling."
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
raise RuntimeError(
|
| 231 |
+
"Neither DeepMIMO v4 nor the legacy DeepMIMO_data_gen function is available. "
|
| 232 |
+
"Please install the DeepMIMO package or provide the legacy generator."
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
__all__ = ["load_deepmimo_user_data"]
|
LWMTemporal/data/scenario_generation.py
ADDED
|
@@ -0,0 +1,1161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import logging
|
| 5 |
+
import warnings
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt # type: ignore[import]
|
| 11 |
+
from matplotlib.lines import Line2D # type: ignore[attr-defined]
|
| 12 |
+
import networkx as nx # type: ignore[import]
|
| 13 |
+
import numpy as np # type: ignore[import]
|
| 14 |
+
from scipy.spatial import KDTree # type: ignore[import]
|
| 15 |
+
|
| 16 |
+
from .deepmimo_adapter import load_deepmimo_user_data
|
| 17 |
+
|
| 18 |
+
try: # pragma: no cover - DeepMIMO may not be installed in all environments
|
| 19 |
+
import DeepMIMOv3.consts as c # type: ignore
|
| 20 |
+
except Exception: # pragma: no cover - maintain compatibility when DeepMIMO is missing
|
| 21 |
+
class _C: # noqa: D401 - minimal stub to satisfy typing
|
| 22 |
+
"""Fallback constants when DeepMIMOv3 is not available."""
|
| 23 |
+
|
| 24 |
+
c = _C()
|
| 25 |
+
c.PARAMSET_OFDM = "OFDM"
|
| 26 |
+
c.PARAMSET_OFDM_BW = "bandwidth"
|
| 27 |
+
c.PARAMSET_OFDM_BW_MULT = 1e9
|
| 28 |
+
c.PARAMSET_OFDM_SC_SAMP = "selected_subcarriers"
|
| 29 |
+
c.PARAMSET_OFDM_SC_NUM = "subcarriers"
|
| 30 |
+
c.PARAMSET_OFDM_LPF = "LPF"
|
| 31 |
+
c.PARAMSET_FDTD = "FDTD"
|
| 32 |
+
c.PARAMSET_ANT_SHAPE = "shape"
|
| 33 |
+
c.PARAMSET_ANT_SPACING = "spacing"
|
| 34 |
+
c.PARAMSET_ANT_ROTATION = "rotation"
|
| 35 |
+
c.PARAMSET_ANT_FOV = "FoV"
|
| 36 |
+
c.PARAMSET_ANT_RAD_PAT = "radiation_pattern"
|
| 37 |
+
c.OUT_PATH_NUM = "num_paths"
|
| 38 |
+
c.OUT_PATH_DOD_THETA = "DoD_theta"
|
| 39 |
+
c.OUT_PATH_DOD_PHI = "DoD_phi"
|
| 40 |
+
c.OUT_PATH_DOA_THETA = "DoA_theta"
|
| 41 |
+
c.OUT_PATH_DOA_PHI = "DoA_phi"
|
| 42 |
+
c.OUT_PATH_PHASE = "phase"
|
| 43 |
+
c.OUT_PATH_TOA = "ToA"
|
| 44 |
+
c.OUT_PATH_RX_POW = "power"
|
| 45 |
+
c.OUT_PATH_DOP_VEL = "Doppler_vel"
|
| 46 |
+
c.OUT_PATH_DOP_ACC = "Doppler_acc"
|
| 47 |
+
c.PARAMSET_DOPPLER_EN = "Doppler"
|
| 48 |
+
c.PARAMSET_SCENARIO_PARAMS = "scenario_params"
|
| 49 |
+
c.PARAMSET_SCENARIO_PARAMS_DOPPLER_EN = "Doppler_enabled"
|
| 50 |
+
c.PARAMSET_SCENARIO_PARAMS_CF = "carrier_freq"
|
| 51 |
+
c.LIGHTSPEED = 3e8
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
_LOGGER = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class AntennaArrayConfig:
|
| 59 |
+
"""Configuration describing the transmit and receive array geometry."""
|
| 60 |
+
|
| 61 |
+
tx_horizontal: int = 32
|
| 62 |
+
tx_vertical: int = 1
|
| 63 |
+
rx_horizontal: int = 1
|
| 64 |
+
rx_vertical: int = 1
|
| 65 |
+
subcarriers: int = 32
|
| 66 |
+
spacing: float = 0.5
|
| 67 |
+
tx_rotation: Tuple[float, float, float] = (0.0, 0.0, -135.0)
|
| 68 |
+
rx_rotation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
| 69 |
+
field_of_view: Tuple[float, float] = (360.0, 180.0)
|
| 70 |
+
|
| 71 |
+
def total_tx_elements(self) -> int:
|
| 72 |
+
return int(self.tx_horizontal * self.tx_vertical)
|
| 73 |
+
|
| 74 |
+
def tx_shape(self) -> np.ndarray:
|
| 75 |
+
return np.asarray([self.tx_horizontal, self.tx_vertical, 1])
|
| 76 |
+
|
| 77 |
+
def rx_shape(self) -> np.ndarray:
|
| 78 |
+
return np.asarray([self.rx_horizontal, self.rx_vertical, 1])
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@dataclass
|
| 82 |
+
class TrafficConfig:
|
| 83 |
+
"""Parameters controlling vehicle and pedestrian traffic synthesis."""
|
| 84 |
+
|
| 85 |
+
num_vehicles: int = 50
|
| 86 |
+
num_pedestrians: int = 10
|
| 87 |
+
vehicle_speed_range: Tuple[float, float] = (5 / 3.6, 60 / 3.6)
|
| 88 |
+
pedestrian_speed_range: Tuple[float, float] = (0.5, 2.0)
|
| 89 |
+
turn_probability: float = 0.1
|
| 90 |
+
max_attempts: int = 300
|
| 91 |
+
pedestrian_angle_std: float = 0.1
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class GridConfig:
|
| 96 |
+
"""Describe the road grid geometry to be inferred from DeepMIMO positions."""
|
| 97 |
+
|
| 98 |
+
road_width: float = 6.0
|
| 99 |
+
road_center_spacing: float = 25.0
|
| 100 |
+
step_size: float = 1.0
|
| 101 |
+
auto_step_size: bool = True
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class ScenarioSamplingConfig:
|
| 106 |
+
"""Temporal sampling configuration for the dynamic scenario."""
|
| 107 |
+
|
| 108 |
+
time_steps: int = 20
|
| 109 |
+
continuous_length: Optional[int] = None
|
| 110 |
+
sample_dt: float = 1e-3
|
| 111 |
+
continuous_mode: bool = True
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class ScenarioGenerationConfig:
|
| 116 |
+
"""All inputs required to synthesize a dynamic DeepMIMO scenario."""
|
| 117 |
+
|
| 118 |
+
scenario: str
|
| 119 |
+
antenna: AntennaArrayConfig = field(default_factory=AntennaArrayConfig)
|
| 120 |
+
sampling: ScenarioSamplingConfig = field(default_factory=ScenarioSamplingConfig)
|
| 121 |
+
grid: GridConfig = field(default_factory=GridConfig)
|
| 122 |
+
traffic: TrafficConfig = field(default_factory=TrafficConfig)
|
| 123 |
+
carrier_frequency_hz: float = 3.5e9
|
| 124 |
+
output_dir: Path = Path("examples/data")
|
| 125 |
+
full_output_dir: Path = Path("examples/full_data")
|
| 126 |
+
figures_dir: Optional[Path] = Path("figs")
|
| 127 |
+
export_environment_plot: bool = True
|
| 128 |
+
rng_seed: Optional[int] = None
|
| 129 |
+
scenarios_dir: Optional[Path] = None
|
| 130 |
+
deepmimo_max_paths: Optional[int] = 6
|
| 131 |
+
deepmimo_load_params: Optional[Dict[str, Any]] = None
|
| 132 |
+
deepmimo_array_dtype: np.dtype[Any] = np.float32
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclasses.dataclass
|
| 136 |
+
class ScenarioGenerationResult:
|
| 137 |
+
"""Container for the generated dynamic scenario payload."""
|
| 138 |
+
|
| 139 |
+
payload: Dict[str, Any]
|
| 140 |
+
output_path: Path
|
| 141 |
+
full_output_path: Path
|
| 142 |
+
generated: bool
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
# Geometry helpers
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
|
| 149 |
+
def infer_grid_step(positions: np.ndarray, max_samples: int = 20000) -> float:
|
| 150 |
+
"""Infer the base grid spacing from the DeepMIMO user positions."""
|
| 151 |
+
|
| 152 |
+
unique_xy = np.unique(positions[:, :2], axis=0)
|
| 153 |
+
if unique_xy.shape[0] < 2:
|
| 154 |
+
return 1.0
|
| 155 |
+
if unique_xy.shape[0] > max_samples:
|
| 156 |
+
rng = np.random.default_rng()
|
| 157 |
+
unique_xy = unique_xy[rng.choice(unique_xy.shape[0], size=max_samples, replace=False)]
|
| 158 |
+
tree = KDTree(unique_xy)
|
| 159 |
+
distances, _ = tree.query(unique_xy, k=2)
|
| 160 |
+
nearest = distances[:, 1]
|
| 161 |
+
nearest = nearest[nearest > 0]
|
| 162 |
+
if nearest.size == 0:
|
| 163 |
+
return 1.0
|
| 164 |
+
min_distance = float(np.min(nearest))
|
| 165 |
+
tolerance = min_distance * 0.1
|
| 166 |
+
count_at_min = np.sum(np.abs(nearest - min_distance) < tolerance)
|
| 167 |
+
if count_at_min >= 0.1 * nearest.size:
|
| 168 |
+
return min_distance
|
| 169 |
+
lo, hi = np.percentile(nearest, [5, 30])
|
| 170 |
+
mask = (nearest >= lo) & (nearest <= hi)
|
| 171 |
+
candidate = nearest[mask]
|
| 172 |
+
return float(np.median(candidate if candidate.size else nearest))
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def filter_road_positions(valid_positions: np.ndarray, road_width: float, spacing: float) -> Tuple[np.ndarray, Dict[Tuple[float, float, float], Tuple[Tuple[int, int], str]]]:
|
| 176 |
+
road_positions: List[np.ndarray] = []
|
| 177 |
+
lane_info: Dict[Tuple[float, float, float], Tuple[Tuple[int, int], str]] = {}
|
| 178 |
+
half = road_width / 2.0
|
| 179 |
+
for pos in valid_positions:
|
| 180 |
+
x, y, _ = pos
|
| 181 |
+
cx = round(x / spacing) * spacing
|
| 182 |
+
cy = round(y / spacing) * spacing
|
| 183 |
+
dx, dy = x - cx, y - cy
|
| 184 |
+
on_vertical = abs(dx) < half
|
| 185 |
+
on_horizontal = abs(dy) < half
|
| 186 |
+
key = tuple(pos)
|
| 187 |
+
if on_vertical and not on_horizontal:
|
| 188 |
+
direction = (0, 1) if dx >= 0 else (0, -1)
|
| 189 |
+
lane_info[key] = (direction, "vertical")
|
| 190 |
+
road_positions.append(pos)
|
| 191 |
+
elif on_horizontal and not on_vertical:
|
| 192 |
+
direction = (1, 0) if dy < 0 else (-1, 0)
|
| 193 |
+
lane_info[key] = (direction, "horizontal")
|
| 194 |
+
road_positions.append(pos)
|
| 195 |
+
elif on_vertical and on_horizontal:
|
| 196 |
+
direction = (0, 1) if dx >= 0 else (0, -1)
|
| 197 |
+
lane_info[key] = (direction, "intersection")
|
| 198 |
+
road_positions.append(pos)
|
| 199 |
+
return np.asarray(road_positions), lane_info
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def create_grid_road_network(road_positions: np.ndarray, lane_info: Dict[Tuple[float, float, float], Tuple[Tuple[int, int], str]], step_size: float) -> nx.DiGraph:
|
| 203 |
+
graph: nx.DiGraph = nx.DiGraph()
|
| 204 |
+
pos_dict = {tuple(pos): idx for idx, pos in enumerate(road_positions)}
|
| 205 |
+
for pos, idx in pos_dict.items():
|
| 206 |
+
if pos not in lane_info:
|
| 207 |
+
continue
|
| 208 |
+
direction, lane_type = lane_info[pos]
|
| 209 |
+
graph.add_node(idx, pos=np.array(pos), direction=direction, lane_type=lane_type)
|
| 210 |
+
tree = KDTree(road_positions)
|
| 211 |
+
for idx, pos in enumerate(road_positions):
|
| 212 |
+
if idx not in graph.nodes:
|
| 213 |
+
continue
|
| 214 |
+
neighbors = tree.query_ball_point(pos, r=step_size + 0.1)
|
| 215 |
+
for nb in neighbors:
|
| 216 |
+
if nb == idx or nb not in graph.nodes:
|
| 217 |
+
continue
|
| 218 |
+
target = road_positions[nb]
|
| 219 |
+
distance = np.linalg.norm(pos - target)
|
| 220 |
+
if not np.isclose(distance, step_size, atol=0.1):
|
| 221 |
+
continue
|
| 222 |
+
move_dir = (int(np.sign(target[0] - pos[0])), int(np.sign(target[1] - pos[1])))
|
| 223 |
+
lane_type = graph.nodes[idx].get("lane_type", "vertical")
|
| 224 |
+
if lane_type == "intersection":
|
| 225 |
+
if move_dir in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
|
| 226 |
+
graph.add_edge(idx, nb, weight=distance)
|
| 227 |
+
elif move_dir == graph.nodes[idx]["direction"]:
|
| 228 |
+
graph.add_edge(idx, nb, weight=distance)
|
| 229 |
+
return graph
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _fallback_anywalk(road_positions: np.ndarray, step_size: float, length: int, start_idx: Optional[int] = None) -> np.ndarray:
|
| 233 |
+
if start_idx is None:
|
| 234 |
+
start_idx = np.random.randint(0, len(road_positions))
|
| 235 |
+
pos = road_positions[start_idx]
|
| 236 |
+
trajectory = [pos]
|
| 237 |
+
tree = KDTree(road_positions)
|
| 238 |
+
for _ in range(length - 1):
|
| 239 |
+
idxs = tree.query_ball_point(pos, r=step_size * 1.1)
|
| 240 |
+
candidates = []
|
| 241 |
+
for idx in idxs:
|
| 242 |
+
if np.allclose(road_positions[idx], pos):
|
| 243 |
+
continue
|
| 244 |
+
d = np.linalg.norm(road_positions[idx] - pos)
|
| 245 |
+
if np.isclose(d, step_size, atol=0.15 * max(1.0, step_size)):
|
| 246 |
+
candidates.append(idx)
|
| 247 |
+
if not candidates:
|
| 248 |
+
idxs = tree.query_ball_point(pos, r=max(1.5 * step_size, step_size + 0.5))
|
| 249 |
+
if not idxs:
|
| 250 |
+
break
|
| 251 |
+
candidate_idx = min(idxs, key=lambda j: np.linalg.norm(road_positions[j] - pos))
|
| 252 |
+
else:
|
| 253 |
+
candidate_idx = np.random.choice(candidates)
|
| 254 |
+
pos = road_positions[candidate_idx]
|
| 255 |
+
trajectory.append(pos)
|
| 256 |
+
return np.asarray(trajectory)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def generate_smooth_grid_trajectory(graph: nx.DiGraph, road_positions: np.ndarray, turn_probability: float, sequence_length: int = 12, start_node: Optional[int] = None) -> np.ndarray:
|
| 260 |
+
if start_node is None:
|
| 261 |
+
nodes = list(graph.nodes)
|
| 262 |
+
if not nodes:
|
| 263 |
+
return np.empty((0, 3))
|
| 264 |
+
start_node = np.random.choice(nodes)
|
| 265 |
+
trajectory = [road_positions[start_node]]
|
| 266 |
+
current = start_node
|
| 267 |
+
previous: Optional[int] = None
|
| 268 |
+
for _ in range(sequence_length - 1):
|
| 269 |
+
if current not in graph:
|
| 270 |
+
break
|
| 271 |
+
neighbors = list(graph.neighbors(current))
|
| 272 |
+
if previous in neighbors:
|
| 273 |
+
neighbors.remove(previous)
|
| 274 |
+
if not neighbors:
|
| 275 |
+
remaining = sequence_length - len(trajectory)
|
| 276 |
+
if remaining > 0:
|
| 277 |
+
trajectory.extend([road_positions[current]] * remaining)
|
| 278 |
+
break
|
| 279 |
+
node_data = graph.nodes[current]
|
| 280 |
+
lane_type = node_data.get("lane_type", "vertical")
|
| 281 |
+
if lane_type == "intersection" and previous is not None:
|
| 282 |
+
prev_pos = np.asarray(graph.nodes[previous]["pos"])
|
| 283 |
+
curr_pos = np.asarray(node_data["pos"])
|
| 284 |
+
incoming = (
|
| 285 |
+
int(np.sign(curr_pos[0] - prev_pos[0])),
|
| 286 |
+
int(np.sign(curr_pos[1] - prev_pos[1])),
|
| 287 |
+
)
|
| 288 |
+
default_direction = incoming
|
| 289 |
+
else:
|
| 290 |
+
default_direction = node_data.get("direction", None)
|
| 291 |
+
position = np.asarray(node_data["pos"])
|
| 292 |
+
forward_neighbors: List[Tuple[int, float]] = []
|
| 293 |
+
turn_neighbors: List[Tuple[int, float]] = []
|
| 294 |
+
for nb in neighbors:
|
| 295 |
+
neighbor_pos = np.asarray(graph.nodes[nb]["pos"])
|
| 296 |
+
move_dir = (
|
| 297 |
+
int(np.sign(neighbor_pos[0] - position[0])),
|
| 298 |
+
int(np.sign(neighbor_pos[1] - position[1])),
|
| 299 |
+
)
|
| 300 |
+
distance = float(np.linalg.norm(neighbor_pos - position))
|
| 301 |
+
if move_dir == default_direction:
|
| 302 |
+
forward_neighbors.append((nb, distance))
|
| 303 |
+
else:
|
| 304 |
+
turn_neighbors.append((nb, distance))
|
| 305 |
+
if lane_type == "intersection":
|
| 306 |
+
r = np.random.rand()
|
| 307 |
+
if forward_neighbors and r > turn_probability:
|
| 308 |
+
nxt = min(forward_neighbors, key=lambda item: item[1])[0]
|
| 309 |
+
elif turn_neighbors and r < turn_probability:
|
| 310 |
+
nxt = min(turn_neighbors, key=lambda item: item[1])[0]
|
| 311 |
+
elif forward_neighbors:
|
| 312 |
+
nxt = min(forward_neighbors, key=lambda item: item[1])[0]
|
| 313 |
+
elif turn_neighbors:
|
| 314 |
+
nxt = min(turn_neighbors, key=lambda item: item[1])[0]
|
| 315 |
+
else:
|
| 316 |
+
remaining = sequence_length - len(trajectory)
|
| 317 |
+
if remaining > 0:
|
| 318 |
+
trajectory.extend([road_positions[current]] * remaining)
|
| 319 |
+
break
|
| 320 |
+
else:
|
| 321 |
+
if forward_neighbors:
|
| 322 |
+
nxt = min(forward_neighbors, key=lambda item: item[1])[0]
|
| 323 |
+
else:
|
| 324 |
+
remaining = sequence_length - len(trajectory)
|
| 325 |
+
if remaining > 0:
|
| 326 |
+
trajectory.extend([road_positions[current]] * remaining)
|
| 327 |
+
break
|
| 328 |
+
trajectory.append(road_positions[nxt])
|
| 329 |
+
previous, current = current, nxt
|
| 330 |
+
return np.asarray(trajectory)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def generate_n_smooth_grid_trajectories(graph: nx.DiGraph, road_positions: np.ndarray, n: int, sequence_length: int = 12, turn_probability: float = 0.15, max_attempts: int = 2000, step_size: float = 1.0) -> List[np.ndarray]:
|
| 334 |
+
trajectories: List[np.ndarray] = []
|
| 335 |
+
attempts = 0
|
| 336 |
+
hard_cap = n * max_attempts
|
| 337 |
+
tree = KDTree(road_positions)
|
| 338 |
+
min_x, min_y = np.min(road_positions[:, 0]), np.min(road_positions[:, 1])
|
| 339 |
+
max_x, max_y = np.max(road_positions[:, 0]), np.max(road_positions[:, 1])
|
| 340 |
+
while len(trajectories) < n and attempts < hard_cap:
|
| 341 |
+
rand = [np.random.uniform(min_x, max_x), np.random.uniform(min_y, max_y), 0]
|
| 342 |
+
_, start_idx = tree.query(rand)
|
| 343 |
+
traj = generate_smooth_grid_trajectory(graph, road_positions, turn_probability, sequence_length, start_node=start_idx)
|
| 344 |
+
if traj.shape[0] < sequence_length:
|
| 345 |
+
traj = _fallback_anywalk(road_positions, step_size, sequence_length, start_idx=start_idx)
|
| 346 |
+
if traj.shape[0] >= sequence_length:
|
| 347 |
+
trajectories.append(traj[:sequence_length])
|
| 348 |
+
attempts += 1
|
| 349 |
+
if len(trajectories) < n:
|
| 350 |
+
_LOGGER.warning("Only generated %d/%d vehicle trajectories", len(trajectories), n)
|
| 351 |
+
return trajectories
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def generate_pedestrian_trajectory(valid_positions: np.ndarray, sequence_length: int = 10, step_size: float = 2.5, angle_std: float = 0.1, start: Optional[np.ndarray] = None) -> np.ndarray:
|
| 355 |
+
tree = KDTree(valid_positions)
|
| 356 |
+
if start is None:
|
| 357 |
+
start = valid_positions[np.random.choice(len(valid_positions))]
|
| 358 |
+
trajectory = [start]
|
| 359 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
| 360 |
+
current = start
|
| 361 |
+
for _ in range(sequence_length - 1):
|
| 362 |
+
angle += np.random.normal(0, angle_std)
|
| 363 |
+
candidate = current + np.array([step_size * np.cos(angle), step_size * np.sin(angle), 0])
|
| 364 |
+
_, idx = tree.query(candidate)
|
| 365 |
+
nxt = valid_positions[idx]
|
| 366 |
+
trajectory.append(nxt)
|
| 367 |
+
current = nxt
|
| 368 |
+
return np.asarray(trajectory)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def generate_n_pedestrian_trajectories(valid_positions: np.ndarray, n: int, sequence_length: int = 10, step_size: float = 2.5, angle_std: float = 0.1) -> List[np.ndarray]:
|
| 372 |
+
return [generate_pedestrian_trajectory(valid_positions, sequence_length, step_size, angle_std) for _ in range(n)]
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def get_trajectory_indices(trajectories: Sequence[np.ndarray], pos_total: np.ndarray) -> List[List[int]]:
|
| 376 |
+
def _pos_key(pos: Any) -> Tuple[float, ...]:
|
| 377 |
+
arr = np.asarray(pos)
|
| 378 |
+
if arr.ndim == 0:
|
| 379 |
+
return (float(arr),)
|
| 380 |
+
flat = arr.reshape(-1)
|
| 381 |
+
return tuple(float(x) for x in flat[:3])
|
| 382 |
+
|
| 383 |
+
pos_to_idx = {_pos_key(pos): i for i, pos in enumerate(pos_total)}
|
| 384 |
+
indices: List[List[int]] = []
|
| 385 |
+
for traj in trajectories:
|
| 386 |
+
indices.append([pos_to_idx.get(_pos_key(point), -1) for point in traj])
|
| 387 |
+
return indices
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def sample_continuous_along_polyline(
|
| 391 |
+
traj_pos: np.ndarray,
|
| 392 |
+
idxs: Sequence[int],
|
| 393 |
+
speed: float,
|
| 394 |
+
dt: float,
|
| 395 |
+
n_samples: int,
|
| 396 |
+
speed_profile: Optional[np.ndarray] = None,
|
| 397 |
+
) -> Tuple[np.ndarray, List[Tuple[int, int]], np.ndarray, np.ndarray]:
|
| 398 |
+
traj_pos = np.asarray(traj_pos, float)
|
| 399 |
+
if traj_pos.shape[0] < 2:
|
| 400 |
+
p = np.repeat(traj_pos[:1], n_samples, axis=0)
|
| 401 |
+
pairs = [(idxs[0], idxs[0])] * n_samples
|
| 402 |
+
alphas = np.zeros(n_samples, float)
|
| 403 |
+
vdirs = np.zeros((n_samples, 2), float)
|
| 404 |
+
return p.astype(np.float32), pairs, alphas.astype(np.float32), vdirs.astype(np.float32)
|
| 405 |
+
segment_vectors = traj_pos[1:] - traj_pos[:-1]
|
| 406 |
+
segment_lengths = np.linalg.norm(segment_vectors[:, :2], axis=1)
|
| 407 |
+
cumulative = np.zeros(len(segment_lengths) + 1, float)
|
| 408 |
+
cumulative[1:] = np.cumsum(segment_lengths)
|
| 409 |
+
if cumulative[-1] <= 1e-12:
|
| 410 |
+
p = np.repeat(traj_pos[:1], n_samples, axis=0)
|
| 411 |
+
pairs = [(idxs[0], idxs[0])] * n_samples
|
| 412 |
+
alphas = np.zeros(n_samples, float)
|
| 413 |
+
vdirs = np.zeros((n_samples, 2), float)
|
| 414 |
+
return p.astype(np.float32), pairs, alphas.astype(np.float32), vdirs.astype(np.float32)
|
| 415 |
+
if speed_profile is not None:
|
| 416 |
+
speeds = np.asarray(speed_profile, float)
|
| 417 |
+
if speeds.shape[0] != n_samples:
|
| 418 |
+
raise ValueError("speed_profile must match number of samples")
|
| 419 |
+
ds = np.zeros(n_samples, float)
|
| 420 |
+
for k in range(1, n_samples):
|
| 421 |
+
ds[k] = ds[k - 1] + speeds[k - 1] * dt
|
| 422 |
+
ds = np.clip(ds, 0.0, max(cumulative[-1] - 1e-12, 0.0))
|
| 423 |
+
else:
|
| 424 |
+
ds = speed * dt * np.arange(n_samples, dtype=float)
|
| 425 |
+
ds = np.clip(ds, 0.0, max(cumulative[-1] - 1e-12, 0.0))
|
| 426 |
+
pos_c = np.zeros((n_samples, 3), float)
|
| 427 |
+
alphas = np.zeros(n_samples, float)
|
| 428 |
+
vdirs = np.zeros((n_samples, 2), float)
|
| 429 |
+
pairs: List[Tuple[int, int]] = []
|
| 430 |
+
for idx, distance in enumerate(ds):
|
| 431 |
+
seg_idx = int(np.searchsorted(cumulative, distance, side="right") - 1)
|
| 432 |
+
seg_idx = min(max(seg_idx, 0), len(segment_lengths) - 1)
|
| 433 |
+
seg_start = cumulative[seg_idx]
|
| 434 |
+
seg_length = segment_lengths[seg_idx]
|
| 435 |
+
alpha = (distance - seg_start) / (seg_length if seg_length > 1e-12 else 1.0)
|
| 436 |
+
alpha = min(max(alpha, 0.0), 1.0 - 1e-9)
|
| 437 |
+
p0 = traj_pos[seg_idx]
|
| 438 |
+
p1 = traj_pos[seg_idx + 1]
|
| 439 |
+
position = p0 + alpha * (p1 - p0)
|
| 440 |
+
direction = segment_vectors[seg_idx, :2] / (seg_length if seg_length > 1e-12 else 1.0)
|
| 441 |
+
pos_c[idx] = position
|
| 442 |
+
alphas[idx] = alpha
|
| 443 |
+
vdirs[idx] = direction
|
| 444 |
+
pairs.append((idxs[seg_idx], idxs[seg_idx + 1]))
|
| 445 |
+
return pos_c.astype(np.float32), pairs, alphas.astype(np.float32), vdirs.astype(np.float32)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# ---------------------------------------------------------------------------
|
| 449 |
+
# Channel construction helpers
|
| 450 |
+
# ---------------------------------------------------------------------------
|
| 451 |
+
|
| 452 |
+
def array_response_phase(theta: np.ndarray, phi: np.ndarray, kd: float) -> np.ndarray:
|
| 453 |
+
gamma_x = 1j * kd * np.sin(theta) * np.cos(phi)
|
| 454 |
+
gamma_y = 1j * kd * np.sin(theta) * np.sin(phi)
|
| 455 |
+
gamma_z = 1j * kd * np.cos(theta)
|
| 456 |
+
return np.vstack([gamma_x, gamma_y, gamma_z]).T
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def array_response(ant_indices: np.ndarray, theta: np.ndarray, phi: np.ndarray, kd: float) -> np.ndarray:
|
| 460 |
+
gamma = array_response_phase(theta, phi, kd)
|
| 461 |
+
return np.exp(ant_indices @ gamma.T)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def ant_indices(panel_size: np.ndarray) -> np.ndarray:
|
| 465 |
+
gx = np.tile(np.arange(1), int(panel_size[0] * panel_size[1]))
|
| 466 |
+
gy = np.tile(np.repeat(np.arange(int(panel_size[0])), 1), int(panel_size[1]))
|
| 467 |
+
gz = np.repeat(np.arange(int(panel_size[1])), int(panel_size[0]))
|
| 468 |
+
return np.vstack([gx, gy, gz]).T
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def apply_fov(fov: Sequence[float], theta: np.ndarray, phi: np.ndarray) -> np.ndarray:
|
| 472 |
+
theta = np.mod(theta, 2 * np.pi)
|
| 473 |
+
phi = np.mod(phi, 2 * np.pi)
|
| 474 |
+
fov = np.deg2rad(fov)
|
| 475 |
+
include_phi = np.logical_or(phi <= 0 + fov[0] / 2, phi >= 2 * np.pi - fov[0] / 2)
|
| 476 |
+
include_theta = np.logical_and(theta <= np.pi / 2 + fov[1] / 2, theta >= np.pi / 2 - fov[1] / 2)
|
| 477 |
+
return np.logical_and(include_phi, include_theta)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def rotate_angles(rotation: Optional[np.ndarray], theta: np.ndarray, phi: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 481 |
+
theta = np.deg2rad(theta)
|
| 482 |
+
phi = np.deg2rad(phi)
|
| 483 |
+
if rotation is not None:
|
| 484 |
+
R = np.deg2rad(rotation)
|
| 485 |
+
sa = np.sin(phi - R[2])
|
| 486 |
+
sb = np.sin(R[1])
|
| 487 |
+
sg = np.sin(R[0])
|
| 488 |
+
ca = np.cos(phi - R[2])
|
| 489 |
+
cb = np.cos(R[1])
|
| 490 |
+
cg = np.cos(R[0])
|
| 491 |
+
st, ct = np.sin(theta), np.cos(theta)
|
| 492 |
+
theta = np.arccos(cb * cg * ct + st * (sb * cg * ca - sg * sa))
|
| 493 |
+
phi = np.angle(cb * st * ca - sb * ct + 1j * (cb * sg * ct + st * (sb * sg * ca + cg * sa)))
|
| 494 |
+
return theta, phi
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class OFDMPathGenerator:
|
| 498 |
+
def __init__(self, params: Dict[str, Any], subcarriers: np.ndarray) -> None:
|
| 499 |
+
self.params = params
|
| 500 |
+
self.ofdm = params[c.PARAMSET_OFDM]
|
| 501 |
+
self.subcarriers = subcarriers
|
| 502 |
+
self.total_subc = self.ofdm[c.PARAMSET_OFDM_SC_NUM]
|
| 503 |
+
self.delay_bins = np.arange(self.ofdm["subcarriers"])
|
| 504 |
+
self.delay_to_ofdm = np.exp(-1j * 2 * np.pi / self.total_subc * np.outer(self.delay_bins, self.subcarriers))
|
| 505 |
+
|
| 506 |
+
def _doppler_phase(self, raydata: Dict[str, Any]) -> Optional[np.ndarray]:
|
| 507 |
+
if not (
|
| 508 |
+
self.params[c.PARAMSET_DOPPLER_EN]
|
| 509 |
+
and self.params[c.PARAMSET_SCENARIO_PARAMS][c.PARAMSET_SCENARIO_PARAMS_DOPPLER_EN]
|
| 510 |
+
):
|
| 511 |
+
return None
|
| 512 |
+
fc = self.params[c.PARAMSET_SCENARIO_PARAMS][c.PARAMSET_SCENARIO_PARAMS_CF]
|
| 513 |
+
velocities = np.asarray(raydata.get(c.OUT_PATH_DOP_VEL, 0.0)).reshape(-1, 1)
|
| 514 |
+
elapsed = np.asarray(raydata.get("elapsed_time", 0.0)).reshape(-1, 1)
|
| 515 |
+
return np.exp(-1j * 2 * np.pi * (fc / c.LIGHTSPEED) * (velocities * elapsed))
|
| 516 |
+
|
| 517 |
+
def generate(self, raydata: Dict[str, Any], ts: float) -> np.ndarray:
|
| 518 |
+
if self.ofdm[c.PARAMSET_OFDM_LPF] == 0:
|
| 519 |
+
return self.no_lpf(raydata, ts)
|
| 520 |
+
return self.with_lpf(raydata, ts)
|
| 521 |
+
|
| 522 |
+
def no_lpf(self, raydata: Dict[str, Any], ts: float) -> np.ndarray:
|
| 523 |
+
power = raydata[c.OUT_PATH_RX_POW].reshape(-1, 1)
|
| 524 |
+
delay_n = (raydata[c.OUT_PATH_TOA] / ts).reshape(-1, 1)
|
| 525 |
+
phase = raydata[c.OUT_PATH_PHASE].reshape(-1, 1)
|
| 526 |
+
over = delay_n >= self.ofdm["subcarriers"]
|
| 527 |
+
power[over] = 0
|
| 528 |
+
delay_n[over] = self.ofdm["subcarriers"]
|
| 529 |
+
path_const = np.sqrt(power / self.total_subc) * np.exp(
|
| 530 |
+
1j * (np.deg2rad(phase) - (2 * np.pi / self.total_subc) * np.outer(delay_n, self.subcarriers))
|
| 531 |
+
)
|
| 532 |
+
doppler = self._doppler_phase(raydata)
|
| 533 |
+
if doppler is not None:
|
| 534 |
+
path_const *= doppler
|
| 535 |
+
return path_const
|
| 536 |
+
|
| 537 |
+
def with_lpf(self, raydata: Dict[str, Any], ts: float) -> np.ndarray:
|
| 538 |
+
power = raydata[c.OUT_PATH_RX_POW].reshape(-1, 1)
|
| 539 |
+
delay_n = (raydata[c.OUT_PATH_TOA] / ts).reshape(-1, 1)
|
| 540 |
+
phase = raydata[c.OUT_PATH_PHASE].reshape(-1, 1)
|
| 541 |
+
over = delay_n >= self.ofdm["subcarriers"]
|
| 542 |
+
power[over] = 0
|
| 543 |
+
delay_n[over] = self.ofdm["subcarriers"]
|
| 544 |
+
pulse = np.sinc(self.delay_bins - delay_n) * np.sqrt(power / self.total_subc) * np.exp(1j * np.deg2rad(phase))
|
| 545 |
+
doppler = self._doppler_phase(raydata)
|
| 546 |
+
if doppler is not None:
|
| 547 |
+
pulse *= doppler
|
| 548 |
+
return pulse
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def generate_mimo_channel(raydata: Sequence[Dict[str, Any]], params: Dict[str, Any], tx_params: Dict[str, Any], rx_params: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
|
| 552 |
+
bw = params[c.PARAMSET_OFDM][c.PARAMSET_OFDM_BW] * c.PARAMSET_OFDM_BW_MULT
|
| 553 |
+
kd_tx = 2 * np.pi * tx_params[c.PARAMSET_ANT_SPACING]
|
| 554 |
+
kd_rx = 2 * np.pi * rx_params[c.PARAMSET_ANT_SPACING]
|
| 555 |
+
ts = 1 / bw
|
| 556 |
+
subcarriers = params[c.PARAMSET_OFDM][c.PARAMSET_OFDM_SC_SAMP]
|
| 557 |
+
generator = OFDMPathGenerator(params, subcarriers)
|
| 558 |
+
m_tx = int(np.prod(tx_params[c.PARAMSET_ANT_SHAPE]))
|
| 559 |
+
tx_indices = ant_indices(tx_params[c.PARAMSET_ANT_SHAPE])
|
| 560 |
+
m_rx = int(np.prod(rx_params[c.PARAMSET_ANT_SHAPE]))
|
| 561 |
+
rx_indices = ant_indices(rx_params[c.PARAMSET_ANT_SHAPE])
|
| 562 |
+
channel = np.zeros((len(raydata), m_rx, m_tx, len(subcarriers)), dtype=np.csingle)
|
| 563 |
+
los = np.zeros((len(raydata)), dtype=np.int8) - 2
|
| 564 |
+
for idx, item in enumerate(raydata):
|
| 565 |
+
if item[c.OUT_PATH_NUM] == 0:
|
| 566 |
+
los[idx] = -1
|
| 567 |
+
continue
|
| 568 |
+
dod_theta, dod_phi = rotate_angles(tx_params[c.PARAMSET_ANT_ROTATION], item[c.OUT_PATH_DOD_THETA], item[c.OUT_PATH_DOD_PHI])
|
| 569 |
+
doa_theta, doa_phi = rotate_angles(rx_params[c.PARAMSET_ANT_ROTATION], item[c.OUT_PATH_DOA_THETA], item[c.OUT_PATH_DOA_PHI])
|
| 570 |
+
include_tx = apply_fov(tx_params[c.PARAMSET_ANT_FOV], dod_theta, dod_phi)
|
| 571 |
+
include_rx = apply_fov(rx_params[c.PARAMSET_ANT_FOV], doa_theta, doa_phi)
|
| 572 |
+
include = np.logical_and(include_tx, include_rx)
|
| 573 |
+
dod_theta, dod_phi, doa_theta, doa_phi = dod_theta[include], dod_phi[include], doa_theta[include], doa_phi[include]
|
| 574 |
+
for key in list(item.keys()):
|
| 575 |
+
if key == c.OUT_PATH_NUM:
|
| 576 |
+
item[key] = include.sum()
|
| 577 |
+
elif isinstance(item[key], np.ndarray) and item[key].shape[0] == include.shape[0]:
|
| 578 |
+
item[key] = item[key][include]
|
| 579 |
+
if item[c.OUT_PATH_NUM] == 0:
|
| 580 |
+
los[idx] = -1
|
| 581 |
+
continue
|
| 582 |
+
los[idx] = int(np.sum(item.get("LoS", np.zeros(int(item[c.OUT_PATH_NUM])))))
|
| 583 |
+
tx_response = array_response(tx_indices, dod_theta, dod_phi, kd_tx)
|
| 584 |
+
rx_response = array_response(rx_indices, doa_theta, doa_phi, kd_rx)
|
| 585 |
+
factors = generator.generate(item, ts)
|
| 586 |
+
if params[c.PARAMSET_OFDM][c.PARAMSET_OFDM_LPF] == 0:
|
| 587 |
+
channel[idx] = np.sum(
|
| 588 |
+
rx_response[:, None, None, :] * tx_response[None, :, None, :] * factors.T[None, None, :, :],
|
| 589 |
+
axis=3,
|
| 590 |
+
)
|
| 591 |
+
else:
|
| 592 |
+
channel[idx] = (
|
| 593 |
+
np.sum(
|
| 594 |
+
rx_response[:, None, None, :] * tx_response[None, :, None, :] * factors.T[None, None, :, :],
|
| 595 |
+
axis=3,
|
| 596 |
+
)
|
| 597 |
+
) @ generator.delay_to_ofdm
|
| 598 |
+
return channel, los
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def generate_channel_from_interpolated_ray(ray_interp: Dict[str, Any], antenna_cfg: AntennaArrayConfig, carrier_frequency_hz: float) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
|
| 602 |
+
raydata = [
|
| 603 |
+
{
|
| 604 |
+
c.OUT_PATH_NUM: int(ray_interp["num_paths"]),
|
| 605 |
+
c.OUT_PATH_DOD_THETA: np.asarray(ray_interp["DoD_theta"]),
|
| 606 |
+
c.OUT_PATH_DOD_PHI: np.asarray(ray_interp["DoD_phi"]),
|
| 607 |
+
c.OUT_PATH_DOA_THETA: np.asarray(ray_interp["DoA_theta"]),
|
| 608 |
+
c.OUT_PATH_DOA_PHI: np.asarray(ray_interp["DoA_phi"]),
|
| 609 |
+
c.OUT_PATH_PHASE: np.asarray(ray_interp["phase"]),
|
| 610 |
+
c.OUT_PATH_TOA: np.asarray(ray_interp["ToA"]),
|
| 611 |
+
c.OUT_PATH_RX_POW: np.asarray(ray_interp["power"]),
|
| 612 |
+
"LoS": np.asarray(ray_interp.get("LoS", np.zeros(int(ray_interp["num_paths"])))),
|
| 613 |
+
c.OUT_PATH_DOP_VEL: np.asarray(ray_interp.get("Doppler_vel", np.zeros(int(ray_interp["num_paths"])))),
|
| 614 |
+
"elapsed_time": np.asarray(ray_interp.get("elapsed_time", np.zeros(int(ray_interp["num_paths"])))),
|
| 615 |
+
}
|
| 616 |
+
]
|
| 617 |
+
params = {
|
| 618 |
+
c.PARAMSET_OFDM: {
|
| 619 |
+
c.PARAMSET_OFDM_BW: 1.92e-3,
|
| 620 |
+
c.PARAMSET_OFDM_SC_SAMP: np.arange(antenna_cfg.subcarriers),
|
| 621 |
+
c.PARAMSET_OFDM_SC_NUM: antenna_cfg.subcarriers,
|
| 622 |
+
c.PARAMSET_OFDM_LPF: 0,
|
| 623 |
+
"subcarriers": antenna_cfg.subcarriers,
|
| 624 |
+
},
|
| 625 |
+
c.PARAMSET_FDTD: True,
|
| 626 |
+
c.PARAMSET_DOPPLER_EN: True,
|
| 627 |
+
c.PARAMSET_SCENARIO_PARAMS: {
|
| 628 |
+
c.PARAMSET_SCENARIO_PARAMS_DOPPLER_EN: True,
|
| 629 |
+
c.PARAMSET_SCENARIO_PARAMS_CF: float(carrier_frequency_hz),
|
| 630 |
+
},
|
| 631 |
+
}
|
| 632 |
+
tx = {
|
| 633 |
+
c.PARAMSET_ANT_SHAPE: antenna_cfg.tx_shape(),
|
| 634 |
+
c.PARAMSET_ANT_SPACING: antenna_cfg.spacing,
|
| 635 |
+
c.PARAMSET_ANT_ROTATION: np.asarray(antenna_cfg.tx_rotation),
|
| 636 |
+
c.PARAMSET_ANT_FOV: list(antenna_cfg.field_of_view),
|
| 637 |
+
c.PARAMSET_ANT_RAD_PAT: "isotropic",
|
| 638 |
+
}
|
| 639 |
+
rx = {
|
| 640 |
+
c.PARAMSET_ANT_SHAPE: antenna_cfg.rx_shape(),
|
| 641 |
+
c.PARAMSET_ANT_SPACING: antenna_cfg.spacing,
|
| 642 |
+
c.PARAMSET_ANT_ROTATION: np.asarray(antenna_cfg.rx_rotation),
|
| 643 |
+
c.PARAMSET_ANT_FOV: list(antenna_cfg.field_of_view),
|
| 644 |
+
c.PARAMSET_ANT_RAD_PAT: "isotropic",
|
| 645 |
+
}
|
| 646 |
+
channel, los = generate_mimo_channel(raydata, params, tx, rx)
|
| 647 |
+
return channel, None, los
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def unwrap_angle_deg(a0: np.ndarray, a1: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 651 |
+
delta = ((a1 - a0 + 180.0) % 360.0) - 180.0
|
| 652 |
+
return a0, a0 + delta
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def interpolate_ray_params(deepmimo_data: Dict[str, Any], idx0: int, idx1: int, alpha: float) -> Dict[str, Any]:
|
| 656 |
+
p0 = deepmimo_data["user"]["paths"][idx0]
|
| 657 |
+
p1 = deepmimo_data["user"]["paths"][idx1]
|
| 658 |
+
l0, l1 = int(p0["num_paths"]), int(p1["num_paths"])
|
| 659 |
+
if l0 == 0 or l1 == 0:
|
| 660 |
+
return {
|
| 661 |
+
"num_paths": 0,
|
| 662 |
+
"DoD_theta": np.array([]),
|
| 663 |
+
"DoD_phi": np.array([]),
|
| 664 |
+
"DoA_theta": np.array([]),
|
| 665 |
+
"DoA_phi": np.array([]),
|
| 666 |
+
"phase": np.array([]),
|
| 667 |
+
"ToA": np.array([]),
|
| 668 |
+
"power": np.array([]),
|
| 669 |
+
"LoS": np.array([], int),
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
def _select_paths(path: Dict[str, Any], count: int) -> np.ndarray:
|
| 673 |
+
scores = np.asarray(path["power"]).flatten()
|
| 674 |
+
order = np.argsort(-scores)
|
| 675 |
+
return order[:count]
|
| 676 |
+
|
| 677 |
+
count = min(l0, l1)
|
| 678 |
+
idxs0 = _select_paths(p0, count)
|
| 679 |
+
idxs1 = _select_paths(p1, count)
|
| 680 |
+
|
| 681 |
+
def _g(source: Dict[str, Any], key: str, selection: np.ndarray) -> np.ndarray:
|
| 682 |
+
return np.asarray(source[key]).flatten()[selection].astype(float)
|
| 683 |
+
|
| 684 |
+
power0, power1 = _g(p0, "power", idxs0), _g(p1, "power", idxs1)
|
| 685 |
+
power = (1 - alpha) * power0 + alpha * power1
|
| 686 |
+
toa0, toa1 = _g(p0, "ToA", idxs0), _g(p1, "ToA", idxs1)
|
| 687 |
+
toa = (1 - alpha) * toa0 + alpha * toa1
|
| 688 |
+
|
| 689 |
+
def _ainterp(key: str) -> np.ndarray:
|
| 690 |
+
a0, a1 = _g(p0, key, idxs0), _g(p1, key, idxs1)
|
| 691 |
+
out = np.zeros_like(a0)
|
| 692 |
+
for n in range(count):
|
| 693 |
+
u0, u1 = unwrap_angle_deg(a0[n], a1[n])
|
| 694 |
+
out[n] = (1 - alpha) * u0 + alpha * u1
|
| 695 |
+
return out
|
| 696 |
+
|
| 697 |
+
dod_theta, dod_phi = _ainterp("DoD_theta"), _ainterp("DoD_phi")
|
| 698 |
+
doa_theta, doa_phi = _ainterp("DoA_theta"), _ainterp("DoA_phi")
|
| 699 |
+
phase0 = np.deg2rad(_g(p0, "phase", idxs0))
|
| 700 |
+
phase1 = np.deg2rad(_g(p1, "phase", idxs1))
|
| 701 |
+
dphi = np.angle(np.exp(1j * (phase1 - phase0)))
|
| 702 |
+
phase = np.rad2deg(phase0 + alpha * dphi)
|
| 703 |
+
los = (
|
| 704 |
+
_g(p0, "LoS", idxs0).astype(int) + _g(p1, "LoS", idxs1).astype(int) > 0
|
| 705 |
+
).astype(int)
|
| 706 |
+
|
| 707 |
+
return {
|
| 708 |
+
"num_paths": count,
|
| 709 |
+
"DoD_theta": dod_theta,
|
| 710 |
+
"DoD_phi": dod_phi,
|
| 711 |
+
"DoA_theta": doa_theta,
|
| 712 |
+
"DoA_phi": doa_phi,
|
| 713 |
+
"phase": phase,
|
| 714 |
+
"ToA": toa,
|
| 715 |
+
"power": power,
|
| 716 |
+
"LoS": los,
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
# ---------------------------------------------------------------------------
|
| 721 |
+
# Scenario generator
|
| 722 |
+
# ---------------------------------------------------------------------------
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
class DynamicScenarioGenerator:
|
| 726 |
+
"""High-level orchestrator for dynamic DeepMIMO scenario synthesis."""
|
| 727 |
+
|
| 728 |
+
def __init__(self, config: ScenarioGenerationConfig, logger: Optional[logging.Logger] = None) -> None:
|
| 729 |
+
self.config = config
|
| 730 |
+
self.logger = logger or _LOGGER
|
| 731 |
+
if config.rng_seed is not None:
|
| 732 |
+
np.random.seed(config.rng_seed)
|
| 733 |
+
|
| 734 |
+
# ------------------------------------------------------------------
|
| 735 |
+
# public API
|
| 736 |
+
# ------------------------------------------------------------------
|
| 737 |
+
def generate(self, overwrite: bool = False) -> ScenarioGenerationResult:
|
| 738 |
+
cfg = self.config
|
| 739 |
+
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
| 740 |
+
cfg.full_output_dir.mkdir(parents=True, exist_ok=True)
|
| 741 |
+
filename = self._build_filename()
|
| 742 |
+
data_path = cfg.output_dir / filename
|
| 743 |
+
full_path = cfg.full_output_dir / filename.replace(".p", "_full.p")
|
| 744 |
+
|
| 745 |
+
if data_path.exists() and full_path.exists() and not overwrite:
|
| 746 |
+
payload = self._load_pickle(data_path)
|
| 747 |
+
self.logger.info("Loaded cached scenario from %s", data_path)
|
| 748 |
+
return ScenarioGenerationResult(payload=payload, output_path=data_path, full_output_path=full_path, generated=False)
|
| 749 |
+
|
| 750 |
+
deepmimo_data = self._load_or_build_deepmimo()
|
| 751 |
+
path_exist = np.asarray(deepmimo_data["user"]["LoS"])
|
| 752 |
+
pos_total = np.asarray(deepmimo_data["user"]["location"])
|
| 753 |
+
|
| 754 |
+
grid_step = self._infer_grid_step(pos_total)
|
| 755 |
+
road_positions, lane_info = filter_road_positions(
|
| 756 |
+
pos_total[(path_exist == 0) | (path_exist == 1)], cfg.grid.road_width, cfg.grid.road_center_spacing
|
| 757 |
+
)
|
| 758 |
+
road_graph = create_grid_road_network(road_positions, lane_info, grid_step)
|
| 759 |
+
self.logger.info("Road graph has %d nodes", len(road_graph.nodes))
|
| 760 |
+
|
| 761 |
+
vehicle_trajs = generate_n_smooth_grid_trajectories(
|
| 762 |
+
road_graph,
|
| 763 |
+
road_positions,
|
| 764 |
+
cfg.traffic.num_vehicles,
|
| 765 |
+
sequence_length=cfg.sampling.time_steps,
|
| 766 |
+
turn_probability=cfg.traffic.turn_probability,
|
| 767 |
+
max_attempts=cfg.traffic.max_attempts,
|
| 768 |
+
step_size=grid_step,
|
| 769 |
+
)
|
| 770 |
+
pedestrian_trajs = generate_n_pedestrian_trajectories(
|
| 771 |
+
pos_total[(path_exist == 0) | (path_exist == 1)],
|
| 772 |
+
cfg.traffic.num_pedestrians,
|
| 773 |
+
sequence_length=cfg.sampling.time_steps,
|
| 774 |
+
step_size=grid_step,
|
| 775 |
+
angle_std=cfg.traffic.pedestrian_angle_std,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
veh_idx = get_trajectory_indices(vehicle_trajs, pos_total)
|
| 779 |
+
ped_idx = get_trajectory_indices(pedestrian_trajs, pos_total)
|
| 780 |
+
|
| 781 |
+
car_tracks = self._build_tracks(vehicle_trajs, veh_idx, cfg.traffic.vehicle_speed_range, cfg.sampling)
|
| 782 |
+
ped_tracks = self._build_tracks(pedestrian_trajs, ped_idx, cfg.traffic.pedestrian_speed_range, cfg.sampling)
|
| 783 |
+
|
| 784 |
+
if cfg.export_environment_plot and cfg.figures_dir is not None:
|
| 785 |
+
self._save_environment_plot(
|
| 786 |
+
cfg.figures_dir,
|
| 787 |
+
pos_total,
|
| 788 |
+
path_exist,
|
| 789 |
+
road_positions=road_positions,
|
| 790 |
+
vehicle_trajs=[track["pos"] for track in car_tracks],
|
| 791 |
+
ped_trajs=[track["pos"] for track in ped_tracks],
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
channel_payload = self._channels_for_tracks(
|
| 795 |
+
deepmimo_data,
|
| 796 |
+
car_tracks + ped_tracks,
|
| 797 |
+
cfg.antenna,
|
| 798 |
+
cfg.carrier_frequency_hz,
|
| 799 |
+
cfg.sampling,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
discrete_channels = np.zeros(
|
| 803 |
+
(
|
| 804 |
+
cfg.traffic.num_vehicles + cfg.traffic.num_pedestrians,
|
| 805 |
+
cfg.sampling.time_steps,
|
| 806 |
+
cfg.antenna.total_tx_elements(),
|
| 807 |
+
cfg.antenna.subcarriers,
|
| 808 |
+
),
|
| 809 |
+
np.complex128,
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
payload: Dict[str, Any] = {
|
| 813 |
+
"scenario": cfg.scenario,
|
| 814 |
+
"index_discrete": veh_idx + ped_idx,
|
| 815 |
+
"grid_step": grid_step,
|
| 816 |
+
"sample_dt": cfg.sampling.sample_dt,
|
| 817 |
+
"car_speed_range": cfg.traffic.vehicle_speed_range,
|
| 818 |
+
"ped_speed_range": cfg.traffic.pedestrian_speed_range,
|
| 819 |
+
"los": path_exist,
|
| 820 |
+
"channel_cont": channel_payload["channel"],
|
| 821 |
+
"pos_cont": channel_payload["pos"],
|
| 822 |
+
"vel_cont": channel_payload["vel"],
|
| 823 |
+
"acc_cont": channel_payload["acc"],
|
| 824 |
+
"doppler_vel_cont": channel_payload["doppler"],
|
| 825 |
+
"angle_cont": channel_payload["angle"],
|
| 826 |
+
"delay_cont": channel_payload["delay"],
|
| 827 |
+
"pos_step_xy": channel_payload["step_xy"],
|
| 828 |
+
"channel_discrete": discrete_channels,
|
| 829 |
+
"pos_discrete": vehicle_trajs + pedestrian_trajs,
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
if cfg.sampling.continuous_mode:
|
| 833 |
+
payload.update(
|
| 834 |
+
{
|
| 835 |
+
"channel": channel_payload["channel"],
|
| 836 |
+
"pos": channel_payload["pos"],
|
| 837 |
+
"vel": channel_payload["vel"],
|
| 838 |
+
"acc": channel_payload["acc"],
|
| 839 |
+
"doppler_vel": channel_payload["doppler"],
|
| 840 |
+
"angle": channel_payload["angle"],
|
| 841 |
+
"delay": channel_payload["delay"],
|
| 842 |
+
}
|
| 843 |
+
)
|
| 844 |
+
else:
|
| 845 |
+
payload["channel"] = discrete_channels
|
| 846 |
+
|
| 847 |
+
self._dump_pickle(data_path, payload["channel"])
|
| 848 |
+
self._dump_pickle(full_path, payload)
|
| 849 |
+
self.logger.info("Generated scenario saved to %s", data_path)
|
| 850 |
+
return ScenarioGenerationResult(payload=payload, output_path=data_path, full_output_path=full_path, generated=True)
|
| 851 |
+
|
| 852 |
+
# ------------------------------------------------------------------
|
| 853 |
+
# internal helpers
|
| 854 |
+
# ------------------------------------------------------------------
|
| 855 |
+
def _load_or_build_deepmimo(self) -> Dict[str, Any]:
|
| 856 |
+
cfg = self.config
|
| 857 |
+
return load_deepmimo_user_data(
|
| 858 |
+
cfg.scenario,
|
| 859 |
+
scenarios_dir=cfg.scenarios_dir,
|
| 860 |
+
load_params=cfg.deepmimo_load_params,
|
| 861 |
+
max_paths=cfg.deepmimo_max_paths,
|
| 862 |
+
array_dtype=cfg.deepmimo_array_dtype,
|
| 863 |
+
logger=self.logger,
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
def _save_environment_plot(
|
| 867 |
+
self,
|
| 868 |
+
figures_dir: Path,
|
| 869 |
+
positions: np.ndarray,
|
| 870 |
+
los: np.ndarray,
|
| 871 |
+
road_positions: Optional[np.ndarray] = None,
|
| 872 |
+
vehicle_trajs: Optional[Sequence[np.ndarray]] = None,
|
| 873 |
+
ped_trajs: Optional[Sequence[np.ndarray]] = None,
|
| 874 |
+
) -> None:
|
| 875 |
+
figures_dir.mkdir(parents=True, exist_ok=True)
|
| 876 |
+
x, y = positions[:, 0], positions[:, 1]
|
| 877 |
+
los_array = np.asarray(los)
|
| 878 |
+
if los_array.ndim > 1:
|
| 879 |
+
N = x.shape[0]
|
| 880 |
+
shape = list(los_array.shape)
|
| 881 |
+
if N in shape:
|
| 882 |
+
axis_user = shape.index(N)
|
| 883 |
+
los_matrix = np.moveaxis(los_array, axis_user, -1).reshape(-1, N)
|
| 884 |
+
los_per_user = np.full(N, -1, np.int8)
|
| 885 |
+
any_los = (los_matrix == 1).any(axis=0)
|
| 886 |
+
los_per_user[any_los] = 1
|
| 887 |
+
any_nlos = (los_matrix == 0).any(axis=0) & ~any_los
|
| 888 |
+
los_per_user[any_nlos] = 0
|
| 889 |
+
else:
|
| 890 |
+
los_per_user = np.full(x.shape[0], -1, np.int8)
|
| 891 |
+
else:
|
| 892 |
+
los_per_user = los_array.astype(np.int8, copy=False)
|
| 893 |
+
|
| 894 |
+
colors = {
|
| 895 |
+
1: "#2ca8c2", # teal for LoS
|
| 896 |
+
0: "#f39c12", # amber for NLoS
|
| 897 |
+
-1: "#5d6d7e", # muted slate for inactive users
|
| 898 |
+
}
|
| 899 |
+
plt.figure(figsize=(8, 8), dpi=500)
|
| 900 |
+
legend_handles = []
|
| 901 |
+
legend_labels = []
|
| 902 |
+
for value, color in colors.items():
|
| 903 |
+
mask = los_per_user == value
|
| 904 |
+
if np.any(mask):
|
| 905 |
+
handle = plt.scatter(x[mask], y[mask], s=3, c=color, alpha=0.6)
|
| 906 |
+
legend_handles.append(handle)
|
| 907 |
+
legend_labels.append(f"LoS={value}")
|
| 908 |
+
|
| 909 |
+
if road_positions is not None and len(road_positions):
|
| 910 |
+
rp = np.asarray(road_positions)
|
| 911 |
+
handle = plt.scatter(rp[:, 0], rp[:, 1], s=1, c="#a0a0a0", alpha=0.25)
|
| 912 |
+
legend_handles.append(handle)
|
| 913 |
+
legend_labels.append("Road nodes")
|
| 914 |
+
|
| 915 |
+
if vehicle_trajs:
|
| 916 |
+
cmap_vehicle = plt.cm.get_cmap("tab20", max(len(vehicle_trajs), 1))
|
| 917 |
+
for idx, traj in enumerate(vehicle_trajs):
|
| 918 |
+
arr = np.asarray(traj, dtype=float)
|
| 919 |
+
arr = arr.reshape(arr.shape[0], -1)
|
| 920 |
+
if arr.size == 0:
|
| 921 |
+
continue
|
| 922 |
+
color = cmap_vehicle(idx % cmap_vehicle.N)
|
| 923 |
+
plt.plot(arr[:, 0], arr[:, 1], color=color, alpha=0.85, linewidth=1.2)
|
| 924 |
+
plt.scatter(
|
| 925 |
+
arr[0:1, 0],
|
| 926 |
+
arr[0:1, 1],
|
| 927 |
+
s=20,
|
| 928 |
+
marker="o",
|
| 929 |
+
facecolor="white",
|
| 930 |
+
edgecolor=color,
|
| 931 |
+
linewidths=1.2,
|
| 932 |
+
zorder=3,
|
| 933 |
+
)
|
| 934 |
+
legend_handles.append(Line2D([0], [0], color=cmap_vehicle(0), linewidth=1.2))
|
| 935 |
+
legend_labels.append("Vehicle trajectories")
|
| 936 |
+
legend_handles.append(Line2D([0], [0], marker="o", color="white", markeredgecolor=cmap_vehicle(0), markerfacecolor="white", linestyle="None", markersize=5, markeredgewidth=1.2))
|
| 937 |
+
legend_labels.append("Vehicle start")
|
| 938 |
+
|
| 939 |
+
if ped_trajs:
|
| 940 |
+
cmap_ped = plt.cm.get_cmap("Set2", max(len(ped_trajs), 1))
|
| 941 |
+
for idx, traj in enumerate(ped_trajs):
|
| 942 |
+
arr = np.asarray(traj, dtype=float)
|
| 943 |
+
arr = arr.reshape(arr.shape[0], -1)
|
| 944 |
+
if arr.size == 0:
|
| 945 |
+
continue
|
| 946 |
+
color = cmap_ped(idx % cmap_ped.N)
|
| 947 |
+
plt.plot(arr[:, 0], arr[:, 1], color=color, alpha=0.85, linewidth=1.2, linestyle="--")
|
| 948 |
+
plt.scatter(
|
| 949 |
+
arr[0:1, 0],
|
| 950 |
+
arr[0:1, 1],
|
| 951 |
+
s=20,
|
| 952 |
+
marker="o",
|
| 953 |
+
facecolor="white",
|
| 954 |
+
edgecolor=color,
|
| 955 |
+
linewidths=1.2,
|
| 956 |
+
zorder=3,
|
| 957 |
+
)
|
| 958 |
+
legend_handles.append(Line2D([0], [0], color=cmap_ped(0), linewidth=1.2, linestyle="--"))
|
| 959 |
+
legend_labels.append("Pedestrian trajectories")
|
| 960 |
+
legend_handles.append(Line2D([0], [0], marker="o", color="white", markeredgecolor=cmap_ped(0), markerfacecolor="white", linestyle="None", markersize=5, markeredgewidth=1.2))
|
| 961 |
+
legend_labels.append("Pedestrian start")
|
| 962 |
+
|
| 963 |
+
if legend_handles:
|
| 964 |
+
plt.legend(legend_handles, legend_labels, loc="best", fontsize=8)
|
| 965 |
+
# plt.grid(True)
|
| 966 |
+
plt.gca().set_aspect("equal", adjustable="box")
|
| 967 |
+
plt.savefig(figures_dir / "environment.png", bbox_inches="tight", pad_inches=0.1)
|
| 968 |
+
plt.close()
|
| 969 |
+
|
| 970 |
+
def _infer_grid_step(self, positions: np.ndarray) -> float:
|
| 971 |
+
cfg = self.config
|
| 972 |
+
if cfg.grid.auto_step_size:
|
| 973 |
+
inferred = infer_grid_step(positions)
|
| 974 |
+
self.logger.info("Inferred road step size %.6f m", inferred)
|
| 975 |
+
return inferred
|
| 976 |
+
return float(cfg.grid.step_size)
|
| 977 |
+
|
| 978 |
+
def _build_tracks(
|
| 979 |
+
self,
|
| 980 |
+
trajectories: Sequence[np.ndarray],
|
| 981 |
+
indices: Sequence[Sequence[int]],
|
| 982 |
+
speed_range: Tuple[float, float],
|
| 983 |
+
sampling_cfg: ScenarioSamplingConfig,
|
| 984 |
+
) -> List[Dict[str, Any]]:
|
| 985 |
+
tracks: List[Dict[str, Any]] = []
|
| 986 |
+
horizon = sampling_cfg.continuous_length or sampling_cfg.time_steps
|
| 987 |
+
for traj, idxs in zip(trajectories, indices):
|
| 988 |
+
speed = float(np.random.uniform(*speed_range))
|
| 989 |
+
idxs_list = list(idxs)
|
| 990 |
+
speed_profile = np.full(horizon, speed, float)
|
| 991 |
+
last_change = -1
|
| 992 |
+
for k in range(len(idxs_list) - 1):
|
| 993 |
+
if idxs_list[k] != idxs_list[k + 1]:
|
| 994 |
+
last_change = k
|
| 995 |
+
if last_change == -1:
|
| 996 |
+
speed_profile[:] = 0.0
|
| 997 |
+
else:
|
| 998 |
+
last_move_sample = last_change + 1
|
| 999 |
+
if last_move_sample < horizon - 1:
|
| 1000 |
+
ramp = min(5, last_move_sample + 1)
|
| 1001 |
+
start_idx = max(0, last_move_sample - ramp + 1)
|
| 1002 |
+
speed_profile[:start_idx] = speed
|
| 1003 |
+
for k in range(start_idx, last_move_sample + 1):
|
| 1004 |
+
factor = float(last_move_sample - k + 1) / float(ramp)
|
| 1005 |
+
speed_profile[k] = speed * factor
|
| 1006 |
+
speed_profile[last_move_sample + 1 :] = 0.0
|
| 1007 |
+
pos_c, pairs, alpha, vdir = sample_continuous_along_polyline(
|
| 1008 |
+
traj,
|
| 1009 |
+
idxs,
|
| 1010 |
+
speed,
|
| 1011 |
+
sampling_cfg.sample_dt,
|
| 1012 |
+
horizon,
|
| 1013 |
+
speed_profile=speed_profile,
|
| 1014 |
+
)
|
| 1015 |
+
tracks.append(
|
| 1016 |
+
{
|
| 1017 |
+
"speed": speed,
|
| 1018 |
+
"speed_profile": speed_profile,
|
| 1019 |
+
"pos": pos_c,
|
| 1020 |
+
"pairs": pairs,
|
| 1021 |
+
"alpha": alpha,
|
| 1022 |
+
"vdir": vdir,
|
| 1023 |
+
}
|
| 1024 |
+
)
|
| 1025 |
+
return tracks
|
| 1026 |
+
|
| 1027 |
+
def _channels_for_tracks(
|
| 1028 |
+
self,
|
| 1029 |
+
deepmimo_data: Dict[str, Any],
|
| 1030 |
+
tracks: Sequence[Dict[str, Any]],
|
| 1031 |
+
antenna_cfg: AntennaArrayConfig,
|
| 1032 |
+
carrier_frequency_hz: float,
|
| 1033 |
+
sampling_cfg: ScenarioSamplingConfig,
|
| 1034 |
+
) -> Dict[str, Any]:
|
| 1035 |
+
channels: List[np.ndarray] = []
|
| 1036 |
+
positions: List[np.ndarray] = []
|
| 1037 |
+
velocities: List[np.ndarray] = []
|
| 1038 |
+
accelerations: List[np.ndarray] = []
|
| 1039 |
+
dopplers: List[List[np.ndarray]] = []
|
| 1040 |
+
angles: List[List[np.ndarray]] = []
|
| 1041 |
+
delays: List[List[np.ndarray]] = []
|
| 1042 |
+
steps_xy: List[np.ndarray] = []
|
| 1043 |
+
for track in tracks:
|
| 1044 |
+
base_speed = track["speed"]
|
| 1045 |
+
speed_profile = track.get("speed_profile")
|
| 1046 |
+
pos = track["pos"]
|
| 1047 |
+
pairs = track["pairs"]
|
| 1048 |
+
alpha = track["alpha"]
|
| 1049 |
+
vdir = track["vdir"]
|
| 1050 |
+
horizon = len(alpha)
|
| 1051 |
+
channel_sequence: List[np.ndarray] = []
|
| 1052 |
+
doppler_list: List[np.ndarray] = []
|
| 1053 |
+
angle_list: List[np.ndarray] = []
|
| 1054 |
+
delay_list: List[np.ndarray] = []
|
| 1055 |
+
step_xy = np.zeros(horizon, float)
|
| 1056 |
+
step_xy[1:] = np.linalg.norm(pos[1:, :2] - pos[:-1, :2], axis=1)
|
| 1057 |
+
vel_series = np.zeros(horizon, float)
|
| 1058 |
+
if horizon > 1:
|
| 1059 |
+
vel_series[1:] = step_xy[1:] / sampling_cfg.sample_dt
|
| 1060 |
+
if vel_series[0] == 0.0:
|
| 1061 |
+
if speed_profile is not None:
|
| 1062 |
+
vel_series[0] = float(speed_profile[0])
|
| 1063 |
+
else:
|
| 1064 |
+
vel_series[0] = base_speed
|
| 1065 |
+
acc_series = np.zeros_like(vel_series)
|
| 1066 |
+
if horizon > 1:
|
| 1067 |
+
acc_series[1:] = (vel_series[1:] - vel_series[:-1]) / sampling_cfg.sample_dt
|
| 1068 |
+
acc_series[0] = acc_series[1] if horizon > 1 else 0.0
|
| 1069 |
+
for idx, (pair, a, vd) in enumerate(zip(pairs, alpha, vdir)):
|
| 1070 |
+
i0, i1 = pair
|
| 1071 |
+
ray_interp = interpolate_ray_params(deepmimo_data, i0, i1, float(a))
|
| 1072 |
+
if ray_interp["num_paths"] == 0:
|
| 1073 |
+
channel_sequence.append(np.zeros((antenna_cfg.total_tx_elements(), antenna_cfg.subcarriers), np.complex128))
|
| 1074 |
+
doppler_list.append(np.zeros((0,), np.float32))
|
| 1075 |
+
angle_list.append(np.zeros((0,), np.float32))
|
| 1076 |
+
delay_list.append(np.zeros((0,), np.float32))
|
| 1077 |
+
continue
|
| 1078 |
+
doa_phi = np.deg2rad(ray_interp["DoA_phi"])
|
| 1079 |
+
aoa_unit = np.stack([np.cos(doa_phi), np.sin(doa_phi)], axis=1)
|
| 1080 |
+
speed_inst = vel_series[idx] if idx < vel_series.shape[0] else base_speed
|
| 1081 |
+
v_proj = -np.sum(aoa_unit * vd[None, :], axis=1) * speed_inst
|
| 1082 |
+
ray_interp["Doppler_vel"] = v_proj.astype(np.float32)
|
| 1083 |
+
ray_interp["elapsed_time"] = np.ones_like(ray_interp["power"], dtype=np.float32) * (idx * sampling_cfg.sample_dt)
|
| 1084 |
+
pred, _, _ = generate_channel_from_interpolated_ray(ray_interp, antenna_cfg, carrier_frequency_hz)
|
| 1085 |
+
channel_sequence.append(np.asarray(pred[0]).squeeze(0))
|
| 1086 |
+
doppler_list.append(v_proj.astype(np.float32))
|
| 1087 |
+
angle_list.append(ray_interp["DoA_phi"].astype(np.float32))
|
| 1088 |
+
delay_list.append(ray_interp["ToA"].astype(np.float32))
|
| 1089 |
+
channels.append(np.stack(channel_sequence, axis=0))
|
| 1090 |
+
positions.append(pos.astype(np.float32))
|
| 1091 |
+
velocities.append(vel_series.astype(np.float32))
|
| 1092 |
+
accelerations.append(acc_series.astype(np.float32))
|
| 1093 |
+
dopplers.append(doppler_list)
|
| 1094 |
+
angles.append(angle_list)
|
| 1095 |
+
delays.append(delay_list)
|
| 1096 |
+
steps_xy.append(step_xy.astype(np.float32))
|
| 1097 |
+
return {
|
| 1098 |
+
"channel": np.stack(channels, axis=0),
|
| 1099 |
+
"pos": np.stack(positions, axis=0),
|
| 1100 |
+
"vel": np.stack(velocities, axis=0),
|
| 1101 |
+
"acc": np.stack(accelerations, axis=0),
|
| 1102 |
+
"doppler": dopplers,
|
| 1103 |
+
"angle": angles,
|
| 1104 |
+
"delay": delays,
|
| 1105 |
+
"step_xy": np.stack(steps_xy, axis=0),
|
| 1106 |
+
}
|
| 1107 |
+
|
| 1108 |
+
def _build_filename(self) -> str:
|
| 1109 |
+
cfg = self.config
|
| 1110 |
+
return f"{cfg.scenario}_{cfg.sampling.time_steps}_{cfg.antenna.total_tx_elements()}_{cfg.antenna.subcarriers}.p"
|
| 1111 |
+
|
| 1112 |
+
@staticmethod
|
| 1113 |
+
def _load_pickle(path: Path) -> Any:
|
| 1114 |
+
with path.open("rb") as handle:
|
| 1115 |
+
import pickle
|
| 1116 |
+
|
| 1117 |
+
return pickle.load(handle)
|
| 1118 |
+
|
| 1119 |
+
@staticmethod
|
| 1120 |
+
def _dump_pickle(path: Path, obj: Any) -> None:
|
| 1121 |
+
with path.open("wb") as handle:
|
| 1122 |
+
import pickle
|
| 1123 |
+
|
| 1124 |
+
pickle.dump(obj, handle)
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
# Backwards compatibility helper -------------------------------------------------
|
| 1128 |
+
|
| 1129 |
+
def generate_dynamic_scenario_dataset(
|
| 1130 |
+
config: ScenarioGenerationConfig,
|
| 1131 |
+
overwrite: bool = False,
|
| 1132 |
+
logger: Optional[logging.Logger] = None,
|
| 1133 |
+
) -> ScenarioGenerationResult:
|
| 1134 |
+
generator = DynamicScenarioGenerator(config, logger=logger)
|
| 1135 |
+
return generator.generate(overwrite=overwrite)
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
def generate_dynamic_scenario(
|
| 1139 |
+
config: ScenarioGenerationConfig,
|
| 1140 |
+
overwrite: bool = False,
|
| 1141 |
+
logger: Optional[logging.Logger] = None,
|
| 1142 |
+
) -> ScenarioGenerationResult:
|
| 1143 |
+
warnings.warn(
|
| 1144 |
+
"`generate_dynamic_scenario` is deprecated. Use `generate_dynamic_scenario_dataset` instead.",
|
| 1145 |
+
DeprecationWarning,
|
| 1146 |
+
stacklevel=2,
|
| 1147 |
+
)
|
| 1148 |
+
return generate_dynamic_scenario_dataset(config, overwrite=overwrite, logger=logger)
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
__all__ = [
|
| 1152 |
+
"AntennaArrayConfig",
|
| 1153 |
+
"TrafficConfig",
|
| 1154 |
+
"GridConfig",
|
| 1155 |
+
"ScenarioSamplingConfig",
|
| 1156 |
+
"ScenarioGenerationConfig",
|
| 1157 |
+
"ScenarioGenerationResult",
|
| 1158 |
+
"DynamicScenarioGenerator",
|
| 1159 |
+
"generate_dynamic_scenario_dataset",
|
| 1160 |
+
"generate_dynamic_scenario",
|
| 1161 |
+
]
|
docs/dynamic_scenario_pipeline.md
CHANGED
|
@@ -214,14 +214,14 @@ from LWMTemporal.data.scenario_generation import (
|
|
| 214 |
config = ScenarioGenerationConfig(
|
| 215 |
scenario="city_0_newyork_3p5",
|
| 216 |
antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32),
|
| 217 |
-
sampling=ScenarioSamplingConfig(time_steps=
|
| 218 |
-
traffic=TrafficConfig(num_vehicles=120, num_pedestrians=20, turn_probability=0.
|
| 219 |
-
grid=GridConfig(road_width=
|
| 220 |
output_dir=Path("examples/data"),
|
| 221 |
full_output_dir=Path("examples/full_data"),
|
| 222 |
figures_dir=Path("examples/figs/newyork"),
|
| 223 |
scenarios_dir=Path("deepmimo_scenarios"),
|
| 224 |
-
deepmimo_max_paths=
|
| 225 |
)
|
| 226 |
|
| 227 |
generator = DynamicScenarioGenerator(config)
|
|
|
|
| 214 |
config = ScenarioGenerationConfig(
|
| 215 |
scenario="city_0_newyork_3p5",
|
| 216 |
antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32),
|
| 217 |
+
sampling=ScenarioSamplingConfig(time_steps=20, sample_dt=0.1),
|
| 218 |
+
traffic=TrafficConfig(num_vehicles=120, num_pedestrians=20, turn_probability=0.1),
|
| 219 |
+
grid=GridConfig(road_width=2.0, road_center_spacing=8.0),
|
| 220 |
output_dir=Path("examples/data"),
|
| 221 |
full_output_dir=Path("examples/full_data"),
|
| 222 |
figures_dir=Path("examples/figs/newyork"),
|
| 223 |
scenarios_dir=Path("deepmimo_scenarios"),
|
| 224 |
+
deepmimo_max_paths=25,
|
| 225 |
)
|
| 226 |
|
| 227 |
generator = DynamicScenarioGenerator(config)
|