wi-lab's picture
Update app.py
f3c3f04 verified
from __future__ import annotations
import math
import pickle
import shutil
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import imageio
import numpy as np
import plotly.graph_objects as go # type: ignore
import deepmimo
import subprocess
import sys
import os
# Suppress all output from pip install
subprocess.check_call([
sys.executable, "-m", "pip", "install", "--upgrade", "--quiet", "--no-cache-dir",
"gradio>=6.0", "gradio_client>=2.0"
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# Optional: also suppress the dumb --import attempts from HF
os.environ["PYTHONWARNINGS"] = "ignore"
import gradio as gr # type: ignore
import torch # type: ignore
from torch.utils.data import DataLoader # type: ignore
# Install DeepMIMO if not already installed
# try:
# import deepmimo
# except ImportError:
# print("Installing DeepMIMO from GitHub...")
# subprocess.check_call([
# sys.executable, "-m", "pip", "install",
# "git+https://github.com/DeepMIMO/DeepMIMO.git"
# ])
# import deepmimo
# import sys, subprocess, sysconfig, pathlib
# from pathlib import Path
# def ensure_deepmimo():
# try:
# import deepmimo # type: ignore # already installed and importable
# return
# except ModuleNotFoundError:
# print("Installing DeepMIMO from GitHub...")
# subprocess.check_call([
# sys.executable,
# "-m",
# "pip",
# "install",
# "git+https://github.com/DeepMIMO/DeepMIMO.git"
# ])
# # ---- Patch DeepMIMO's dataset.py to avoid the union-typing bug ----
# site_pkgs = (
# Path(sys.executable).parent.parent
# / "lib"
# / f"python{sys.version_info.major}.{sys.version_info.minor}"
# / "site-packages"
# )
# dataset_py = site_pkgs / "deepmimo" / "generator" / "dataset.py"
# if dataset_py.exists():
# txt = dataset_py.read_text()
# # handle both variants just in case
# before_1 = "def append(self, dataset: Dataset | 'MacroDataset'):"
# before_2 = "def append(self, dataset: Dataset | MacroDataset):"
# if before_1 in txt or before_2 in txt:
# print("Patching DeepMIMO MacroDataset.append type annotation...")
# txt = txt.replace(before_1, "def append(self, dataset):")
# txt = txt.replace(before_2, "def append(self, dataset):")
# dataset_py.write_text(txt)
# import deepmimo # type: ignore
# print("DeepMIMO imported successfully.")
# # Call this once near the top of app.py
# ensure_deepmimo()
import builtins
# Force DeepMIMO to never ask questions
def no_input(*args, **kwargs):
return "y" # auto-answer "yes" to everything
builtins.input = no_input
os.environ["DEEPMIMO_NO_INTERACTIVE"] = "1"
# ---------------------------------------------------------------------------
# Import LWMTemporal either from an installed package or from a private
# Hugging Face repo (useful when this Space is separate from the main codebase)
# ---------------------------------------------------------------------------
try:
from LWMTemporal.data.angle_delay import AngleDelayConfig, AngleDelayProcessor
from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
from LWMTemporal.data.scenario_generation import (
AntennaArrayConfig,
DynamicScenarioGenerator,
GridConfig,
ScenarioGenerationConfig,
ScenarioSamplingConfig,
TrafficConfig,
)
from LWMTemporal.models.lwm import LWMBackbone, LWMConfig, NeighborIndexer, masked_nmse_loss
from LWMTemporal.tasks.channel_prediction import (
ChannelPredictionArgs,
ChannelPredictionTrainer,
DatasetArgs,
ModelArgs,
PredictionArgs,
TrainingArgs,
compute_nmse,
)
from LWMTemporal.tasks.pretraining import MaskArgs, MaskGenerator
from LWMTemporal.utils.logging import setup_logging
except ImportError:
# Fallback: clone from a private HF repo and import from there.
# Configure via environment:
# LWM_TEMPORAL_REPO_ID (e.g. "your-username/lwm-temporal")
# LWM_TEMPORAL_REVISION (optional branch/tag/commit)
# HF_TOKEN must have access to the private repo.
import os
import sys
from huggingface_hub import snapshot_download
BASE_DIR = Path(__file__).resolve().parents[1]
REMOTE_DIR = BASE_DIR / "_remote_lwm_temporal"
REPO_ID = os.getenv("LWM_TEMPORAL_REPO_ID", "wi-lab/lwm-temporal")
REVISION = os.getenv("LWM_TEMPORAL_REVISION") or None
if not REMOTE_DIR.exists():
snapshot_download(
repo_id=REPO_ID,
revision=REVISION,
repo_type="model",
local_dir=REMOTE_DIR,
token=os.getenv("HF_TOKEN"),
)
# snapshot_download(
# repo_id=REPO_ID,
# revision=REVISION,
# repo_type="model",
# local_dir=REMOTE_DIR,
# token=os.getenv("HF_TOKEN"),
# allow_patterns=[
# "*.bin",
# ],
# )
sys.path.insert(0, str(REMOTE_DIR))
from LWMTemporal.data.angle_delay import AngleDelayConfig, AngleDelayProcessor
from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
from LWMTemporal.data.scenario_generation import (
AntennaArrayConfig,
DynamicScenarioGenerator,
GridConfig,
ScenarioGenerationConfig,
ScenarioSamplingConfig,
TrafficConfig,
)
from LWMTemporal.models.lwm import LWMBackbone, LWMConfig, NeighborIndexer, masked_nmse_loss
from LWMTemporal.tasks.channel_prediction import (
ChannelPredictionArgs,
ChannelPredictionTrainer,
DatasetArgs,
ModelArgs,
PredictionArgs,
TrainingArgs,
compute_nmse,
)
from LWMTemporal.tasks.pretraining import MaskArgs, MaskGenerator
from LWMTemporal.utils.logging import setup_logging
# ---------------------------------------------------------------------------
# DeepMIMO download helper (optional)
# ---------------------------------------------------------------------------
try: # first try the installed Python package name
from deepmimo import download as deepmimo_download # type: ignore
except Exception: # pragma: no cover
try: # fallback to the vendored package path in this repo
from DeepMIMO.deepmimo import download as deepmimo_download # type: ignore
except Exception: # pragma: no cover
deepmimo_download = None # type: ignore
print("REMOTE_DIR =", REMOTE_DIR)
for root, dirs, files in os.walk(REMOTE_DIR):
print(root)
for f in files:
print(" -", f)
ROOT = Path(__file__).resolve().parents[1]
DATA_DIR = ROOT / "examples" / "data" / "lab"
FULL_DATA_DIR = ROOT / "examples" / "full_data" / "lab"
FIG_DIR = ROOT / "figs" / "lab"
SCENARIO_DIR = ROOT / "deepmimo_scenarios"
# CHECKPOINT_PATH = ROOT / "checkpoints" / "pytorch_model.bin"
CHECKPOINT_PATH = Path(REMOTE_DIR / "checkpoints" / "pytorch_model.bin")
if not CHECKPOINT_PATH.exists():
raise FileNotFoundError(
f"Pretrained model not found at {CHECKPOINT_PATH}\n"
"Make sure your private repo 'wi-lab/lwm-temporal' contains checkpoints/pytorch_model.bin"
)
CHANNEL_PRED_VIZ = FIG_DIR / "channel_prediction"
ANGLE_DELAY_VIZ = FIG_DIR / "angle_delay"
ENV_FIG_DIR = FIG_DIR / "environment"
MASK_VIZ_DIR = FIG_DIR / "masking"
ATTN_VIZ_DIR = FIG_DIR / "attention"
for directory in [
DATA_DIR,
FULL_DATA_DIR,
FIG_DIR,
CHANNEL_PRED_VIZ,
ANGLE_DELAY_VIZ,
ENV_FIG_DIR,
MASK_VIZ_DIR,
ATTN_VIZ_DIR,
]:
directory.mkdir(parents=True, exist_ok=True)
DEFAULT_SCENARIOS: List[str] = [
"asu_campus_3p5",
"city_0_newyork_3p5",
"city_1_losangeles_3p5",
"city_2_chicago_3p5",
"city_3_houston_3p5",
"city_4_phoenix_3p5",
"city_5_philadelphia_3p5",
"city_6_miami_3p5",
"city_7_sandiego_3p5",
"city_8_dallas_3p5",
"city_9_sanfrancisco_3p5",
"city_10_austin_3p5",
"city_11_santaclara_3p5",
"city_12_fortworth_3p5",
"city_13_columbus_3p5",
"city_14_charlotte_3p5",
"city_15_indianapolis_3p5",
"city_16_sanfrancisco_3p5",
"city_17_seattle_3p5",
"city_18_denver_3p5",
"city_19_oklahoma_3p5",
"city_21_taito_city_3p5",
"city_22_rome_3p5",
"city_23_beijing_3p5",
"city_24_cairo_3p5",
"city_25_instanbul_3p5",
"city_26_dubai_3p5",
"city_27_rio_de_janeiro_3p5",
"city_28_mumbai_3p5",
"city_29_centro_3p5",
"city_30_singapore_3p5",
"city_31_barcelona_3p5",
"city_32_sydney_3p5",
"city_33_hong_kong_3p5",
"city_34_amsterdam_3p5",
"city_35_san_francisco_3p5",
"city_36_bangkok_3p5",
"city_37_seoul_3p5",
"city_38_toronto_3p5",
"city_39_jerusalem_3p5",
"city_40_prague_3p5",
"city_41_kyoto_3p5",
"city_42_san_nicolas_3p5",
"city_43_cape_town_3p5",
"city_44_lisboa_3p5",
"city_45_stockholm_3p5",
"city_46_la_habana_3p5",
"city_47_chicago_3p5",
"city_48_gurbchen_3p5",
"city_49_new_delhi_3p5",
"city_50_edinburgh_3p5",
"city_51_firenze_3p5",
"city_52_marrakesh_3p5",
"city_53_sankt-peterburg_3p5",
"city_54_north_jakarta_3p5",
"city_55_madrid_3p5",
"city_56_montreal_3p5",
"city_57_taipei_3p5",
"city_58_sumida_city_3p5",
"city_59_roma_3p5",
"city_60_toronto_3p5",
"city_61_hatsukaichi_3p5",
"city_62_santiago_3p5",
"city_63_athens_3p5",
"city_64_new_york_3p5",
"city_65_granada_3p5",
"city_66_bruxelles_3p5",
"city_67_fujiyoshida_3p5",
"city_68_reykjavik_3p5",
"city_69_warszawa_3p5",
"city_70_stockholm_3p5",
"city_71_helsinki_3p5",
"city_72_capetown_3p5",
"city_73_casablanca_3p5",
"city_74_chiyoda_3p5",
"city_75_dongcheng_3p5",
"city_76_houston_3p5",
"city_77_melbourne_3p5",
"city_78_newtaipei_3p5",
"city_79_parow_3p5",
"city_80_philadelphia_3p5",
"city_81_seoul_3p5",
"city_82_tempe_3p5",
"city_0_newyork_28",
"city_1_losangeles_28",
"city_2_chicago_28",
"city_3_houston_28",
"city_4_phoenix_28",
"city_5_philadelphia_28",
"city_6_miami_28",
"city_7_sandiego_28",
"city_8_dallas_28",
"city_9_sanfrancisco_28",
"city_10_austin_28",
"city_11_santaclara_28",
"city_12_fortworth_28",
"city_13_columbus_28",
"city_14_charlotte_28",
"city_15_indianapolis_28",
"city_16_sanfrancisco_28",
"city_17_seattle_28",
"city_18_denver_28",
"city_19_oklahoma_28",
"boston5g_3p5",
"boston5g_28",
"city_0_newyork_3p5_s",
"city_1_losangeles_3p5_s",
"city_2_chicago_3p5_s",
"city_3_houston_3p5_s",
"city_4_phoenix_3p5_s",
"city_5_philadelphia_3p5_s",
"city_6_miami_3p5_s",
"city_7_sandiego_3p5_s",
"city_8_dallas_3p5_s",
"city_9_sanfrancisco_3p5_s",
"city_10_austin_3p5_s",
"city_11_santaclara_3p5_s",
"city_12_fortworth_3p5_s",
"city_13_columbus_3p5_s",
"city_14_charlotte_3p5_s",
"city_15_indianapolis_3p5_s",
"city_16_sanfrancisco_3p5_s",
"city_17_seattle_3p5_s",
"city_18_denver_3p5_s",
"city_19_oklahoma_3p5_s",
"i1_2p5",
"i2_28b",
"o1_3p5",
"o1_3p4",
"o1b_3p5",
"o1b_28",
"o1_140",
"i1_2p4",
"o1_28",
"o1_60",
"o1_drone_200",
]
# LWM-Temporal curated scenario subsets
LWM_TEMPORAL_TRAIN_SCENARIOS: List[str] = [
"city_72_capetown_3p5",
"city_73_casablanca_3p5",
"city_74_chiyoda_3p5",
"city_75_dongcheng_3p5",
"city_76_houston_3p5",
"city_77_melbourne_3p5",
"city_78_newtaipei_3p5",
"city_79_parow_3p5",
"city_80_philadelphia_3p5",
"city_81_seoul_3p5",
]
LWM_TEMPORAL_TEST_SCENARIOS: List[str] = [
"city_82_tempe_3p5",
]
DEEPMIMO_V4_SCENARIOS: List[str] = DEFAULT_SCENARIOS
# General DeepMIMO v4 sub-families
DEEPMIMO_SIONNA_SCENARIOS: List[str] = [
name for name in DEEPMIMO_V4_SCENARIOS if name.endswith("_s")
]
_WI_CITY_PREFIXES = ("city", "asu", "boston")
DEEPMIMO_WI_CITY_SCENARIOS: List[str] = [
name
for name in DEEPMIMO_V4_SCENARIOS
if not name.endswith("_s") and name.startswith(_WI_CITY_PREFIXES)
]
DEEPMIMO_WI_CITY_3P5_SCENARIOS: List[str] = [
name for name in DEEPMIMO_WI_CITY_SCENARIOS if name.endswith("3p5")
]
DEEPMIMO_WI_CITY_28_SCENARIOS: List[str] = [
name for name in DEEPMIMO_WI_CITY_SCENARIOS if name.endswith("28")
]
DEEPMIMO_WI_OTHER_SCENARIOS: List[str] = [
name
for name in DEEPMIMO_V4_SCENARIOS
if not name.endswith("_s") and name not in DEEPMIMO_WI_CITY_SCENARIOS
]
DEEPMIMO_GENERAL_SUBFAMILIES: List[str] = [
"All DeepMIMO v4 scenarios",
"Sionna: City scenarios (3.5 GHz)",
"Wireless InSite: City scenarios (3.5 GHz)",
"Wireless InSite: City scenarios (28 GHz)",
"Wireless InSite: Other scenarios",
]
MOCK_CMAP = plt.get_cmap("magma")
# ---------------------------------------------------------------------------
# General helpers
# ---------------------------------------------------------------------------
def _resolve_device(choice: str) -> str:
if choice == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return choice
def _ensure_state(state: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if not state:
raise gr.Error("Generate or load a scenario first.")
return state
def _load_payload(full_path: Path) -> Dict[str, Any]:
with full_path.open("rb") as handle:
payload = pickle.load(handle)
if not isinstance(payload, dict):
raise gr.Error("Expected a dictionary payload. Regenerate the scenario.")
return payload
def _summarize_payload(payload: Dict[str, Any]) -> str:
if isinstance(payload, dict):
channel = payload.get("channel_discrete")
if channel is None:
channel = payload.get("channel")
else:
channel = payload
if channel is None:
return "Payload missing channel tensor."
shape = tuple(channel.shape)
los = payload.get("los") if isinstance(payload, dict) else None
los_classes = ""
if los is not None:
unique, counts = torch.unique(torch.as_tensor(los), return_counts=True)
los_classes = ", ".join(f"{int(u.item())}: {int(c.item())}" for u, c in zip(unique, counts))
keys = ", ".join(sorted(payload.keys())) if isinstance(payload, dict) else "n/a"
return (
f"Channel tensor shape: {shape}\n"
f"Payload keys: {keys}\n"
+ (f"LoS distribution: {los_classes}\n" if los_classes else "")
)
def _angle_delay_paths(run_tag: str) -> Tuple[Path, Path, Path]:
run_dir = ANGLE_DELAY_VIZ / run_tag
run_dir.mkdir(parents=True, exist_ok=True)
return (
run_dir / "angle_delay.gif",
run_dir / "channel.gif",
run_dir / "bins.png",
)
def _save_angle_delay_gif_lab(
tensor: torch.Tensor,
output_path: Path,
fps: int = 6,
pause_seconds: float = 0.7,
) -> None:
"""Save an angle–delay GIF using the same style as the notebook helper.
We re-implement the non-interactive branch here so that Gradio receives a
true multi-frame GIF instead of a single-frame PNG.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
tensor = tensor.to(torch.complex64)
magnitude = tensor.abs().cpu()
vmin, vmax = float(magnitude.min()), float(magnitude.max())
frames: List[np.ndarray] = []
for frame_idx in range(magnitude.size(0)):
fig, ax = plt.subplots(figsize=(8, 6))
fig.patch.set_facecolor("#0b0e11")
ax.set_facecolor("#0b0e11")
ax.tick_params(colors="#cbd5f5")
for spine in ax.spines.values():
spine.set_color("#374151")
im = ax.imshow(
magnitude[frame_idx].numpy(),
cmap="magma",
origin="lower",
aspect="auto",
vmin=vmin,
vmax=vmax,
)
ax.set_xlabel("Delay bins", color="#cbd5f5")
ax.set_ylabel("Angle bins", color="#cbd5f5")
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.ax.yaxis.set_tick_params(color="#cbd5f5")
plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5")
cbar.set_label("|H| (dB)", color="#cbd5f5")
ax.set_title(
f"Angle-Delay Intensity — Frame {frame_idx}",
color="#f8fafc",
fontsize=12,
fontweight="semibold",
)
fig.canvas.draw()
frames.append(np.asarray(fig.canvas.buffer_rgba()))
plt.close(fig)
# Duplicate the last frame so the animation appears to "pause" before looping.
# This is more robust across different GIF viewers than relying on per-frame durations.
extra_frames = max(1, int(round(max(pause_seconds, 0.0) * max(fps, 1))))
if frames and extra_frames > 0:
frames.extend([frames[-1]] * extra_frames)
# Ensure the GIF loops indefinitely (loop=0 is "infinite" in GIF metadata).
imageio.mimsave(output_path, frames, fps=fps, loop=0)
def _add_pause_to_gif(
gif_path: Path,
fps: int = 5,
pause_seconds: float = 0.7,
) -> None:
"""Post-process an existing GIF to add a small pause at the end of each loop.
Implemented by duplicating the last frame a few times, which works even when
viewers ignore per-frame duration metadata.
"""
gif_path = Path(gif_path)
if not gif_path.exists():
return
try:
frames = imageio.mimread(gif_path)
except Exception:
return
if not frames:
return
extra_frames = max(1, int(round(max(pause_seconds, 0.0) * max(fps, 1))))
if extra_frames > 0:
frames.extend([frames[-1]] * extra_frames)
imageio.mimsave(gif_path, frames, fps=fps, loop=0)
def _build_dataset(raw_path: Path, keep_percentage: float) -> AngleDelaySequenceDataset:
cfg = AngleDelayDatasetConfig(
raw_path=raw_path,
keep_percentage=keep_percentage,
normalize="per_sample_rms",
cache_dir=ROOT / "cache",
patch_size=(1, 1),
phase_mode="real_imag",
)
return AngleDelaySequenceDataset(cfg)
def _refresh_scenarios() -> Tuple[List[str], str]:
names = DEFAULT_SCENARIOS if DEFAULT_SCENARIOS else ["asu_campus_3p5"]
return names, "Loaded static scenario list."
def _ensure_scenario_available(scenario: str, auto_download: bool) -> None:
scenario_dir = SCENARIO_DIR / scenario
if scenario_dir.exists():
return
if not auto_download:
raise gr.Error(
f"Scenario '{scenario}' not found under '{SCENARIO_DIR}'. Enable the download toggle to fetch it automatically."
)
if deepmimo_download is None:
raise gr.Error(
"DeepMIMO download helper is unavailable. Install DeepMIMO dependencies or place the scenario manually."
)
try:
deepmimo_download(scenario, output_dir=str(SCENARIO_DIR))
except Exception as exc:
raise gr.Error(f"Failed to download scenario '{scenario}': {exc}") from exc
# ---------------------------------------------------------------------------
# Angle-delay helpers (ported from examples/ad_temporal_evolutiton.py)
# ---------------------------------------------------------------------------
# def pick_bins(sequence: torch.Tensor, k: int) -> List[Tuple[int, int]]:
# if sequence.ndim != 3:
# raise ValueError(f"Expected angle-delay tensor (T, H, W); got {tuple(sequence.shape)}")
# _, H, W = sequence.shape
# mag = sequence.abs().mean(dim=0)
# topk = torch.topk(mag.flatten(), k=min(k, H * W))
# picks: List[Tuple[int, int]] = []
# for idx in topk.indices.tolist():
# n = idx // W
# m = idx % W
# picks.append((n, m))
# return picks
# def plot_curves(sequence: torch.Tensor, picks: List[Tuple[int, int]], out_path: Path, title: str) -> None:
# if not picks:
# raise ValueError("No bins selected for plotting.")
# T = sequence.size(0)
# fig, axes = plt.subplots(len(picks), 2, figsize=(12, 3 * len(picks)), dpi=150)
# axes = np.atleast_2d(axes)
# times = np.arange(T)
# for row, (n, m) in enumerate(picks):
# series = sequence[:, n, m]
# mag = series.abs().cpu().numpy()
# phase = np.unwrap(torch.angle(series).cpu().numpy())
# ax_mag, ax_phase = axes[row]
# ax_mag.imshow(mag[None, :], aspect="auto", cmap="magma")
# ax_mag.set_title(f"|H| for bin (n={n}, m={m})")
# ax_mag.set_yticks([])
# ax_mag.set_xlabel("time index")
# ax_phase.plot(times, phase, color="#f87171")
# ax_phase.set_title("Phase (rad)")
# ax_phase.set_xlabel("time index")
# ax_phase.grid(True, alpha=0.3)
# fig.suptitle(title)
# out_path.parent.mkdir(parents=True, exist_ok=True)
# fig.tight_layout()
# fig.savefig(out_path)
# plt.close(fig)
# ---------------------------------------------------------------------------
# Scenario generation + visualization
# ---------------------------------------------------------------------------
def generate_scenario_action(
prev_state: Optional[Dict[str, Any]],
scenario: str,
time_steps: int,
sample_dt: float,
vehicles: int,
pedestrians: int,
turn_prob: float,
vehicle_speed_max: float,
road_width: float,
road_spacing: float,
keep_percentage: float,
overwrite: bool,
auto_download: bool,
) -> Tuple[str, str, Optional[str], Optional[go.Figure], Dict[str, Any]]:
if not scenario:
raise gr.Error("Select or enter a scenario name, then try again. Click 'Refresh scenario list' if needed.")
progress = gr.Progress(track_tqdm=False)
progress(0.05, desc="Configuring scenario")
_ensure_scenario_available(scenario, auto_download)
signature = "|".join(
map(
str,
[
scenario,
time_steps,
sample_dt,
vehicles,
pedestrians,
turn_prob,
vehicle_speed_max,
road_width,
road_spacing,
keep_percentage,
],
)
)
prev_signature = (prev_state or {}).get("signature")
# Always regenerate the scenario to avoid reusing cached payloads.
should_overwrite = True
env_dir = ENV_FIG_DIR / scenario
config = ScenarioGenerationConfig(
scenario=scenario,
antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32),
sampling=ScenarioSamplingConfig(time_steps=time_steps, sample_dt=sample_dt),
traffic=TrafficConfig(
num_vehicles=vehicles,
num_pedestrians=pedestrians,
turn_probability=turn_prob,
vehicle_speed_range=(0 / 3.6, vehicle_speed_max / 3.6),
),
grid=GridConfig(
road_width=float(road_width),
road_center_spacing=float(road_spacing),
),
output_dir=DATA_DIR,
full_output_dir=FULL_DATA_DIR,
figures_dir=env_dir,
scenarios_dir=SCENARIO_DIR,
deepmimo_max_paths=25,
)
generator = DynamicScenarioGenerator(config)
progress(0.2, desc="Running generator")
result = generator.generate(overwrite=should_overwrite)
payload = result.payload
summary = (
f"{'Generated' if result.generated else 'Loaded cached'} scenario **{scenario}**\n"
f"- time steps: {time_steps}\n"
f"- vehicles/pedestrians: {vehicles}/{pedestrians}\n"
f"- output: `{result.output_path.name}`\n"
)
stats = _summarize_payload(payload)
# Static PNG environment figure (from the library helper).
env_path = env_dir / "environment.png"
env_display_path: Optional[str] = None
if env_path.exists():
stamped = ENV_FIG_DIR / f"{scenario}_env_{int(time.time())}.png"
try:
shutil.copy(env_path, stamped)
env_display_path = str(stamped)
except Exception:
env_display_path = str(env_path)
# Interactive "map-like" environment figure using Plotly (zoom & pan + selection).
env_plot: Optional[go.Figure] = None
try:
pos_cont = payload.get("pos_cont")
x_vals: Optional[np.ndarray] = None
y_vals: Optional[np.ndarray] = None
# Continuous positions
if pos_cont is not None:
arr = np.asarray(pos_cont)
if arr.ndim >= 2 and arr.shape[-1] >= 2:
x_vals = arr[..., 0].reshape(-1)
y_vals = arr[..., 1].reshape(-1)
# Fallback to discrete trajectories
if x_vals is None or y_vals is None:
pos_disc = payload.get("pos_discrete")
if pos_disc is not None:
all_pts: List[np.ndarray] = []
for traj in pos_disc:
t_arr = np.asarray(traj)
if t_arr.ndim >= 2 and t_arr.shape[-1] >= 2:
all_pts.append(t_arr[..., :2].reshape(-1, 2))
if all_pts:
stacked = np.concatenate(all_pts, axis=0)
x_vals = stacked[:, 0]
y_vals = stacked[:, 1]
if x_vals is not None and y_vals is not None and x_vals.size:
env_plot = go.Figure()
env_plot.add_trace(
go.Scattergl(
x=x_vals,
y=y_vals,
mode="markers",
marker=dict(
size=4,
opacity=0.7,
color="#38bdf8", # cyan-ish blue
line=dict(width=0),
),
name="Users / tracks",
hovertemplate="x: %{x:.2f} m<br>y: %{y:.2f} m<extra></extra>",
# Make box/lasso selection visually obvious:
selected=dict(
marker=dict(
size=6,
opacity=1.0,
color="#f97316", # orange for selected
)
),
unselected=dict(
marker=dict(
opacity=0.15,
)
),
)
)
env_plot.update_layout(
# ⬇⬇⬇ your requested title dict added here
# title=dict(
# text="Environment overview",
# x=0.02, # left-align
# xanchor="left",
# y=0.96,
# yanchor="top",
# font=dict(size=18, color="#f9fafb"),
# ),
template="plotly_dark",
font=dict(
family="Inter, system-ui, -apple-system, BlinkMacSystemFont, sans-serif",
size=14,
color="#e5e7eb",
),
paper_bgcolor="#020617",
plot_bgcolor="#020617",
margin=dict(l=40, r=20, t=50, b=40),
height=550,
hovermode="closest",
# Use "zoom" (default) so zoom/box-select/lasso all work via toolbar
dragmode="zoom",
showlegend=False,
)
env_plot.update_xaxes(
title="x (m)",
showgrid=True,
gridcolor="rgba(148, 163, 184, 0.25)",
zeroline=False,
showline=True,
linecolor="rgba(148, 163, 184, 0.6)",
mirror=True,
scaleanchor="y",
scaleratio=1,
)
env_plot.update_yaxes(
title="y (m)",
showgrid=True,
gridcolor="rgba(148, 163, 184, 0.25)",
zeroline=False,
showline=True,
linecolor="rgba(148, 163, 184, 0.6)",
mirror=True,
)
except Exception:
env_plot = None
state = {
"scenario": scenario,
"slim_path": str(result.output_path),
"full_path": str(result.full_output_path),
"keep_percentage": keep_percentage,
"time_steps": time_steps,
"signature": signature,
}
return summary, stats, env_display_path, env_plot, state
def visualize_angle_delay_action(
state: Optional[Dict[str, Any]],
ue_index: int,
keep_percentage: float,
) -> Tuple[Optional[str], Optional[str], Optional[str], str]:
info = _ensure_state(state)
payload = _load_payload(Path(info["full_path"]))
discrete = torch.as_tensor(payload["channel_discrete"])
total = discrete.size(0)
idx = max(0, min(total - 1, int(ue_index)))
processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=keep_percentage))
run_tag = f"{info['scenario']}_ue{idx}_{int(time.time())}"
angle_delay_path, channel_path, curves_path = _angle_delay_paths(run_tag)
tensor_discrete = torch.as_tensor(payload["channel_discrete"][idx])
tensor_cont = torch.as_tensor(payload["channel_cont"][idx])
volume = processor.forward(tensor_discrete)
volume_trimmed, _ = processor.truncate_delay_bins(volume)
# Save GIF using the same logic as the notebook, but with an explicit GIF writer for Gradio
_save_angle_delay_gif_lab(volume_trimmed, angle_delay_path, fps=6, pause_seconds=0.7)
# Channel animation from the package helper
processor.save_channel_animation(tensor_cont, channel_path, show=False)
# Add a small pause between loops for the channel GIF as well.
_add_pause_to_gif(channel_path, fps=6, pause_seconds=0.7)
# Dominant-bin evolution plot from the package helper for exact matching with the library
processor.save_bin_evolution_plot(tensor_discrete, curves_path, show=False)
summary = (
f"Rendered UE {idx} / {total}.\n"
f"Angle-delay GIF: `{angle_delay_path.name}` | Channel GIF: `{channel_path.name}`"
)
return str(angle_delay_path), str(channel_path), str(curves_path), summary
# ---------------------------------------------------------------------------
# Masked modeling sanity check
# ---------------------------------------------------------------------------
def masked_modeling_action(
state: Optional[Dict[str, Any]],
mask_ratio: float,
batch_size: int,
device_choice: str,
max_batches: int,
) -> str:
info = _ensure_state(state)
raw_path = Path(info["slim_path"])
if not raw_path.exists():
raise gr.Error(f"Dataset not found at {raw_path}")
if not CHECKPOINT_PATH.exists():
raise gr.Error(f"Checkpoint {CHECKPOINT_PATH} missing. Train or copy a backbone first.")
dataset = _build_dataset(raw_path, info["keep_percentage"])
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False)
if not len(dataset):
raise gr.Error("Dataset is empty. Regenerate the scenario.")
example_shape = dataset[0]["shape"]
T, H, W = (int(example_shape[dim].item()) for dim in range(3))
tokens_per_sample = T * H * W
device = torch.device(_resolve_device(device_choice))
cfg = LWMConfig(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=12,
num_heads=8,
mlp_ratio=4.0,
same_frame_window=2,
temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
temporal_spatial_window=2,
temporal_drift_h=1,
temporal_drift_w=1,
routing_topk_enable=True,
topk_per_head=True,
max_seq_len=tokens_per_sample,
)
backbone = LWMBackbone.from_pretrained(CHECKPOINT_PATH, config=cfg).to(device)
backbone.eval()
masker = MaskGenerator(MaskArgs(mask_ratio=mask_ratio, mask_mode="random"))
total_nmse = 0.0
total_items = 0
processed = 0
with torch.no_grad():
for batch in loader:
tokens = batch["tokens"].to(device)
base_mask = batch["base_mask"].to(device)
seq = batch["sequence"].to(device)
T, H, W = seq[0].shape
B, _, _ = tokens.shape
mask = torch.stack([masker(T, H, W, device).view(-1) for _ in range(B)])
mask = torch.logical_or(mask, base_mask)
corrupted = tokens.masked_fill(mask.unsqueeze(-1), 0.0)
outputs = backbone.forward_tokens(corrupted, mask, T, H, W, return_cls=False)
reconstructed = outputs["reconstruction"]
nmse = masked_nmse_loss(reconstructed, tokens, mask)
total_nmse += nmse.item() * B
total_items += B
processed += 1
if processed >= max_batches:
break
avg_nmse = total_nmse / max(1, total_items)
avg_nmse_db = 10.0 * math.log10(max(avg_nmse, 1e-12))
return f"Processed {processed} batch(es). NMSE={avg_nmse:.4e} ({avg_nmse_db:.2f} dB)"
# ---------------------------------------------------------------------------
# Channel prediction inference
# ---------------------------------------------------------------------------
def _evaluate_channel_prediction(trainer: ChannelPredictionTrainer, loader: DataLoader) -> Tuple[float, float]:
trainer.model.eval()
losses: List[float] = []
nmses: List[float] = []
with torch.no_grad():
for batch in loader:
tokens, _, H, W = trainer._prepare_batch(batch)
preds, target, mask = trainer.engine.autoregressive_rollout(
trainer.model,
tokens,
trainer.args.prediction.Tpast,
trainer.args.prediction.horizon,
H,
W,
)
loss = trainer._compute_loss(preds, target, mask)
losses.append(loss.item())
nmses.append(compute_nmse(preds, target, mask))
avg_loss = float(sum(losses) / max(1, len(losses)))
avg_nmse = float(sum(nmses) / max(1, len(nmses)))
return avg_loss, avg_nmse
def channel_prediction_action(
state: Optional[Dict[str, Any]],
tpast: int,
horizon: int,
device_choice: str,
train_limit: int,
val_limit: int,
) -> Tuple[str, Optional[str]]:
info = _ensure_state(state)
if not CHECKPOINT_PATH.exists():
raise gr.Error(f"Checkpoint {CHECKPOINT_PATH} missing. Train or copy a backbone first.")
dataset_args = DatasetArgs(
data_path=Path(info["slim_path"]),
keep_percentage=info["keep_percentage"],
normalize="per_sample_rms",
seed=0,
train_limit=train_limit,
val_limit=val_limit,
)
pred_viz_dir = CHANNEL_PRED_VIZ / f"{info['scenario']}_{int(time.time())}"
model_args = ModelArgs(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=12,
num_heads=8,
mlp_ratio=4.0,
same_frame_window=2,
temporal_offsets=tuple(range(-tpast, 0)),
temporal_spatial_window=2,
temporal_drift_h=1,
temporal_drift_w=1,
routing_topk_enable=True,
routing_topk_fraction=0.2,
routing_topk_max=32,
pretrained=CHECKPOINT_PATH,
)
training_args = TrainingArgs(
device=_resolve_device(device_choice),
epochs=1,
batch_size=1,
lr=1e-4,
weight_decay=1e-4,
inference_only=True,
inference_split="val",
use_dataparallel=False,
verbose_inference=False,
save_dir=ROOT / "models",
)
prediction_args = PredictionArgs(
Tpast=tpast,
horizon=horizon,
viz_dir=pred_viz_dir,
num_visual_samples=1,
)
logger = setup_logging("lab_channel_prediction", log_dir=ROOT / "logs" / "lab")
trainer = ChannelPredictionTrainer(
ChannelPredictionArgs(
dataset=dataset_args,
model=model_args,
training=training_args,
prediction=prediction_args,
),
logger=logger,
)
val_loader = trainer.data.val_loader(batch_size=1) or trainer.data.train_loader(batch_size=1)
if val_loader is None:
raise gr.Error("Dataset too small for evaluation.")
loss, nmse = _evaluate_channel_prediction(trainer, val_loader)
sample_batch = next(iter(val_loader))
tokens, _, H, W = trainer._prepare_batch(sample_batch)
trainer.viz.save(trainer.model, tokens, H, W, trainer.args.prediction)
viz_frames = sorted(pred_viz_dir.glob("sample_*.png"))
viz_path = str(viz_frames[0]) if viz_frames else None
nmse_db = 10.0 * math.log10(max(nmse, 1e-12))
summary = f"Eval batches={len(val_loader)} | loss={loss:.4e} | NMSE={nmse:.4e} ({nmse_db:.2f} dB)"
return summary, viz_path
# ---------------------------------------------------------------------------
# Masking gallery (toy visualization for different modes)
# ---------------------------------------------------------------------------
def mask_gallery_action(
T: int,
H: int,
W: int,
mask_mode: str,
mask_ratio: float,
num_examples: int,
) -> Tuple[Optional[str], str]:
T = max(1, min(12, int(T)))
H = max(1, min(32, int(H)))
W = max(1, min(32, int(W)))
if mask_mode not in {"auto", "random", "rect", "tube", "comb"}:
raise gr.Error("Unknown mask mode. Choose from auto, random, rect, tube, comb.")
num_examples = max(1, min(8, int(num_examples)))
device = torch.device("cpu")
generator = MaskGenerator(MaskArgs(mask_ratio=mask_ratio, mask_mode=mask_mode, random_fraction=0.3))
# Draw several random samples to estimate coverage, but visualize a single example
first_mask: Optional[np.ndarray] = None
for _ in range(num_examples):
mask_tensor = generator(T, H, W, device=device).cpu()
mask_np = mask_tensor.numpy().astype(float) # 1 = masked
if first_mask is None:
first_mask = mask_np
if first_mask is None:
raise gr.Error("Failed to generate masks for visualization.")
# Visualization: sequence of angle–delay channels in dark mode with only masked tokens colored
cmap = ListedColormap(
[
"#020617", # 0: background (unmasked)
"#1d4ed8", # 1: masked (deep blue)
]
)
fig, axes = plt.subplots(1, T, figsize=(3 * T, 3), dpi=180)
fig.patch.set_facecolor("#020617")
axes = np.atleast_1d(axes)
mask_seq = first_mask.astype(float) # (T, H, W)
for t in range(T):
ax = axes[t]
frame = mask_seq[t] # (H, W), 1 = masked, 0 = visible/white
ax.set_facecolor("#020617")
ax.imshow(
frame,
cmap=cmap,
origin="lower",
aspect="auto",
vmin=0.0,
vmax=1.0,
)
# brighter grid to emphasize bins on dark background
ax.set_xticks(np.arange(-0.5, W, 1), minor=True)
ax.set_yticks(np.arange(-0.5, H, 1), minor=True)
ax.grid(which="minor", color="#6b7280", linestyle="-", linewidth=0.8, alpha=0.9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax.set_title(f"t = {t}", fontsize=9, pad=6, color="#e5e7eb")
avg_cov = float(mask_seq.mean())
fig.suptitle(
f"Mask mode = {mask_mode} | T={T}, H={H}, W={W}, target≈{mask_ratio:.2f}, "
f"observed≈{avg_cov:.2f} over {num_examples} samples",
fontsize=11,
color="#f9fafb",
)
fig.tight_layout(rect=(0, 0, 1, 0.95))
out_path = MASK_VIZ_DIR / f"mask_gallery_{int(time.time())}.png"
fig.savefig(out_path, dpi=180)
plt.close(fig)
summary = f"Mode: {mask_mode} | average masked tokens ≈ {avg_cov * 100:.1f}% over {num_examples} samples."
return str(out_path), summary
# ---------------------------------------------------------------------------
# SSTA vs full attention visualization
# ---------------------------------------------------------------------------
def _token_to_indices(token_idx: int, H: int, W: int) -> Tuple[int, int, int]:
t = token_idx // (H * W)
rem = token_idx % (H * W)
h = rem // W
w = rem % W
return t, h, w
def ssta_attention_action(
T: int,
H: int,
W: int,
same_window: int,
temporal_window: int,
drift: int,
ssta_mode: str,
offsets_text: str,
query_t: int,
query_h: int,
query_w: int,
) -> Tuple[Optional[str], Optional[str], str]:
T = max(1, min(12, int(T)))
H = max(1, min(32, int(H)))
W = max(1, min(32, int(W)))
same_window = max(0, min(3, int(same_window)))
temporal_window = max(0, min(4, int(temporal_window)))
drift = max(0, min(4, int(drift)))
try:
offsets = [int(x.strip()) for x in offsets_text.split(",") if x.strip()]
except ValueError as exc:
raise gr.Error("Offsets must be comma-separated integers.") from exc
# Normalize attention mode
mode_str = (ssta_mode or "").lower()
is_causal = "causal" in mode_str
if is_causal:
# Causal: query only attends to current/past frames (non-positive offsets).
offsets = [o for o in offsets if o <= 0]
if 0 not in offsets:
offsets.append(0)
if not offsets:
offsets = [0, -1]
else:
# Bidirectional: allow both past and future context, but drop pure zero.
offsets = [o for o in offsets if o != 0]
if not offsets:
offsets = [-1, 1]
offsets = sorted(set(offsets))
config = LWMConfig(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=1,
num_heads=1,
mlp_ratio=2.0,
same_frame_window=same_window,
temporal_offsets=tuple(offsets),
temporal_spatial_window=temporal_window,
temporal_drift_h=drift,
temporal_drift_w=drift,
routing_topk_enable=False,
spatial_only=False,
)
indexer = NeighborIndexer()
neighbors = indexer.get(T, H, W, False, config, torch.device("cpu"))
seq_len = T * H * W
qt = max(0, min(T - 1, int(query_t)))
qh = max(0, min(H - 1, int(query_h)))
qw = max(0, min(W - 1, int(query_w)))
q_idx = qt * H * W + qh * W + qw
local = neighbors[q_idx]
mask = torch.zeros(seq_len, dtype=torch.bool)
valid = local[local >= 0]
mask[valid] = True
mask_3d = mask.view(T, H, W).numpy().astype(bool)
# Synthetic angle–delay sequence just for visualization
t_axis = np.linspace(0.0, 1.0, T, endpoint=True)
h_axis = np.linspace(0.0, 1.0, H, endpoint=True)
w_axis = np.linspace(0.0, 1.0, W, endpoint=True)
base = np.outer(h_axis, w_axis) # (H, W)
frames = np.stack([base * (0.5 + 0.5 * t) for t in t_axis], axis=0) # (T, H, W)
# Dark-mode overlay colors: dark background, neighbors cyan, query vivid orange
neighbor_cmap = ListedColormap(
[
"#020617", # 0: background (very dark)
"#22d3ee", # 1: neighbors
"#f97316", # 2: query
]
)
# Sparse SSTA neighborhood visualization (dark mode)
fig_ssta, axes_ssta = plt.subplots(1, T, figsize=(3 * T, 3), dpi=180)
fig_ssta.patch.set_facecolor("#020617")
axes_ssta = np.atleast_1d(axes_ssta)
for t in range(T):
ax = axes_ssta[t]
# overlay map: 0 = white (background), 1 = neighbor, 2 = query
overlay = np.zeros((H, W), dtype=int)
overlay[mask_3d[t]] = 1 # neighbors at this frame
if t == qt:
qh_clamped = max(0, min(H - 1, qh))
qw_clamped = max(0, min(W - 1, qw))
overlay[qh_clamped, qw_clamped] = 2 # query
ax.set_facecolor("#020617")
ax.imshow(
overlay,
cmap=neighbor_cmap,
origin="lower",
aspect="auto",
vmin=0,
vmax=2,
)
ax.set_xticks(np.arange(-0.5, W, 1), minor=True)
ax.set_yticks(np.arange(-0.5, H, 1), minor=True)
# Brighter grid for better visibility on dark background
ax.grid(which="minor", color="#6b7280", linestyle="-", linewidth=0.8, alpha=0.9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax.set_title(f"t = {t}", fontsize=10, pad=6, color="#e5e7eb")
mode_title = "Causal SSTA (prediction-time)" if is_causal else "Bidirectional SSTA (pretraining)"
fig_ssta.suptitle(
f"{mode_title} — query (t={qt}, h={qh}, w={qw})",
fontsize=12,
color="#f9fafb",
)
fig_ssta.tight_layout(rect=(0, 0, 1, 0.9))
ts = int(time.time())
ssta_path = ATTN_VIZ_DIR / f"ssta_{ts}.png"
fig_ssta.savefig(ssta_path)
plt.close(fig_ssta)
# Full attention visualization: query attends to all tokens (dark mode)
fig_full, axes_full = plt.subplots(1, T, figsize=(3 * T, 3), dpi=180)
fig_full.patch.set_facecolor("#020617")
axes_full = np.atleast_1d(axes_full)
for t in range(T):
ax = axes_full[t]
# 1 everywhere (all tokens attended), 2 at query on its frame
overlay_full = np.ones((H, W), dtype=int)
if t == qt:
qh_clamped = max(0, min(H - 1, qh))
qw_clamped = max(0, min(W - 1, qw))
overlay_full[qh_clamped, qw_clamped] = 2
ax.set_facecolor("#020617")
ax.imshow(
overlay_full,
cmap=neighbor_cmap,
origin="lower",
aspect="auto",
vmin=0,
vmax=2,
)
ax.set_xticks(np.arange(-0.5, W, 1), minor=True)
ax.set_yticks(np.arange(-0.5, H, 1), minor=True)
# Brighter grid for better visibility on dark background
ax.grid(which="minor", color="#6b7280", linestyle="-", linewidth=0.8, alpha=0.9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax.set_title(f"t = {t}", fontsize=10, pad=6, color="#e5e7eb")
fig_full.suptitle(
f"Full attention — query (t={qt}, h={qh}, w={qw})",
fontsize=12,
color="#f9fafb",
)
fig_full.tight_layout(rect=(0, 0, 1, 0.9))
full_path = ATTN_VIZ_DIR / f"full_attn_{ts}.png"
fig_full.savefig(full_path)
plt.close(fig_full)
mode_expl = (
"Bidirectional SSTA is used during pretraining to learn generic space–frequency–time dependencies."
if not is_causal
else "Causal SSTA is used for generative tasks like channel prediction, where future tokens are not visible."
)
summary = (
f"Mode: {'Causal' if is_causal else 'Bidirectional'} SSTA. {mode_expl}\n"
f"T={T}, H={H}, W={W}. Sparse neighbors: {int(mask.sum())}/{seq_len} tokens.\n"
f"Offsets={offsets}, same_frame_window={same_window}, "
f"temporal_window={temporal_window}, drift={drift}."
)
return str(ssta_path), str(full_path), summary
# ---------------------------------------------------------------------------
# UI assembly
# ---------------------------------------------------------------------------
def build_lab() -> gr.Blocks:
with gr.Blocks(title="LWM-Temporal Lab") as demo:
gr.Markdown(
"""
# LWM-Temporal Interactive Lab
A playground for dynamic DeepMIMO scenarios, angle–delay evolution,
zero-shot performances, masking strategies, and Sparse Spatio-Temporal Attention (SSTA)
"""
)
scenario_state = gr.State()
with gr.Tab("Scenario & Data"):
gr.Markdown(
"Design a **dynamic wireless world** in a few clicks.\n\n"
"LWM-Temporal extends the DeepMIMO library with a fully-flexible dynamic scenario generator, so you can:\n"
"- Mix **realistic ray-traced environments** from engines like *Wireless InSite*, *Sionna*, and *AODT* from cities all around the globe with DeepMIMO v4 (https://deepmimo.net/)\n"
"- Configure the **antenna array** (horizontal / vertical elements) and **number of subcarriers**\n"
"- Control the **temporal resolution** (time steps and sampling rate) of your channel sequence\n"
"- Shape the **traffic**: number of vehicles and pedestrians, their speeds, and turn probability\n"
"- Tune the **road geometry** (road width and center spacing) and the **delay-bin keep percentage** in the angle–delay domain\n\n"
"These controls let you pretrain LWM-Temporal at scale on diverse, richly-parameterized scenarios before moving to downstream tasks"
)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
scenario_family = gr.Dropdown(
choices=[
"General DeepMIMO v4 scenarios",
"LWM-Temporal train scenarios",
"LWM-Temporal test scenarios",
],
value="General DeepMIMO v4 scenarios",
label="Scenario family",
)
general_subfamily = gr.Dropdown(
choices=DEEPMIMO_GENERAL_SUBFAMILIES,
value=DEEPMIMO_GENERAL_SUBFAMILIES[0],
label="DeepMIMO v4 sub-family",
visible=True,
)
scenario_input = gr.Dropdown(
choices=DEEPMIMO_V4_SCENARIOS,
value=DEFAULT_SCENARIOS[16],
label="Selected DeepMIMO scenario",
interactive=True,
allow_custom_value=True,
)
gr.Markdown(
"We provide three scenario families:\n"
"- **General DeepMIMO v4 scenarios**: Full DeepMIMO v4 catalog\n"
"- **LWM-Temporal train scenarios**: Dense scenarios with 10 cm grid spacing\n"
"- **LWM-Temporal test scenarios**: Held-out dense scenarios for robust evaluation\n\n"
"Pick a family, then choose a specific scenario or type any valid DeepMIMO v4 scenario name"
)
time_steps_input = gr.Slider(1, 64, value=11, step=1, label="Time steps")
sample_dt_input = gr.Slider(1e-3, 1e-1, value=1e-3, step=1e-4, label="Sample Δt (s)")
vehicle_input = gr.Slider(1, 1000, value=300, step=10, label="Vehicles")
pedestrian_input = gr.Slider(0, 100, value=50, step=5, label="Pedestrians")
turn_prob_input = gr.Slider(0.0, 0.4, value=0.1, step=0.01, label="Turn probability")
speed_input = gr.Slider(0, 150, value=20, step=5, label="Vehicle speed max (km/h)")
road_width_input = gr.Slider(1.0, 6.0, value=2.0, step=0.5, label="Road width (m)")
road_spacing_input = gr.Slider(6.0, 40.0, value=8.0, step=1.0, label="Road center spacing (m)")
keep_pct_input = gr.Slider(0.1, 1.0, value=0.25, step=0.05, label="Angle-delay keep %")
with gr.Row():
overwrite_toggle = gr.Checkbox(value=False, label="Overwrite cached files")
download_toggle = gr.Checkbox(value=True, label="Download scenario if missing")
run_button = gr.Button("Generate scenario", variant="primary")
with gr.Column(scale=2):
env_image = gr.Image(
type="filepath",
label="Environment & Dynamic UEs",
interactive=False,
height=400,
)
env_plot = gr.Plot(
label="Dynamic UEs (pan & zoom)",
)
scenario_summary = gr.Markdown(label="Scenario status")
payload_stats = gr.Textbox(lines=5, label="Payload summary", max_lines=6)
# When the scenario family changes, update the available choices in the scenario dropdown.
def _update_scenarios(family: str, subfamily: str) -> Any:
if "train" in family:
choices = LWM_TEMPORAL_TRAIN_SCENARIOS
elif "test" in family:
choices = LWM_TEMPORAL_TEST_SCENARIOS
else:
# General DeepMIMO v4 scenarios, further split into sub-families
label = (subfamily or "").lower()
if "sionna" in label or "_s" in label:
choices = DEEPMIMO_SIONNA_SCENARIOS
elif "3.5" in label or "3p5" in label:
choices = DEEPMIMO_WI_CITY_3P5_SCENARIOS
elif "28ghz" in label or "28" in label:
choices = DEEPMIMO_WI_CITY_28_SCENARIOS
elif "other" in label:
choices = DEEPMIMO_WI_OTHER_SCENARIOS
else:
choices = DEEPMIMO_V4_SCENARIOS
value = choices[0] if choices else None
# Use generic gr.update for compatibility with older Gradio versions.
return gr.update(choices=choices, value=value)
scenario_family.change(
fn=_update_scenarios,
inputs=[scenario_family, general_subfamily],
outputs=scenario_input,
)
general_subfamily.change(
fn=_update_scenarios,
inputs=[scenario_family, general_subfamily],
outputs=scenario_input,
)
# Show the sub-family dropdown only when the general DeepMIMO family is active.
def _toggle_general_subfamily(family: str) -> Any:
is_general = "general deepmimo v4" in (family or "").lower()
return gr.update(visible=is_general)
scenario_family.change(
fn=_toggle_general_subfamily,
inputs=scenario_family,
outputs=general_subfamily,
)
run_button.click(
fn=generate_scenario_action,
inputs=[
scenario_state,
scenario_input,
time_steps_input,
sample_dt_input,
vehicle_input,
pedestrian_input,
turn_prob_input,
speed_input,
road_width_input,
road_spacing_input,
keep_pct_input,
overwrite_toggle,
download_toggle,
],
outputs=[scenario_summary, payload_stats, env_image, env_plot, scenario_state],
)
with gr.Tab("Angle–Delay Visuals"):
gr.Markdown(
"Explore the **angle–delay domain** where LWM-Temporal actually thinks\n\n"
"- Represent rich **space–frequency channels** in a compact, sparse grid over angle and delay\n"
"- Work directly with the **physical multipath structure** (rays and echoes) instead of opaque frequency-domain tensors\n"
"- Train LWM-Temporal in a **physics-informed** way and interpret its attention maps over angle–delay bins\n\n"
"Use this tab to compare how a single UE’s channel evolves over time in both continuous and discretized angle–delay form"
)
with gr.Column():
# First row: controls (dropdowns / sliders / button)
with gr.Row():
ue_slider = gr.Slider(0, 512, value=0, step=1, label="UE index")
keep_slider = gr.Slider(0.1, 1.0, value=0.25, step=0.05, label="Keep percentage")
viz_button = gr.Button("Render visualizations")
# Second row: three figures
with gr.Row():
angle_delay_gif = gr.Image(
type="filepath",
label="Angle–delay GIF",
interactive=False,
height=320,
)
channel_gif = gr.Image(
type="filepath",
label="Channel GIF",
interactive=False,
height=320,
)
curve_png = gr.Image(
type="filepath",
label="Dominant bin traces",
interactive=False,
height=320,
)
# Third row: status / stats under the figures
viz_summary = gr.Textbox(lines=3, label="Status")
viz_button.click(
fn=visualize_angle_delay_action,
inputs=[scenario_state, ue_slider, keep_slider],
outputs=[angle_delay_gif, channel_gif, curve_png, viz_summary],
)
# Dynamically adapt the UE index slider max to the number of users
# (vehicles + pedestrians) chosen in the first tab.
def _update_ue_slider_from_counts(
vehicles: float,
pedestrians: float,
current_value: float,
) -> Any:
total_users = max(0, int(vehicles) + int(pedestrians))
# UE indices are 0-based, so the maximum valid index is total_users - 1.
max_idx = max(0, total_users - 1)
clamped_value = max(0, min(int(current_value), max_idx))
# Use generic gr.update to stay compatible with the Gradio version
# in your environment.
return gr.update(maximum=max_idx, value=clamped_value)
vehicle_input.change(
fn=_update_ue_slider_from_counts,
inputs=[vehicle_input, pedestrian_input, ue_slider],
outputs=ue_slider,
)
pedestrian_input.change(
fn=_update_ue_slider_from_counts,
inputs=[vehicle_input, pedestrian_input, ue_slider],
outputs=ue_slider,
)
with gr.Tab("Zero-Shot Channel Estimation"):
gr.Markdown(
"Peek under the hood of **physics-informed pretraining** with Masked Channel Modeling (MCM)\n\n"
"- LWM-Temporal builds on our earlier LWM 1.0 / 1.1 models and extends their **self-supervised MCM** pretraining to dynamic, angle–delay channels\n"
"- MCM hides parts of the channel tensor and asks the model to reconstruct them, forcing it to internalize the **underlying propagation physics** instead of just memorizing patterns\n"
"- Here, MCM becomes **physics-informed** via four complementary masking strategies (random, rectangular, tube, comb) that stress-test different aspects of the channel\n\n"
"Use this tab to run a quick **zero-shot reconstruction check** on a single batch after pretraining with our bidirectional SSTA attention"
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
mask_ratio_input = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="Mask ratio")
batch_size_input = gr.Slider(1, 32, value=8, step=1, label="Batch size")
max_batches_input = gr.Slider(1, 10, value=1, step=1, label="Max batches")
device_radio = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
masked_button = gr.Button("Run masked modeling")
with gr.Column(scale=1):
masked_output = gr.Textbox(lines=4, label="NMSE summary")
masked_button.click(
fn=masked_modeling_action,
inputs=[scenario_state, mask_ratio_input, batch_size_input, device_radio, max_batches_input],
outputs=masked_output,
)
with gr.Tab("Zero-Shot Channel Prediction"):
gr.Markdown(
"Turn LWM-Temporal into a **generative predictor** for dynamic channels\n\n"
"- Beyond understanding tasks, many downstream applications are **temporal and generative** (e.g., predicting future channels or strongest beams)\n"
"- This tab showcases **channel prediction**: given T_past observed channels, the model predicts the next T_horizon future channels\n"
"- For this task, we fine-tune the pretrained backbone with **causal SSTA**, so the model only attends to current and past frames—never to the future\n\n"
"Use this tab to inspect the **zero-shot prediction preformance** of the model on your selected scenario"
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
tpast_slider = gr.Slider(4, 12, value=10, step=1, label="Tpast")
horizon_slider = gr.Slider(1, 4, value=1, step=1, label="Horizon")
cp_device = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
train_limit_slider = gr.Slider(1, 200, value=1, step=1, label="Context pool (train split)")
val_limit_slider = gr.Slider(1, 100, value=1, step=1, label="Context pool (val split)")
channel_button = gr.Button("Run channel prediction inference")
with gr.Column(scale=1):
channel_summary = gr.Textbox(lines=4, label="Inference summary")
channel_image = gr.Image(
type="filepath",
label="Prediction visualization",
interactive=False,
height=260,
)
channel_button.click(
fn=channel_prediction_action,
inputs=[scenario_state, tpast_slider, horizon_slider, cp_device, train_limit_slider, val_limit_slider],
outputs=[channel_summary, channel_image],
)
with gr.Tab("Masking Modes"):
gr.Markdown(
"Toy example: Visualize the **four physics-informed masking strategies** used during pretraining\n\n"
"- **Random**: masks tokens randomly across the angle–delay grid\n"
"- **Rectangular**: masks contiguous rectangular regions of the grid\n"
"- **Tube**: masks tubes of tokens along the delay axis (capturing delay-local structures)\n"
"- **Comb**: masks comb-like patterns along the angle axis (capturing angular diversity)\n\n"
"Use this tab to see how each strategy sculpts the information given to the model and how robustly it can reconstruct the masked channels"
)
with gr.Row():
T_mask = gr.Slider(2, 12, value=6, step=1, label="Frames (T)")
H_mask = gr.Slider(4, 32, value=16, step=1, label="Angle bins (H)")
W_mask = gr.Slider(4, 32, value=16, step=1, label="Delay bins (W)")
with gr.Row():
mask_mode_dd = gr.Dropdown(
choices=["auto", "random", "rect", "tube", "comb"],
value="auto",
label="Masking mode",
)
gallery_ratio = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="Target mask ratio")
gallery_examples = gr.Slider(1, 4, value=3, step=1, label="Examples to draw")
gallery_button = gr.Button("Render mask gallery")
gallery_image = gr.Image(
type="filepath",
label="Mask gallery",
interactive=False,
height=260,
)
gallery_summary = gr.Textbox(lines=4, label="Coverage summary")
gallery_button.click(
fn=mask_gallery_action,
inputs=[T_mask, H_mask, W_mask, mask_mode_dd, gallery_ratio, gallery_examples],
outputs=[gallery_image, gallery_summary],
)
with gr.Tab("Sparse Spatio-Temporal Attention (SSTA)"):
gr.Markdown(
"Toy example: visualize how Sparse Spatio-Temporal Attention (SSTA) selects neighbors for a single query token and how it differs from full attention in efficiency and physics-informedness\n\n"
"- **Bidirectional SSTA** (pretraining): attends to both past and future frames to learn generic spatial–spectral–temporal dependencies\n"
"- **Causal SSTA** (generative downstream tasks): attends only to current and past frames (no future leaks) for generative tasks like channel prediction"
)
# Controls on top
with gr.Row(equal_height=True):
T = 8; H = 8; W = 8
with gr.Column(scale=1):
T_slider = gr.Slider(1, 11, value=T, step=1, label="Frames (T)")
H_slider = gr.Slider(4, 32, value=H, step=1, label="Angle bins (H)")
W_slider = gr.Slider(4, 32, value=W, step=1, label="Delay bins (W)")
with gr.Column(scale=1):
same_window_slider = gr.Slider(0, 3, value=1, step=1, label="Same-frame window")
temporal_window_slider = gr.Slider(0, 4, value=1, step=1, label="Temporal spatial window")
drift_slider = gr.Slider(0, 4, value=1, step=1, label="Temporal drift")
ssta_mode_radio = gr.Radio(
choices=[
"Bidirectional SSTA (pretraining)",
"Causal SSTA (generative downstream tasks)",
],
value="Bidirectional SSTA (pretraining)",
label="Attention mode",
)
with gr.Column(scale=1):
offsets_text = gr.Textbox(value="-2,-1,0,1,2", label="Temporal offsets (comma separated)")
query_t_slider = gr.Slider(0, T, value=3, step=1, label="Query t (time index)")
query_h_slider = gr.Slider(0, H, value=4, step=1, label="Query angle index (h)")
query_w_slider = gr.Slider(0, W, value=5, step=1, label="Query delay index (w)")
attention_button = gr.Button("Visualize attention neighborhood")
# Visualization underneath: SSTA on top, full attention below, summary at the end.
with gr.Row(equal_height=True):
with gr.Column(scale=2):
attention_ssta_image = gr.Image(
type="filepath",
label="SSTA neighborhood: O(N)",
interactive=False,
height=200,
)
attention_full_image = gr.Image(
type="filepath",
label="Full attention: O(N^2)",
interactive=False,
height=200,
)
attention_summary = gr.Textbox(lines=4, label="Neighborhood summary")
# Keep query sliders in-range whenever T, H, or W sliders change.
def _sync_query_t(T_val: float, current_q: float) -> Any:
max_idx = max(0, int(T_val))
clamped = max(0, min(int(current_q), max_idx))
return gr.update(maximum=max_idx, value=clamped)
def _sync_query_h(H_val: float, current_q: float) -> Any:
max_idx = max(0, int(H_val))
clamped = max(0, min(int(current_q), max_idx))
return gr.update(maximum=max_idx, value=clamped)
def _sync_query_w(W_val: float, current_q: float) -> Any:
max_idx = max(0, int(W_val))
clamped = max(0, min(int(current_q), max_idx))
return gr.update(maximum=max_idx, value=clamped)
T_slider.change(
fn=_sync_query_t,
inputs=[T_slider, query_t_slider],
outputs=query_t_slider,
)
H_slider.change(
fn=_sync_query_h,
inputs=[H_slider, query_h_slider],
outputs=query_h_slider,
)
W_slider.change(
fn=_sync_query_w,
inputs=[W_slider, query_w_slider],
outputs=query_w_slider,
)
attention_button.click(
fn=ssta_attention_action,
inputs=[
T_slider,
H_slider,
W_slider,
same_window_slider,
temporal_window_slider,
drift_slider,
ssta_mode_radio,
offsets_text,
query_t_slider,
query_h_slider,
query_w_slider,
],
outputs=[attention_ssta_image, attention_full_image, attention_summary],
)
gr.Markdown(
"""
---
## Notes
- Steps reuse cached artifacts on disk so you can iterate without re-running the entire pipeline.
- Heavy stages (channel prediction) run with conservative defaults. Increase limits as needed.
- All artifacts are written under `figs/lab`, `examples/data/lab`, and `checkpoints/`.
"""
)
return demo
demo = build_lab()
if __name__ == "__main__":
# demo.launch(
# share=True,
# server_name="0.0.0.0",
# server_port=7860,
# debug=False
# )
# demo.launch(
# share=True,
# server_name="0.0.0.0",
# server_port=7860,
# debug=False,
# show_error=True
# )
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=["/figs"]
)