Spaces:
Running
Running
| from __future__ import annotations | |
| import math | |
| import pickle | |
| import shutil | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import ListedColormap | |
| import imageio | |
| import numpy as np | |
| import plotly.graph_objects as go # type: ignore | |
| import deepmimo | |
| import subprocess | |
| import sys | |
| import os | |
| # Suppress all output from pip install | |
| subprocess.check_call([ | |
| sys.executable, "-m", "pip", "install", "--upgrade", "--quiet", "--no-cache-dir", | |
| "gradio>=6.0", "gradio_client>=2.0" | |
| ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| # Optional: also suppress the dumb --import attempts from HF | |
| os.environ["PYTHONWARNINGS"] = "ignore" | |
| import gradio as gr # type: ignore | |
| import torch # type: ignore | |
| from torch.utils.data import DataLoader # type: ignore | |
| # Install DeepMIMO if not already installed | |
| # try: | |
| # import deepmimo | |
| # except ImportError: | |
| # print("Installing DeepMIMO from GitHub...") | |
| # subprocess.check_call([ | |
| # sys.executable, "-m", "pip", "install", | |
| # "git+https://github.com/DeepMIMO/DeepMIMO.git" | |
| # ]) | |
| # import deepmimo | |
| # import sys, subprocess, sysconfig, pathlib | |
| # from pathlib import Path | |
| # def ensure_deepmimo(): | |
| # try: | |
| # import deepmimo # type: ignore # already installed and importable | |
| # return | |
| # except ModuleNotFoundError: | |
| # print("Installing DeepMIMO from GitHub...") | |
| # subprocess.check_call([ | |
| # sys.executable, | |
| # "-m", | |
| # "pip", | |
| # "install", | |
| # "git+https://github.com/DeepMIMO/DeepMIMO.git" | |
| # ]) | |
| # # ---- Patch DeepMIMO's dataset.py to avoid the union-typing bug ---- | |
| # site_pkgs = ( | |
| # Path(sys.executable).parent.parent | |
| # / "lib" | |
| # / f"python{sys.version_info.major}.{sys.version_info.minor}" | |
| # / "site-packages" | |
| # ) | |
| # dataset_py = site_pkgs / "deepmimo" / "generator" / "dataset.py" | |
| # if dataset_py.exists(): | |
| # txt = dataset_py.read_text() | |
| # # handle both variants just in case | |
| # before_1 = "def append(self, dataset: Dataset | 'MacroDataset'):" | |
| # before_2 = "def append(self, dataset: Dataset | MacroDataset):" | |
| # if before_1 in txt or before_2 in txt: | |
| # print("Patching DeepMIMO MacroDataset.append type annotation...") | |
| # txt = txt.replace(before_1, "def append(self, dataset):") | |
| # txt = txt.replace(before_2, "def append(self, dataset):") | |
| # dataset_py.write_text(txt) | |
| # import deepmimo # type: ignore | |
| # print("DeepMIMO imported successfully.") | |
| # # Call this once near the top of app.py | |
| # ensure_deepmimo() | |
| import builtins | |
| # Force DeepMIMO to never ask questions | |
| def no_input(*args, **kwargs): | |
| return "y" # auto-answer "yes" to everything | |
| builtins.input = no_input | |
| os.environ["DEEPMIMO_NO_INTERACTIVE"] = "1" | |
| # --------------------------------------------------------------------------- | |
| # Import LWMTemporal either from an installed package or from a private | |
| # Hugging Face repo (useful when this Space is separate from the main codebase) | |
| # --------------------------------------------------------------------------- | |
| try: | |
| from LWMTemporal.data.angle_delay import AngleDelayConfig, AngleDelayProcessor | |
| from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset | |
| from LWMTemporal.data.scenario_generation import ( | |
| AntennaArrayConfig, | |
| DynamicScenarioGenerator, | |
| GridConfig, | |
| ScenarioGenerationConfig, | |
| ScenarioSamplingConfig, | |
| TrafficConfig, | |
| ) | |
| from LWMTemporal.models.lwm import LWMBackbone, LWMConfig, NeighborIndexer, masked_nmse_loss | |
| from LWMTemporal.tasks.channel_prediction import ( | |
| ChannelPredictionArgs, | |
| ChannelPredictionTrainer, | |
| DatasetArgs, | |
| ModelArgs, | |
| PredictionArgs, | |
| TrainingArgs, | |
| compute_nmse, | |
| ) | |
| from LWMTemporal.tasks.pretraining import MaskArgs, MaskGenerator | |
| from LWMTemporal.utils.logging import setup_logging | |
| except ImportError: | |
| # Fallback: clone from a private HF repo and import from there. | |
| # Configure via environment: | |
| # LWM_TEMPORAL_REPO_ID (e.g. "your-username/lwm-temporal") | |
| # LWM_TEMPORAL_REVISION (optional branch/tag/commit) | |
| # HF_TOKEN must have access to the private repo. | |
| import os | |
| import sys | |
| from huggingface_hub import snapshot_download | |
| BASE_DIR = Path(__file__).resolve().parents[1] | |
| REMOTE_DIR = BASE_DIR / "_remote_lwm_temporal" | |
| REPO_ID = os.getenv("LWM_TEMPORAL_REPO_ID", "wi-lab/lwm-temporal") | |
| REVISION = os.getenv("LWM_TEMPORAL_REVISION") or None | |
| if not REMOTE_DIR.exists(): | |
| snapshot_download( | |
| repo_id=REPO_ID, | |
| revision=REVISION, | |
| repo_type="model", | |
| local_dir=REMOTE_DIR, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| # snapshot_download( | |
| # repo_id=REPO_ID, | |
| # revision=REVISION, | |
| # repo_type="model", | |
| # local_dir=REMOTE_DIR, | |
| # token=os.getenv("HF_TOKEN"), | |
| # allow_patterns=[ | |
| # "*.bin", | |
| # ], | |
| # ) | |
| sys.path.insert(0, str(REMOTE_DIR)) | |
| from LWMTemporal.data.angle_delay import AngleDelayConfig, AngleDelayProcessor | |
| from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset | |
| from LWMTemporal.data.scenario_generation import ( | |
| AntennaArrayConfig, | |
| DynamicScenarioGenerator, | |
| GridConfig, | |
| ScenarioGenerationConfig, | |
| ScenarioSamplingConfig, | |
| TrafficConfig, | |
| ) | |
| from LWMTemporal.models.lwm import LWMBackbone, LWMConfig, NeighborIndexer, masked_nmse_loss | |
| from LWMTemporal.tasks.channel_prediction import ( | |
| ChannelPredictionArgs, | |
| ChannelPredictionTrainer, | |
| DatasetArgs, | |
| ModelArgs, | |
| PredictionArgs, | |
| TrainingArgs, | |
| compute_nmse, | |
| ) | |
| from LWMTemporal.tasks.pretraining import MaskArgs, MaskGenerator | |
| from LWMTemporal.utils.logging import setup_logging | |
| # --------------------------------------------------------------------------- | |
| # DeepMIMO download helper (optional) | |
| # --------------------------------------------------------------------------- | |
| try: # first try the installed Python package name | |
| from deepmimo import download as deepmimo_download # type: ignore | |
| except Exception: # pragma: no cover | |
| try: # fallback to the vendored package path in this repo | |
| from DeepMIMO.deepmimo import download as deepmimo_download # type: ignore | |
| except Exception: # pragma: no cover | |
| deepmimo_download = None # type: ignore | |
| print("REMOTE_DIR =", REMOTE_DIR) | |
| for root, dirs, files in os.walk(REMOTE_DIR): | |
| print(root) | |
| for f in files: | |
| print(" -", f) | |
| ROOT = Path(__file__).resolve().parents[1] | |
| DATA_DIR = ROOT / "examples" / "data" / "lab" | |
| FULL_DATA_DIR = ROOT / "examples" / "full_data" / "lab" | |
| FIG_DIR = ROOT / "figs" / "lab" | |
| SCENARIO_DIR = ROOT / "deepmimo_scenarios" | |
| # CHECKPOINT_PATH = ROOT / "checkpoints" / "pytorch_model.bin" | |
| CHECKPOINT_PATH = Path(REMOTE_DIR / "checkpoints" / "pytorch_model.bin") | |
| if not CHECKPOINT_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"Pretrained model not found at {CHECKPOINT_PATH}\n" | |
| "Make sure your private repo 'wi-lab/lwm-temporal' contains checkpoints/pytorch_model.bin" | |
| ) | |
| CHANNEL_PRED_VIZ = FIG_DIR / "channel_prediction" | |
| ANGLE_DELAY_VIZ = FIG_DIR / "angle_delay" | |
| ENV_FIG_DIR = FIG_DIR / "environment" | |
| MASK_VIZ_DIR = FIG_DIR / "masking" | |
| ATTN_VIZ_DIR = FIG_DIR / "attention" | |
| for directory in [ | |
| DATA_DIR, | |
| FULL_DATA_DIR, | |
| FIG_DIR, | |
| CHANNEL_PRED_VIZ, | |
| ANGLE_DELAY_VIZ, | |
| ENV_FIG_DIR, | |
| MASK_VIZ_DIR, | |
| ATTN_VIZ_DIR, | |
| ]: | |
| directory.mkdir(parents=True, exist_ok=True) | |
| DEFAULT_SCENARIOS: List[str] = [ | |
| "asu_campus_3p5", | |
| "city_0_newyork_3p5", | |
| "city_1_losangeles_3p5", | |
| "city_2_chicago_3p5", | |
| "city_3_houston_3p5", | |
| "city_4_phoenix_3p5", | |
| "city_5_philadelphia_3p5", | |
| "city_6_miami_3p5", | |
| "city_7_sandiego_3p5", | |
| "city_8_dallas_3p5", | |
| "city_9_sanfrancisco_3p5", | |
| "city_10_austin_3p5", | |
| "city_11_santaclara_3p5", | |
| "city_12_fortworth_3p5", | |
| "city_13_columbus_3p5", | |
| "city_14_charlotte_3p5", | |
| "city_15_indianapolis_3p5", | |
| "city_16_sanfrancisco_3p5", | |
| "city_17_seattle_3p5", | |
| "city_18_denver_3p5", | |
| "city_19_oklahoma_3p5", | |
| "city_21_taito_city_3p5", | |
| "city_22_rome_3p5", | |
| "city_23_beijing_3p5", | |
| "city_24_cairo_3p5", | |
| "city_25_instanbul_3p5", | |
| "city_26_dubai_3p5", | |
| "city_27_rio_de_janeiro_3p5", | |
| "city_28_mumbai_3p5", | |
| "city_29_centro_3p5", | |
| "city_30_singapore_3p5", | |
| "city_31_barcelona_3p5", | |
| "city_32_sydney_3p5", | |
| "city_33_hong_kong_3p5", | |
| "city_34_amsterdam_3p5", | |
| "city_35_san_francisco_3p5", | |
| "city_36_bangkok_3p5", | |
| "city_37_seoul_3p5", | |
| "city_38_toronto_3p5", | |
| "city_39_jerusalem_3p5", | |
| "city_40_prague_3p5", | |
| "city_41_kyoto_3p5", | |
| "city_42_san_nicolas_3p5", | |
| "city_43_cape_town_3p5", | |
| "city_44_lisboa_3p5", | |
| "city_45_stockholm_3p5", | |
| "city_46_la_habana_3p5", | |
| "city_47_chicago_3p5", | |
| "city_48_gurbchen_3p5", | |
| "city_49_new_delhi_3p5", | |
| "city_50_edinburgh_3p5", | |
| "city_51_firenze_3p5", | |
| "city_52_marrakesh_3p5", | |
| "city_53_sankt-peterburg_3p5", | |
| "city_54_north_jakarta_3p5", | |
| "city_55_madrid_3p5", | |
| "city_56_montreal_3p5", | |
| "city_57_taipei_3p5", | |
| "city_58_sumida_city_3p5", | |
| "city_59_roma_3p5", | |
| "city_60_toronto_3p5", | |
| "city_61_hatsukaichi_3p5", | |
| "city_62_santiago_3p5", | |
| "city_63_athens_3p5", | |
| "city_64_new_york_3p5", | |
| "city_65_granada_3p5", | |
| "city_66_bruxelles_3p5", | |
| "city_67_fujiyoshida_3p5", | |
| "city_68_reykjavik_3p5", | |
| "city_69_warszawa_3p5", | |
| "city_70_stockholm_3p5", | |
| "city_71_helsinki_3p5", | |
| "city_72_capetown_3p5", | |
| "city_73_casablanca_3p5", | |
| "city_74_chiyoda_3p5", | |
| "city_75_dongcheng_3p5", | |
| "city_76_houston_3p5", | |
| "city_77_melbourne_3p5", | |
| "city_78_newtaipei_3p5", | |
| "city_79_parow_3p5", | |
| "city_80_philadelphia_3p5", | |
| "city_81_seoul_3p5", | |
| "city_82_tempe_3p5", | |
| "city_0_newyork_28", | |
| "city_1_losangeles_28", | |
| "city_2_chicago_28", | |
| "city_3_houston_28", | |
| "city_4_phoenix_28", | |
| "city_5_philadelphia_28", | |
| "city_6_miami_28", | |
| "city_7_sandiego_28", | |
| "city_8_dallas_28", | |
| "city_9_sanfrancisco_28", | |
| "city_10_austin_28", | |
| "city_11_santaclara_28", | |
| "city_12_fortworth_28", | |
| "city_13_columbus_28", | |
| "city_14_charlotte_28", | |
| "city_15_indianapolis_28", | |
| "city_16_sanfrancisco_28", | |
| "city_17_seattle_28", | |
| "city_18_denver_28", | |
| "city_19_oklahoma_28", | |
| "boston5g_3p5", | |
| "boston5g_28", | |
| "city_0_newyork_3p5_s", | |
| "city_1_losangeles_3p5_s", | |
| "city_2_chicago_3p5_s", | |
| "city_3_houston_3p5_s", | |
| "city_4_phoenix_3p5_s", | |
| "city_5_philadelphia_3p5_s", | |
| "city_6_miami_3p5_s", | |
| "city_7_sandiego_3p5_s", | |
| "city_8_dallas_3p5_s", | |
| "city_9_sanfrancisco_3p5_s", | |
| "city_10_austin_3p5_s", | |
| "city_11_santaclara_3p5_s", | |
| "city_12_fortworth_3p5_s", | |
| "city_13_columbus_3p5_s", | |
| "city_14_charlotte_3p5_s", | |
| "city_15_indianapolis_3p5_s", | |
| "city_16_sanfrancisco_3p5_s", | |
| "city_17_seattle_3p5_s", | |
| "city_18_denver_3p5_s", | |
| "city_19_oklahoma_3p5_s", | |
| "i1_2p5", | |
| "i2_28b", | |
| "o1_3p5", | |
| "o1_3p4", | |
| "o1b_3p5", | |
| "o1b_28", | |
| "o1_140", | |
| "i1_2p4", | |
| "o1_28", | |
| "o1_60", | |
| "o1_drone_200", | |
| ] | |
| # LWM-Temporal curated scenario subsets | |
| LWM_TEMPORAL_TRAIN_SCENARIOS: List[str] = [ | |
| "city_72_capetown_3p5", | |
| "city_73_casablanca_3p5", | |
| "city_74_chiyoda_3p5", | |
| "city_75_dongcheng_3p5", | |
| "city_76_houston_3p5", | |
| "city_77_melbourne_3p5", | |
| "city_78_newtaipei_3p5", | |
| "city_79_parow_3p5", | |
| "city_80_philadelphia_3p5", | |
| "city_81_seoul_3p5", | |
| ] | |
| LWM_TEMPORAL_TEST_SCENARIOS: List[str] = [ | |
| "city_82_tempe_3p5", | |
| ] | |
| DEEPMIMO_V4_SCENARIOS: List[str] = DEFAULT_SCENARIOS | |
| # General DeepMIMO v4 sub-families | |
| DEEPMIMO_SIONNA_SCENARIOS: List[str] = [ | |
| name for name in DEEPMIMO_V4_SCENARIOS if name.endswith("_s") | |
| ] | |
| _WI_CITY_PREFIXES = ("city", "asu", "boston") | |
| DEEPMIMO_WI_CITY_SCENARIOS: List[str] = [ | |
| name | |
| for name in DEEPMIMO_V4_SCENARIOS | |
| if not name.endswith("_s") and name.startswith(_WI_CITY_PREFIXES) | |
| ] | |
| DEEPMIMO_WI_CITY_3P5_SCENARIOS: List[str] = [ | |
| name for name in DEEPMIMO_WI_CITY_SCENARIOS if name.endswith("3p5") | |
| ] | |
| DEEPMIMO_WI_CITY_28_SCENARIOS: List[str] = [ | |
| name for name in DEEPMIMO_WI_CITY_SCENARIOS if name.endswith("28") | |
| ] | |
| DEEPMIMO_WI_OTHER_SCENARIOS: List[str] = [ | |
| name | |
| for name in DEEPMIMO_V4_SCENARIOS | |
| if not name.endswith("_s") and name not in DEEPMIMO_WI_CITY_SCENARIOS | |
| ] | |
| DEEPMIMO_GENERAL_SUBFAMILIES: List[str] = [ | |
| "All DeepMIMO v4 scenarios", | |
| "Sionna: City scenarios (3.5 GHz)", | |
| "Wireless InSite: City scenarios (3.5 GHz)", | |
| "Wireless InSite: City scenarios (28 GHz)", | |
| "Wireless InSite: Other scenarios", | |
| ] | |
| MOCK_CMAP = plt.get_cmap("magma") | |
| # --------------------------------------------------------------------------- | |
| # General helpers | |
| # --------------------------------------------------------------------------- | |
| def _resolve_device(choice: str) -> str: | |
| if choice == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return choice | |
| def _ensure_state(state: Optional[Dict[str, Any]]) -> Dict[str, Any]: | |
| if not state: | |
| raise gr.Error("Generate or load a scenario first.") | |
| return state | |
| def _load_payload(full_path: Path) -> Dict[str, Any]: | |
| with full_path.open("rb") as handle: | |
| payload = pickle.load(handle) | |
| if not isinstance(payload, dict): | |
| raise gr.Error("Expected a dictionary payload. Regenerate the scenario.") | |
| return payload | |
| def _summarize_payload(payload: Dict[str, Any]) -> str: | |
| if isinstance(payload, dict): | |
| channel = payload.get("channel_discrete") | |
| if channel is None: | |
| channel = payload.get("channel") | |
| else: | |
| channel = payload | |
| if channel is None: | |
| return "Payload missing channel tensor." | |
| shape = tuple(channel.shape) | |
| los = payload.get("los") if isinstance(payload, dict) else None | |
| los_classes = "" | |
| if los is not None: | |
| unique, counts = torch.unique(torch.as_tensor(los), return_counts=True) | |
| los_classes = ", ".join(f"{int(u.item())}: {int(c.item())}" for u, c in zip(unique, counts)) | |
| keys = ", ".join(sorted(payload.keys())) if isinstance(payload, dict) else "n/a" | |
| return ( | |
| f"Channel tensor shape: {shape}\n" | |
| f"Payload keys: {keys}\n" | |
| + (f"LoS distribution: {los_classes}\n" if los_classes else "") | |
| ) | |
| def _angle_delay_paths(run_tag: str) -> Tuple[Path, Path, Path]: | |
| run_dir = ANGLE_DELAY_VIZ / run_tag | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| return ( | |
| run_dir / "angle_delay.gif", | |
| run_dir / "channel.gif", | |
| run_dir / "bins.png", | |
| ) | |
| def _save_angle_delay_gif_lab( | |
| tensor: torch.Tensor, | |
| output_path: Path, | |
| fps: int = 6, | |
| pause_seconds: float = 0.7, | |
| ) -> None: | |
| """Save an angle–delay GIF using the same style as the notebook helper. | |
| We re-implement the non-interactive branch here so that Gradio receives a | |
| true multi-frame GIF instead of a single-frame PNG. | |
| """ | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| tensor = tensor.to(torch.complex64) | |
| magnitude = tensor.abs().cpu() | |
| vmin, vmax = float(magnitude.min()), float(magnitude.max()) | |
| frames: List[np.ndarray] = [] | |
| for frame_idx in range(magnitude.size(0)): | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| fig.patch.set_facecolor("#0b0e11") | |
| ax.set_facecolor("#0b0e11") | |
| ax.tick_params(colors="#cbd5f5") | |
| for spine in ax.spines.values(): | |
| spine.set_color("#374151") | |
| im = ax.imshow( | |
| magnitude[frame_idx].numpy(), | |
| cmap="magma", | |
| origin="lower", | |
| aspect="auto", | |
| vmin=vmin, | |
| vmax=vmax, | |
| ) | |
| ax.set_xlabel("Delay bins", color="#cbd5f5") | |
| ax.set_ylabel("Angle bins", color="#cbd5f5") | |
| cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| cbar.ax.yaxis.set_tick_params(color="#cbd5f5") | |
| plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") | |
| cbar.set_label("|H| (dB)", color="#cbd5f5") | |
| ax.set_title( | |
| f"Angle-Delay Intensity — Frame {frame_idx}", | |
| color="#f8fafc", | |
| fontsize=12, | |
| fontweight="semibold", | |
| ) | |
| fig.canvas.draw() | |
| frames.append(np.asarray(fig.canvas.buffer_rgba())) | |
| plt.close(fig) | |
| # Duplicate the last frame so the animation appears to "pause" before looping. | |
| # This is more robust across different GIF viewers than relying on per-frame durations. | |
| extra_frames = max(1, int(round(max(pause_seconds, 0.0) * max(fps, 1)))) | |
| if frames and extra_frames > 0: | |
| frames.extend([frames[-1]] * extra_frames) | |
| # Ensure the GIF loops indefinitely (loop=0 is "infinite" in GIF metadata). | |
| imageio.mimsave(output_path, frames, fps=fps, loop=0) | |
| def _add_pause_to_gif( | |
| gif_path: Path, | |
| fps: int = 5, | |
| pause_seconds: float = 0.7, | |
| ) -> None: | |
| """Post-process an existing GIF to add a small pause at the end of each loop. | |
| Implemented by duplicating the last frame a few times, which works even when | |
| viewers ignore per-frame duration metadata. | |
| """ | |
| gif_path = Path(gif_path) | |
| if not gif_path.exists(): | |
| return | |
| try: | |
| frames = imageio.mimread(gif_path) | |
| except Exception: | |
| return | |
| if not frames: | |
| return | |
| extra_frames = max(1, int(round(max(pause_seconds, 0.0) * max(fps, 1)))) | |
| if extra_frames > 0: | |
| frames.extend([frames[-1]] * extra_frames) | |
| imageio.mimsave(gif_path, frames, fps=fps, loop=0) | |
| def _build_dataset(raw_path: Path, keep_percentage: float) -> AngleDelaySequenceDataset: | |
| cfg = AngleDelayDatasetConfig( | |
| raw_path=raw_path, | |
| keep_percentage=keep_percentage, | |
| normalize="per_sample_rms", | |
| cache_dir=ROOT / "cache", | |
| patch_size=(1, 1), | |
| phase_mode="real_imag", | |
| ) | |
| return AngleDelaySequenceDataset(cfg) | |
| def _refresh_scenarios() -> Tuple[List[str], str]: | |
| names = DEFAULT_SCENARIOS if DEFAULT_SCENARIOS else ["asu_campus_3p5"] | |
| return names, "Loaded static scenario list." | |
| def _ensure_scenario_available(scenario: str, auto_download: bool) -> None: | |
| scenario_dir = SCENARIO_DIR / scenario | |
| if scenario_dir.exists(): | |
| return | |
| if not auto_download: | |
| raise gr.Error( | |
| f"Scenario '{scenario}' not found under '{SCENARIO_DIR}'. Enable the download toggle to fetch it automatically." | |
| ) | |
| if deepmimo_download is None: | |
| raise gr.Error( | |
| "DeepMIMO download helper is unavailable. Install DeepMIMO dependencies or place the scenario manually." | |
| ) | |
| try: | |
| deepmimo_download(scenario, output_dir=str(SCENARIO_DIR)) | |
| except Exception as exc: | |
| raise gr.Error(f"Failed to download scenario '{scenario}': {exc}") from exc | |
| # --------------------------------------------------------------------------- | |
| # Angle-delay helpers (ported from examples/ad_temporal_evolutiton.py) | |
| # --------------------------------------------------------------------------- | |
| # def pick_bins(sequence: torch.Tensor, k: int) -> List[Tuple[int, int]]: | |
| # if sequence.ndim != 3: | |
| # raise ValueError(f"Expected angle-delay tensor (T, H, W); got {tuple(sequence.shape)}") | |
| # _, H, W = sequence.shape | |
| # mag = sequence.abs().mean(dim=0) | |
| # topk = torch.topk(mag.flatten(), k=min(k, H * W)) | |
| # picks: List[Tuple[int, int]] = [] | |
| # for idx in topk.indices.tolist(): | |
| # n = idx // W | |
| # m = idx % W | |
| # picks.append((n, m)) | |
| # return picks | |
| # def plot_curves(sequence: torch.Tensor, picks: List[Tuple[int, int]], out_path: Path, title: str) -> None: | |
| # if not picks: | |
| # raise ValueError("No bins selected for plotting.") | |
| # T = sequence.size(0) | |
| # fig, axes = plt.subplots(len(picks), 2, figsize=(12, 3 * len(picks)), dpi=150) | |
| # axes = np.atleast_2d(axes) | |
| # times = np.arange(T) | |
| # for row, (n, m) in enumerate(picks): | |
| # series = sequence[:, n, m] | |
| # mag = series.abs().cpu().numpy() | |
| # phase = np.unwrap(torch.angle(series).cpu().numpy()) | |
| # ax_mag, ax_phase = axes[row] | |
| # ax_mag.imshow(mag[None, :], aspect="auto", cmap="magma") | |
| # ax_mag.set_title(f"|H| for bin (n={n}, m={m})") | |
| # ax_mag.set_yticks([]) | |
| # ax_mag.set_xlabel("time index") | |
| # ax_phase.plot(times, phase, color="#f87171") | |
| # ax_phase.set_title("Phase (rad)") | |
| # ax_phase.set_xlabel("time index") | |
| # ax_phase.grid(True, alpha=0.3) | |
| # fig.suptitle(title) | |
| # out_path.parent.mkdir(parents=True, exist_ok=True) | |
| # fig.tight_layout() | |
| # fig.savefig(out_path) | |
| # plt.close(fig) | |
| # --------------------------------------------------------------------------- | |
| # Scenario generation + visualization | |
| # --------------------------------------------------------------------------- | |
| def generate_scenario_action( | |
| prev_state: Optional[Dict[str, Any]], | |
| scenario: str, | |
| time_steps: int, | |
| sample_dt: float, | |
| vehicles: int, | |
| pedestrians: int, | |
| turn_prob: float, | |
| vehicle_speed_max: float, | |
| road_width: float, | |
| road_spacing: float, | |
| keep_percentage: float, | |
| overwrite: bool, | |
| auto_download: bool, | |
| ) -> Tuple[str, str, Optional[str], Optional[go.Figure], Dict[str, Any]]: | |
| if not scenario: | |
| raise gr.Error("Select or enter a scenario name, then try again. Click 'Refresh scenario list' if needed.") | |
| progress = gr.Progress(track_tqdm=False) | |
| progress(0.05, desc="Configuring scenario") | |
| _ensure_scenario_available(scenario, auto_download) | |
| signature = "|".join( | |
| map( | |
| str, | |
| [ | |
| scenario, | |
| time_steps, | |
| sample_dt, | |
| vehicles, | |
| pedestrians, | |
| turn_prob, | |
| vehicle_speed_max, | |
| road_width, | |
| road_spacing, | |
| keep_percentage, | |
| ], | |
| ) | |
| ) | |
| prev_signature = (prev_state or {}).get("signature") | |
| # Always regenerate the scenario to avoid reusing cached payloads. | |
| should_overwrite = True | |
| env_dir = ENV_FIG_DIR / scenario | |
| config = ScenarioGenerationConfig( | |
| scenario=scenario, | |
| antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32), | |
| sampling=ScenarioSamplingConfig(time_steps=time_steps, sample_dt=sample_dt), | |
| traffic=TrafficConfig( | |
| num_vehicles=vehicles, | |
| num_pedestrians=pedestrians, | |
| turn_probability=turn_prob, | |
| vehicle_speed_range=(0 / 3.6, vehicle_speed_max / 3.6), | |
| ), | |
| grid=GridConfig( | |
| road_width=float(road_width), | |
| road_center_spacing=float(road_spacing), | |
| ), | |
| output_dir=DATA_DIR, | |
| full_output_dir=FULL_DATA_DIR, | |
| figures_dir=env_dir, | |
| scenarios_dir=SCENARIO_DIR, | |
| deepmimo_max_paths=25, | |
| ) | |
| generator = DynamicScenarioGenerator(config) | |
| progress(0.2, desc="Running generator") | |
| result = generator.generate(overwrite=should_overwrite) | |
| payload = result.payload | |
| summary = ( | |
| f"{'Generated' if result.generated else 'Loaded cached'} scenario **{scenario}**\n" | |
| f"- time steps: {time_steps}\n" | |
| f"- vehicles/pedestrians: {vehicles}/{pedestrians}\n" | |
| f"- output: `{result.output_path.name}`\n" | |
| ) | |
| stats = _summarize_payload(payload) | |
| # Static PNG environment figure (from the library helper). | |
| env_path = env_dir / "environment.png" | |
| env_display_path: Optional[str] = None | |
| if env_path.exists(): | |
| stamped = ENV_FIG_DIR / f"{scenario}_env_{int(time.time())}.png" | |
| try: | |
| shutil.copy(env_path, stamped) | |
| env_display_path = str(stamped) | |
| except Exception: | |
| env_display_path = str(env_path) | |
| # Interactive "map-like" environment figure using Plotly (zoom & pan + selection). | |
| env_plot: Optional[go.Figure] = None | |
| try: | |
| pos_cont = payload.get("pos_cont") | |
| x_vals: Optional[np.ndarray] = None | |
| y_vals: Optional[np.ndarray] = None | |
| # Continuous positions | |
| if pos_cont is not None: | |
| arr = np.asarray(pos_cont) | |
| if arr.ndim >= 2 and arr.shape[-1] >= 2: | |
| x_vals = arr[..., 0].reshape(-1) | |
| y_vals = arr[..., 1].reshape(-1) | |
| # Fallback to discrete trajectories | |
| if x_vals is None or y_vals is None: | |
| pos_disc = payload.get("pos_discrete") | |
| if pos_disc is not None: | |
| all_pts: List[np.ndarray] = [] | |
| for traj in pos_disc: | |
| t_arr = np.asarray(traj) | |
| if t_arr.ndim >= 2 and t_arr.shape[-1] >= 2: | |
| all_pts.append(t_arr[..., :2].reshape(-1, 2)) | |
| if all_pts: | |
| stacked = np.concatenate(all_pts, axis=0) | |
| x_vals = stacked[:, 0] | |
| y_vals = stacked[:, 1] | |
| if x_vals is not None and y_vals is not None and x_vals.size: | |
| env_plot = go.Figure() | |
| env_plot.add_trace( | |
| go.Scattergl( | |
| x=x_vals, | |
| y=y_vals, | |
| mode="markers", | |
| marker=dict( | |
| size=4, | |
| opacity=0.7, | |
| color="#38bdf8", # cyan-ish blue | |
| line=dict(width=0), | |
| ), | |
| name="Users / tracks", | |
| hovertemplate="x: %{x:.2f} m<br>y: %{y:.2f} m<extra></extra>", | |
| # Make box/lasso selection visually obvious: | |
| selected=dict( | |
| marker=dict( | |
| size=6, | |
| opacity=1.0, | |
| color="#f97316", # orange for selected | |
| ) | |
| ), | |
| unselected=dict( | |
| marker=dict( | |
| opacity=0.15, | |
| ) | |
| ), | |
| ) | |
| ) | |
| env_plot.update_layout( | |
| # ⬇⬇⬇ your requested title dict added here | |
| # title=dict( | |
| # text="Environment overview", | |
| # x=0.02, # left-align | |
| # xanchor="left", | |
| # y=0.96, | |
| # yanchor="top", | |
| # font=dict(size=18, color="#f9fafb"), | |
| # ), | |
| template="plotly_dark", | |
| font=dict( | |
| family="Inter, system-ui, -apple-system, BlinkMacSystemFont, sans-serif", | |
| size=14, | |
| color="#e5e7eb", | |
| ), | |
| paper_bgcolor="#020617", | |
| plot_bgcolor="#020617", | |
| margin=dict(l=40, r=20, t=50, b=40), | |
| height=550, | |
| hovermode="closest", | |
| # Use "zoom" (default) so zoom/box-select/lasso all work via toolbar | |
| dragmode="zoom", | |
| showlegend=False, | |
| ) | |
| env_plot.update_xaxes( | |
| title="x (m)", | |
| showgrid=True, | |
| gridcolor="rgba(148, 163, 184, 0.25)", | |
| zeroline=False, | |
| showline=True, | |
| linecolor="rgba(148, 163, 184, 0.6)", | |
| mirror=True, | |
| scaleanchor="y", | |
| scaleratio=1, | |
| ) | |
| env_plot.update_yaxes( | |
| title="y (m)", | |
| showgrid=True, | |
| gridcolor="rgba(148, 163, 184, 0.25)", | |
| zeroline=False, | |
| showline=True, | |
| linecolor="rgba(148, 163, 184, 0.6)", | |
| mirror=True, | |
| ) | |
| except Exception: | |
| env_plot = None | |
| state = { | |
| "scenario": scenario, | |
| "slim_path": str(result.output_path), | |
| "full_path": str(result.full_output_path), | |
| "keep_percentage": keep_percentage, | |
| "time_steps": time_steps, | |
| "signature": signature, | |
| } | |
| return summary, stats, env_display_path, env_plot, state | |
| def visualize_angle_delay_action( | |
| state: Optional[Dict[str, Any]], | |
| ue_index: int, | |
| keep_percentage: float, | |
| ) -> Tuple[Optional[str], Optional[str], Optional[str], str]: | |
| info = _ensure_state(state) | |
| payload = _load_payload(Path(info["full_path"])) | |
| discrete = torch.as_tensor(payload["channel_discrete"]) | |
| total = discrete.size(0) | |
| idx = max(0, min(total - 1, int(ue_index))) | |
| processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=keep_percentage)) | |
| run_tag = f"{info['scenario']}_ue{idx}_{int(time.time())}" | |
| angle_delay_path, channel_path, curves_path = _angle_delay_paths(run_tag) | |
| tensor_discrete = torch.as_tensor(payload["channel_discrete"][idx]) | |
| tensor_cont = torch.as_tensor(payload["channel_cont"][idx]) | |
| volume = processor.forward(tensor_discrete) | |
| volume_trimmed, _ = processor.truncate_delay_bins(volume) | |
| # Save GIF using the same logic as the notebook, but with an explicit GIF writer for Gradio | |
| _save_angle_delay_gif_lab(volume_trimmed, angle_delay_path, fps=6, pause_seconds=0.7) | |
| # Channel animation from the package helper | |
| processor.save_channel_animation(tensor_cont, channel_path, show=False) | |
| # Add a small pause between loops for the channel GIF as well. | |
| _add_pause_to_gif(channel_path, fps=6, pause_seconds=0.7) | |
| # Dominant-bin evolution plot from the package helper for exact matching with the library | |
| processor.save_bin_evolution_plot(tensor_discrete, curves_path, show=False) | |
| summary = ( | |
| f"Rendered UE {idx} / {total}.\n" | |
| f"Angle-delay GIF: `{angle_delay_path.name}` | Channel GIF: `{channel_path.name}`" | |
| ) | |
| return str(angle_delay_path), str(channel_path), str(curves_path), summary | |
| # --------------------------------------------------------------------------- | |
| # Masked modeling sanity check | |
| # --------------------------------------------------------------------------- | |
| def masked_modeling_action( | |
| state: Optional[Dict[str, Any]], | |
| mask_ratio: float, | |
| batch_size: int, | |
| device_choice: str, | |
| max_batches: int, | |
| ) -> str: | |
| info = _ensure_state(state) | |
| raw_path = Path(info["slim_path"]) | |
| if not raw_path.exists(): | |
| raise gr.Error(f"Dataset not found at {raw_path}") | |
| if not CHECKPOINT_PATH.exists(): | |
| raise gr.Error(f"Checkpoint {CHECKPOINT_PATH} missing. Train or copy a backbone first.") | |
| dataset = _build_dataset(raw_path, info["keep_percentage"]) | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False) | |
| if not len(dataset): | |
| raise gr.Error("Dataset is empty. Regenerate the scenario.") | |
| example_shape = dataset[0]["shape"] | |
| T, H, W = (int(example_shape[dim].item()) for dim in range(3)) | |
| tokens_per_sample = T * H * W | |
| device = torch.device(_resolve_device(device_choice)) | |
| cfg = LWMConfig( | |
| patch_size=(1, 1), | |
| phase_mode="real_imag", | |
| embed_dim=32, | |
| depth=12, | |
| num_heads=8, | |
| mlp_ratio=4.0, | |
| same_frame_window=2, | |
| temporal_offsets=(-4, -3, -2, -1, 1, 2, 3), | |
| temporal_spatial_window=2, | |
| temporal_drift_h=1, | |
| temporal_drift_w=1, | |
| routing_topk_enable=True, | |
| topk_per_head=True, | |
| max_seq_len=tokens_per_sample, | |
| ) | |
| backbone = LWMBackbone.from_pretrained(CHECKPOINT_PATH, config=cfg).to(device) | |
| backbone.eval() | |
| masker = MaskGenerator(MaskArgs(mask_ratio=mask_ratio, mask_mode="random")) | |
| total_nmse = 0.0 | |
| total_items = 0 | |
| processed = 0 | |
| with torch.no_grad(): | |
| for batch in loader: | |
| tokens = batch["tokens"].to(device) | |
| base_mask = batch["base_mask"].to(device) | |
| seq = batch["sequence"].to(device) | |
| T, H, W = seq[0].shape | |
| B, _, _ = tokens.shape | |
| mask = torch.stack([masker(T, H, W, device).view(-1) for _ in range(B)]) | |
| mask = torch.logical_or(mask, base_mask) | |
| corrupted = tokens.masked_fill(mask.unsqueeze(-1), 0.0) | |
| outputs = backbone.forward_tokens(corrupted, mask, T, H, W, return_cls=False) | |
| reconstructed = outputs["reconstruction"] | |
| nmse = masked_nmse_loss(reconstructed, tokens, mask) | |
| total_nmse += nmse.item() * B | |
| total_items += B | |
| processed += 1 | |
| if processed >= max_batches: | |
| break | |
| avg_nmse = total_nmse / max(1, total_items) | |
| avg_nmse_db = 10.0 * math.log10(max(avg_nmse, 1e-12)) | |
| return f"Processed {processed} batch(es). NMSE={avg_nmse:.4e} ({avg_nmse_db:.2f} dB)" | |
| # --------------------------------------------------------------------------- | |
| # Channel prediction inference | |
| # --------------------------------------------------------------------------- | |
| def _evaluate_channel_prediction(trainer: ChannelPredictionTrainer, loader: DataLoader) -> Tuple[float, float]: | |
| trainer.model.eval() | |
| losses: List[float] = [] | |
| nmses: List[float] = [] | |
| with torch.no_grad(): | |
| for batch in loader: | |
| tokens, _, H, W = trainer._prepare_batch(batch) | |
| preds, target, mask = trainer.engine.autoregressive_rollout( | |
| trainer.model, | |
| tokens, | |
| trainer.args.prediction.Tpast, | |
| trainer.args.prediction.horizon, | |
| H, | |
| W, | |
| ) | |
| loss = trainer._compute_loss(preds, target, mask) | |
| losses.append(loss.item()) | |
| nmses.append(compute_nmse(preds, target, mask)) | |
| avg_loss = float(sum(losses) / max(1, len(losses))) | |
| avg_nmse = float(sum(nmses) / max(1, len(nmses))) | |
| return avg_loss, avg_nmse | |
| def channel_prediction_action( | |
| state: Optional[Dict[str, Any]], | |
| tpast: int, | |
| horizon: int, | |
| device_choice: str, | |
| train_limit: int, | |
| val_limit: int, | |
| ) -> Tuple[str, Optional[str]]: | |
| info = _ensure_state(state) | |
| if not CHECKPOINT_PATH.exists(): | |
| raise gr.Error(f"Checkpoint {CHECKPOINT_PATH} missing. Train or copy a backbone first.") | |
| dataset_args = DatasetArgs( | |
| data_path=Path(info["slim_path"]), | |
| keep_percentage=info["keep_percentage"], | |
| normalize="per_sample_rms", | |
| seed=0, | |
| train_limit=train_limit, | |
| val_limit=val_limit, | |
| ) | |
| pred_viz_dir = CHANNEL_PRED_VIZ / f"{info['scenario']}_{int(time.time())}" | |
| model_args = ModelArgs( | |
| patch_size=(1, 1), | |
| phase_mode="real_imag", | |
| embed_dim=32, | |
| depth=12, | |
| num_heads=8, | |
| mlp_ratio=4.0, | |
| same_frame_window=2, | |
| temporal_offsets=tuple(range(-tpast, 0)), | |
| temporal_spatial_window=2, | |
| temporal_drift_h=1, | |
| temporal_drift_w=1, | |
| routing_topk_enable=True, | |
| routing_topk_fraction=0.2, | |
| routing_topk_max=32, | |
| pretrained=CHECKPOINT_PATH, | |
| ) | |
| training_args = TrainingArgs( | |
| device=_resolve_device(device_choice), | |
| epochs=1, | |
| batch_size=1, | |
| lr=1e-4, | |
| weight_decay=1e-4, | |
| inference_only=True, | |
| inference_split="val", | |
| use_dataparallel=False, | |
| verbose_inference=False, | |
| save_dir=ROOT / "models", | |
| ) | |
| prediction_args = PredictionArgs( | |
| Tpast=tpast, | |
| horizon=horizon, | |
| viz_dir=pred_viz_dir, | |
| num_visual_samples=1, | |
| ) | |
| logger = setup_logging("lab_channel_prediction", log_dir=ROOT / "logs" / "lab") | |
| trainer = ChannelPredictionTrainer( | |
| ChannelPredictionArgs( | |
| dataset=dataset_args, | |
| model=model_args, | |
| training=training_args, | |
| prediction=prediction_args, | |
| ), | |
| logger=logger, | |
| ) | |
| val_loader = trainer.data.val_loader(batch_size=1) or trainer.data.train_loader(batch_size=1) | |
| if val_loader is None: | |
| raise gr.Error("Dataset too small for evaluation.") | |
| loss, nmse = _evaluate_channel_prediction(trainer, val_loader) | |
| sample_batch = next(iter(val_loader)) | |
| tokens, _, H, W = trainer._prepare_batch(sample_batch) | |
| trainer.viz.save(trainer.model, tokens, H, W, trainer.args.prediction) | |
| viz_frames = sorted(pred_viz_dir.glob("sample_*.png")) | |
| viz_path = str(viz_frames[0]) if viz_frames else None | |
| nmse_db = 10.0 * math.log10(max(nmse, 1e-12)) | |
| summary = f"Eval batches={len(val_loader)} | loss={loss:.4e} | NMSE={nmse:.4e} ({nmse_db:.2f} dB)" | |
| return summary, viz_path | |
| # --------------------------------------------------------------------------- | |
| # Masking gallery (toy visualization for different modes) | |
| # --------------------------------------------------------------------------- | |
| def mask_gallery_action( | |
| T: int, | |
| H: int, | |
| W: int, | |
| mask_mode: str, | |
| mask_ratio: float, | |
| num_examples: int, | |
| ) -> Tuple[Optional[str], str]: | |
| T = max(1, min(12, int(T))) | |
| H = max(1, min(32, int(H))) | |
| W = max(1, min(32, int(W))) | |
| if mask_mode not in {"auto", "random", "rect", "tube", "comb"}: | |
| raise gr.Error("Unknown mask mode. Choose from auto, random, rect, tube, comb.") | |
| num_examples = max(1, min(8, int(num_examples))) | |
| device = torch.device("cpu") | |
| generator = MaskGenerator(MaskArgs(mask_ratio=mask_ratio, mask_mode=mask_mode, random_fraction=0.3)) | |
| # Draw several random samples to estimate coverage, but visualize a single example | |
| first_mask: Optional[np.ndarray] = None | |
| for _ in range(num_examples): | |
| mask_tensor = generator(T, H, W, device=device).cpu() | |
| mask_np = mask_tensor.numpy().astype(float) # 1 = masked | |
| if first_mask is None: | |
| first_mask = mask_np | |
| if first_mask is None: | |
| raise gr.Error("Failed to generate masks for visualization.") | |
| # Visualization: sequence of angle–delay channels in dark mode with only masked tokens colored | |
| cmap = ListedColormap( | |
| [ | |
| "#020617", # 0: background (unmasked) | |
| "#1d4ed8", # 1: masked (deep blue) | |
| ] | |
| ) | |
| fig, axes = plt.subplots(1, T, figsize=(3 * T, 3), dpi=180) | |
| fig.patch.set_facecolor("#020617") | |
| axes = np.atleast_1d(axes) | |
| mask_seq = first_mask.astype(float) # (T, H, W) | |
| for t in range(T): | |
| ax = axes[t] | |
| frame = mask_seq[t] # (H, W), 1 = masked, 0 = visible/white | |
| ax.set_facecolor("#020617") | |
| ax.imshow( | |
| frame, | |
| cmap=cmap, | |
| origin="lower", | |
| aspect="auto", | |
| vmin=0.0, | |
| vmax=1.0, | |
| ) | |
| # brighter grid to emphasize bins on dark background | |
| ax.set_xticks(np.arange(-0.5, W, 1), minor=True) | |
| ax.set_yticks(np.arange(-0.5, H, 1), minor=True) | |
| ax.grid(which="minor", color="#6b7280", linestyle="-", linewidth=0.8, alpha=0.9) | |
| ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) | |
| ax.set_title(f"t = {t}", fontsize=9, pad=6, color="#e5e7eb") | |
| avg_cov = float(mask_seq.mean()) | |
| fig.suptitle( | |
| f"Mask mode = {mask_mode} | T={T}, H={H}, W={W}, target≈{mask_ratio:.2f}, " | |
| f"observed≈{avg_cov:.2f} over {num_examples} samples", | |
| fontsize=11, | |
| color="#f9fafb", | |
| ) | |
| fig.tight_layout(rect=(0, 0, 1, 0.95)) | |
| out_path = MASK_VIZ_DIR / f"mask_gallery_{int(time.time())}.png" | |
| fig.savefig(out_path, dpi=180) | |
| plt.close(fig) | |
| summary = f"Mode: {mask_mode} | average masked tokens ≈ {avg_cov * 100:.1f}% over {num_examples} samples." | |
| return str(out_path), summary | |
| # --------------------------------------------------------------------------- | |
| # SSTA vs full attention visualization | |
| # --------------------------------------------------------------------------- | |
| def _token_to_indices(token_idx: int, H: int, W: int) -> Tuple[int, int, int]: | |
| t = token_idx // (H * W) | |
| rem = token_idx % (H * W) | |
| h = rem // W | |
| w = rem % W | |
| return t, h, w | |
| def ssta_attention_action( | |
| T: int, | |
| H: int, | |
| W: int, | |
| same_window: int, | |
| temporal_window: int, | |
| drift: int, | |
| ssta_mode: str, | |
| offsets_text: str, | |
| query_t: int, | |
| query_h: int, | |
| query_w: int, | |
| ) -> Tuple[Optional[str], Optional[str], str]: | |
| T = max(1, min(12, int(T))) | |
| H = max(1, min(32, int(H))) | |
| W = max(1, min(32, int(W))) | |
| same_window = max(0, min(3, int(same_window))) | |
| temporal_window = max(0, min(4, int(temporal_window))) | |
| drift = max(0, min(4, int(drift))) | |
| try: | |
| offsets = [int(x.strip()) for x in offsets_text.split(",") if x.strip()] | |
| except ValueError as exc: | |
| raise gr.Error("Offsets must be comma-separated integers.") from exc | |
| # Normalize attention mode | |
| mode_str = (ssta_mode or "").lower() | |
| is_causal = "causal" in mode_str | |
| if is_causal: | |
| # Causal: query only attends to current/past frames (non-positive offsets). | |
| offsets = [o for o in offsets if o <= 0] | |
| if 0 not in offsets: | |
| offsets.append(0) | |
| if not offsets: | |
| offsets = [0, -1] | |
| else: | |
| # Bidirectional: allow both past and future context, but drop pure zero. | |
| offsets = [o for o in offsets if o != 0] | |
| if not offsets: | |
| offsets = [-1, 1] | |
| offsets = sorted(set(offsets)) | |
| config = LWMConfig( | |
| patch_size=(1, 1), | |
| phase_mode="real_imag", | |
| embed_dim=32, | |
| depth=1, | |
| num_heads=1, | |
| mlp_ratio=2.0, | |
| same_frame_window=same_window, | |
| temporal_offsets=tuple(offsets), | |
| temporal_spatial_window=temporal_window, | |
| temporal_drift_h=drift, | |
| temporal_drift_w=drift, | |
| routing_topk_enable=False, | |
| spatial_only=False, | |
| ) | |
| indexer = NeighborIndexer() | |
| neighbors = indexer.get(T, H, W, False, config, torch.device("cpu")) | |
| seq_len = T * H * W | |
| qt = max(0, min(T - 1, int(query_t))) | |
| qh = max(0, min(H - 1, int(query_h))) | |
| qw = max(0, min(W - 1, int(query_w))) | |
| q_idx = qt * H * W + qh * W + qw | |
| local = neighbors[q_idx] | |
| mask = torch.zeros(seq_len, dtype=torch.bool) | |
| valid = local[local >= 0] | |
| mask[valid] = True | |
| mask_3d = mask.view(T, H, W).numpy().astype(bool) | |
| # Synthetic angle–delay sequence just for visualization | |
| t_axis = np.linspace(0.0, 1.0, T, endpoint=True) | |
| h_axis = np.linspace(0.0, 1.0, H, endpoint=True) | |
| w_axis = np.linspace(0.0, 1.0, W, endpoint=True) | |
| base = np.outer(h_axis, w_axis) # (H, W) | |
| frames = np.stack([base * (0.5 + 0.5 * t) for t in t_axis], axis=0) # (T, H, W) | |
| # Dark-mode overlay colors: dark background, neighbors cyan, query vivid orange | |
| neighbor_cmap = ListedColormap( | |
| [ | |
| "#020617", # 0: background (very dark) | |
| "#22d3ee", # 1: neighbors | |
| "#f97316", # 2: query | |
| ] | |
| ) | |
| # Sparse SSTA neighborhood visualization (dark mode) | |
| fig_ssta, axes_ssta = plt.subplots(1, T, figsize=(3 * T, 3), dpi=180) | |
| fig_ssta.patch.set_facecolor("#020617") | |
| axes_ssta = np.atleast_1d(axes_ssta) | |
| for t in range(T): | |
| ax = axes_ssta[t] | |
| # overlay map: 0 = white (background), 1 = neighbor, 2 = query | |
| overlay = np.zeros((H, W), dtype=int) | |
| overlay[mask_3d[t]] = 1 # neighbors at this frame | |
| if t == qt: | |
| qh_clamped = max(0, min(H - 1, qh)) | |
| qw_clamped = max(0, min(W - 1, qw)) | |
| overlay[qh_clamped, qw_clamped] = 2 # query | |
| ax.set_facecolor("#020617") | |
| ax.imshow( | |
| overlay, | |
| cmap=neighbor_cmap, | |
| origin="lower", | |
| aspect="auto", | |
| vmin=0, | |
| vmax=2, | |
| ) | |
| ax.set_xticks(np.arange(-0.5, W, 1), minor=True) | |
| ax.set_yticks(np.arange(-0.5, H, 1), minor=True) | |
| # Brighter grid for better visibility on dark background | |
| ax.grid(which="minor", color="#6b7280", linestyle="-", linewidth=0.8, alpha=0.9) | |
| ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) | |
| ax.set_title(f"t = {t}", fontsize=10, pad=6, color="#e5e7eb") | |
| mode_title = "Causal SSTA (prediction-time)" if is_causal else "Bidirectional SSTA (pretraining)" | |
| fig_ssta.suptitle( | |
| f"{mode_title} — query (t={qt}, h={qh}, w={qw})", | |
| fontsize=12, | |
| color="#f9fafb", | |
| ) | |
| fig_ssta.tight_layout(rect=(0, 0, 1, 0.9)) | |
| ts = int(time.time()) | |
| ssta_path = ATTN_VIZ_DIR / f"ssta_{ts}.png" | |
| fig_ssta.savefig(ssta_path) | |
| plt.close(fig_ssta) | |
| # Full attention visualization: query attends to all tokens (dark mode) | |
| fig_full, axes_full = plt.subplots(1, T, figsize=(3 * T, 3), dpi=180) | |
| fig_full.patch.set_facecolor("#020617") | |
| axes_full = np.atleast_1d(axes_full) | |
| for t in range(T): | |
| ax = axes_full[t] | |
| # 1 everywhere (all tokens attended), 2 at query on its frame | |
| overlay_full = np.ones((H, W), dtype=int) | |
| if t == qt: | |
| qh_clamped = max(0, min(H - 1, qh)) | |
| qw_clamped = max(0, min(W - 1, qw)) | |
| overlay_full[qh_clamped, qw_clamped] = 2 | |
| ax.set_facecolor("#020617") | |
| ax.imshow( | |
| overlay_full, | |
| cmap=neighbor_cmap, | |
| origin="lower", | |
| aspect="auto", | |
| vmin=0, | |
| vmax=2, | |
| ) | |
| ax.set_xticks(np.arange(-0.5, W, 1), minor=True) | |
| ax.set_yticks(np.arange(-0.5, H, 1), minor=True) | |
| # Brighter grid for better visibility on dark background | |
| ax.grid(which="minor", color="#6b7280", linestyle="-", linewidth=0.8, alpha=0.9) | |
| ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) | |
| ax.set_title(f"t = {t}", fontsize=10, pad=6, color="#e5e7eb") | |
| fig_full.suptitle( | |
| f"Full attention — query (t={qt}, h={qh}, w={qw})", | |
| fontsize=12, | |
| color="#f9fafb", | |
| ) | |
| fig_full.tight_layout(rect=(0, 0, 1, 0.9)) | |
| full_path = ATTN_VIZ_DIR / f"full_attn_{ts}.png" | |
| fig_full.savefig(full_path) | |
| plt.close(fig_full) | |
| mode_expl = ( | |
| "Bidirectional SSTA is used during pretraining to learn generic space–frequency–time dependencies." | |
| if not is_causal | |
| else "Causal SSTA is used for generative tasks like channel prediction, where future tokens are not visible." | |
| ) | |
| summary = ( | |
| f"Mode: {'Causal' if is_causal else 'Bidirectional'} SSTA. {mode_expl}\n" | |
| f"T={T}, H={H}, W={W}. Sparse neighbors: {int(mask.sum())}/{seq_len} tokens.\n" | |
| f"Offsets={offsets}, same_frame_window={same_window}, " | |
| f"temporal_window={temporal_window}, drift={drift}." | |
| ) | |
| return str(ssta_path), str(full_path), summary | |
| # --------------------------------------------------------------------------- | |
| # UI assembly | |
| # --------------------------------------------------------------------------- | |
| def build_lab() -> gr.Blocks: | |
| with gr.Blocks(title="LWM-Temporal Lab") as demo: | |
| gr.Markdown( | |
| """ | |
| # LWM-Temporal Interactive Lab | |
| A playground for dynamic DeepMIMO scenarios, angle–delay evolution, | |
| zero-shot performances, masking strategies, and Sparse Spatio-Temporal Attention (SSTA) | |
| """ | |
| ) | |
| scenario_state = gr.State() | |
| with gr.Tab("Scenario & Data"): | |
| gr.Markdown( | |
| "Design a **dynamic wireless world** in a few clicks.\n\n" | |
| "LWM-Temporal extends the DeepMIMO library with a fully-flexible dynamic scenario generator, so you can:\n" | |
| "- Mix **realistic ray-traced environments** from engines like *Wireless InSite*, *Sionna*, and *AODT* from cities all around the globe with DeepMIMO v4 (https://deepmimo.net/)\n" | |
| "- Configure the **antenna array** (horizontal / vertical elements) and **number of subcarriers**\n" | |
| "- Control the **temporal resolution** (time steps and sampling rate) of your channel sequence\n" | |
| "- Shape the **traffic**: number of vehicles and pedestrians, their speeds, and turn probability\n" | |
| "- Tune the **road geometry** (road width and center spacing) and the **delay-bin keep percentage** in the angle–delay domain\n\n" | |
| "These controls let you pretrain LWM-Temporal at scale on diverse, richly-parameterized scenarios before moving to downstream tasks" | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| scenario_family = gr.Dropdown( | |
| choices=[ | |
| "General DeepMIMO v4 scenarios", | |
| "LWM-Temporal train scenarios", | |
| "LWM-Temporal test scenarios", | |
| ], | |
| value="General DeepMIMO v4 scenarios", | |
| label="Scenario family", | |
| ) | |
| general_subfamily = gr.Dropdown( | |
| choices=DEEPMIMO_GENERAL_SUBFAMILIES, | |
| value=DEEPMIMO_GENERAL_SUBFAMILIES[0], | |
| label="DeepMIMO v4 sub-family", | |
| visible=True, | |
| ) | |
| scenario_input = gr.Dropdown( | |
| choices=DEEPMIMO_V4_SCENARIOS, | |
| value=DEFAULT_SCENARIOS[16], | |
| label="Selected DeepMIMO scenario", | |
| interactive=True, | |
| allow_custom_value=True, | |
| ) | |
| gr.Markdown( | |
| "We provide three scenario families:\n" | |
| "- **General DeepMIMO v4 scenarios**: Full DeepMIMO v4 catalog\n" | |
| "- **LWM-Temporal train scenarios**: Dense scenarios with 10 cm grid spacing\n" | |
| "- **LWM-Temporal test scenarios**: Held-out dense scenarios for robust evaluation\n\n" | |
| "Pick a family, then choose a specific scenario or type any valid DeepMIMO v4 scenario name" | |
| ) | |
| time_steps_input = gr.Slider(1, 64, value=11, step=1, label="Time steps") | |
| sample_dt_input = gr.Slider(1e-3, 1e-1, value=1e-3, step=1e-4, label="Sample Δt (s)") | |
| vehicle_input = gr.Slider(1, 1000, value=300, step=10, label="Vehicles") | |
| pedestrian_input = gr.Slider(0, 100, value=50, step=5, label="Pedestrians") | |
| turn_prob_input = gr.Slider(0.0, 0.4, value=0.1, step=0.01, label="Turn probability") | |
| speed_input = gr.Slider(0, 150, value=20, step=5, label="Vehicle speed max (km/h)") | |
| road_width_input = gr.Slider(1.0, 6.0, value=2.0, step=0.5, label="Road width (m)") | |
| road_spacing_input = gr.Slider(6.0, 40.0, value=8.0, step=1.0, label="Road center spacing (m)") | |
| keep_pct_input = gr.Slider(0.1, 1.0, value=0.25, step=0.05, label="Angle-delay keep %") | |
| with gr.Row(): | |
| overwrite_toggle = gr.Checkbox(value=False, label="Overwrite cached files") | |
| download_toggle = gr.Checkbox(value=True, label="Download scenario if missing") | |
| run_button = gr.Button("Generate scenario", variant="primary") | |
| with gr.Column(scale=2): | |
| env_image = gr.Image( | |
| type="filepath", | |
| label="Environment & Dynamic UEs", | |
| interactive=False, | |
| height=400, | |
| ) | |
| env_plot = gr.Plot( | |
| label="Dynamic UEs (pan & zoom)", | |
| ) | |
| scenario_summary = gr.Markdown(label="Scenario status") | |
| payload_stats = gr.Textbox(lines=5, label="Payload summary", max_lines=6) | |
| # When the scenario family changes, update the available choices in the scenario dropdown. | |
| def _update_scenarios(family: str, subfamily: str) -> Any: | |
| if "train" in family: | |
| choices = LWM_TEMPORAL_TRAIN_SCENARIOS | |
| elif "test" in family: | |
| choices = LWM_TEMPORAL_TEST_SCENARIOS | |
| else: | |
| # General DeepMIMO v4 scenarios, further split into sub-families | |
| label = (subfamily or "").lower() | |
| if "sionna" in label or "_s" in label: | |
| choices = DEEPMIMO_SIONNA_SCENARIOS | |
| elif "3.5" in label or "3p5" in label: | |
| choices = DEEPMIMO_WI_CITY_3P5_SCENARIOS | |
| elif "28ghz" in label or "28" in label: | |
| choices = DEEPMIMO_WI_CITY_28_SCENARIOS | |
| elif "other" in label: | |
| choices = DEEPMIMO_WI_OTHER_SCENARIOS | |
| else: | |
| choices = DEEPMIMO_V4_SCENARIOS | |
| value = choices[0] if choices else None | |
| # Use generic gr.update for compatibility with older Gradio versions. | |
| return gr.update(choices=choices, value=value) | |
| scenario_family.change( | |
| fn=_update_scenarios, | |
| inputs=[scenario_family, general_subfamily], | |
| outputs=scenario_input, | |
| ) | |
| general_subfamily.change( | |
| fn=_update_scenarios, | |
| inputs=[scenario_family, general_subfamily], | |
| outputs=scenario_input, | |
| ) | |
| # Show the sub-family dropdown only when the general DeepMIMO family is active. | |
| def _toggle_general_subfamily(family: str) -> Any: | |
| is_general = "general deepmimo v4" in (family or "").lower() | |
| return gr.update(visible=is_general) | |
| scenario_family.change( | |
| fn=_toggle_general_subfamily, | |
| inputs=scenario_family, | |
| outputs=general_subfamily, | |
| ) | |
| run_button.click( | |
| fn=generate_scenario_action, | |
| inputs=[ | |
| scenario_state, | |
| scenario_input, | |
| time_steps_input, | |
| sample_dt_input, | |
| vehicle_input, | |
| pedestrian_input, | |
| turn_prob_input, | |
| speed_input, | |
| road_width_input, | |
| road_spacing_input, | |
| keep_pct_input, | |
| overwrite_toggle, | |
| download_toggle, | |
| ], | |
| outputs=[scenario_summary, payload_stats, env_image, env_plot, scenario_state], | |
| ) | |
| with gr.Tab("Angle–Delay Visuals"): | |
| gr.Markdown( | |
| "Explore the **angle–delay domain** where LWM-Temporal actually thinks\n\n" | |
| "- Represent rich **space–frequency channels** in a compact, sparse grid over angle and delay\n" | |
| "- Work directly with the **physical multipath structure** (rays and echoes) instead of opaque frequency-domain tensors\n" | |
| "- Train LWM-Temporal in a **physics-informed** way and interpret its attention maps over angle–delay bins\n\n" | |
| "Use this tab to compare how a single UE’s channel evolves over time in both continuous and discretized angle–delay form" | |
| ) | |
| with gr.Column(): | |
| # First row: controls (dropdowns / sliders / button) | |
| with gr.Row(): | |
| ue_slider = gr.Slider(0, 512, value=0, step=1, label="UE index") | |
| keep_slider = gr.Slider(0.1, 1.0, value=0.25, step=0.05, label="Keep percentage") | |
| viz_button = gr.Button("Render visualizations") | |
| # Second row: three figures | |
| with gr.Row(): | |
| angle_delay_gif = gr.Image( | |
| type="filepath", | |
| label="Angle–delay GIF", | |
| interactive=False, | |
| height=320, | |
| ) | |
| channel_gif = gr.Image( | |
| type="filepath", | |
| label="Channel GIF", | |
| interactive=False, | |
| height=320, | |
| ) | |
| curve_png = gr.Image( | |
| type="filepath", | |
| label="Dominant bin traces", | |
| interactive=False, | |
| height=320, | |
| ) | |
| # Third row: status / stats under the figures | |
| viz_summary = gr.Textbox(lines=3, label="Status") | |
| viz_button.click( | |
| fn=visualize_angle_delay_action, | |
| inputs=[scenario_state, ue_slider, keep_slider], | |
| outputs=[angle_delay_gif, channel_gif, curve_png, viz_summary], | |
| ) | |
| # Dynamically adapt the UE index slider max to the number of users | |
| # (vehicles + pedestrians) chosen in the first tab. | |
| def _update_ue_slider_from_counts( | |
| vehicles: float, | |
| pedestrians: float, | |
| current_value: float, | |
| ) -> Any: | |
| total_users = max(0, int(vehicles) + int(pedestrians)) | |
| # UE indices are 0-based, so the maximum valid index is total_users - 1. | |
| max_idx = max(0, total_users - 1) | |
| clamped_value = max(0, min(int(current_value), max_idx)) | |
| # Use generic gr.update to stay compatible with the Gradio version | |
| # in your environment. | |
| return gr.update(maximum=max_idx, value=clamped_value) | |
| vehicle_input.change( | |
| fn=_update_ue_slider_from_counts, | |
| inputs=[vehicle_input, pedestrian_input, ue_slider], | |
| outputs=ue_slider, | |
| ) | |
| pedestrian_input.change( | |
| fn=_update_ue_slider_from_counts, | |
| inputs=[vehicle_input, pedestrian_input, ue_slider], | |
| outputs=ue_slider, | |
| ) | |
| with gr.Tab("Zero-Shot Channel Estimation"): | |
| gr.Markdown( | |
| "Peek under the hood of **physics-informed pretraining** with Masked Channel Modeling (MCM)\n\n" | |
| "- LWM-Temporal builds on our earlier LWM 1.0 / 1.1 models and extends their **self-supervised MCM** pretraining to dynamic, angle–delay channels\n" | |
| "- MCM hides parts of the channel tensor and asks the model to reconstruct them, forcing it to internalize the **underlying propagation physics** instead of just memorizing patterns\n" | |
| "- Here, MCM becomes **physics-informed** via four complementary masking strategies (random, rectangular, tube, comb) that stress-test different aspects of the channel\n\n" | |
| "Use this tab to run a quick **zero-shot reconstruction check** on a single batch after pretraining with our bidirectional SSTA attention" | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| mask_ratio_input = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="Mask ratio") | |
| batch_size_input = gr.Slider(1, 32, value=8, step=1, label="Batch size") | |
| max_batches_input = gr.Slider(1, 10, value=1, step=1, label="Max batches") | |
| device_radio = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device") | |
| masked_button = gr.Button("Run masked modeling") | |
| with gr.Column(scale=1): | |
| masked_output = gr.Textbox(lines=4, label="NMSE summary") | |
| masked_button.click( | |
| fn=masked_modeling_action, | |
| inputs=[scenario_state, mask_ratio_input, batch_size_input, device_radio, max_batches_input], | |
| outputs=masked_output, | |
| ) | |
| with gr.Tab("Zero-Shot Channel Prediction"): | |
| gr.Markdown( | |
| "Turn LWM-Temporal into a **generative predictor** for dynamic channels\n\n" | |
| "- Beyond understanding tasks, many downstream applications are **temporal and generative** (e.g., predicting future channels or strongest beams)\n" | |
| "- This tab showcases **channel prediction**: given T_past observed channels, the model predicts the next T_horizon future channels\n" | |
| "- For this task, we fine-tune the pretrained backbone with **causal SSTA**, so the model only attends to current and past frames—never to the future\n\n" | |
| "Use this tab to inspect the **zero-shot prediction preformance** of the model on your selected scenario" | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| tpast_slider = gr.Slider(4, 12, value=10, step=1, label="Tpast") | |
| horizon_slider = gr.Slider(1, 4, value=1, step=1, label="Horizon") | |
| cp_device = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device") | |
| train_limit_slider = gr.Slider(1, 200, value=1, step=1, label="Context pool (train split)") | |
| val_limit_slider = gr.Slider(1, 100, value=1, step=1, label="Context pool (val split)") | |
| channel_button = gr.Button("Run channel prediction inference") | |
| with gr.Column(scale=1): | |
| channel_summary = gr.Textbox(lines=4, label="Inference summary") | |
| channel_image = gr.Image( | |
| type="filepath", | |
| label="Prediction visualization", | |
| interactive=False, | |
| height=260, | |
| ) | |
| channel_button.click( | |
| fn=channel_prediction_action, | |
| inputs=[scenario_state, tpast_slider, horizon_slider, cp_device, train_limit_slider, val_limit_slider], | |
| outputs=[channel_summary, channel_image], | |
| ) | |
| with gr.Tab("Masking Modes"): | |
| gr.Markdown( | |
| "Toy example: Visualize the **four physics-informed masking strategies** used during pretraining\n\n" | |
| "- **Random**: masks tokens randomly across the angle–delay grid\n" | |
| "- **Rectangular**: masks contiguous rectangular regions of the grid\n" | |
| "- **Tube**: masks tubes of tokens along the delay axis (capturing delay-local structures)\n" | |
| "- **Comb**: masks comb-like patterns along the angle axis (capturing angular diversity)\n\n" | |
| "Use this tab to see how each strategy sculpts the information given to the model and how robustly it can reconstruct the masked channels" | |
| ) | |
| with gr.Row(): | |
| T_mask = gr.Slider(2, 12, value=6, step=1, label="Frames (T)") | |
| H_mask = gr.Slider(4, 32, value=16, step=1, label="Angle bins (H)") | |
| W_mask = gr.Slider(4, 32, value=16, step=1, label="Delay bins (W)") | |
| with gr.Row(): | |
| mask_mode_dd = gr.Dropdown( | |
| choices=["auto", "random", "rect", "tube", "comb"], | |
| value="auto", | |
| label="Masking mode", | |
| ) | |
| gallery_ratio = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="Target mask ratio") | |
| gallery_examples = gr.Slider(1, 4, value=3, step=1, label="Examples to draw") | |
| gallery_button = gr.Button("Render mask gallery") | |
| gallery_image = gr.Image( | |
| type="filepath", | |
| label="Mask gallery", | |
| interactive=False, | |
| height=260, | |
| ) | |
| gallery_summary = gr.Textbox(lines=4, label="Coverage summary") | |
| gallery_button.click( | |
| fn=mask_gallery_action, | |
| inputs=[T_mask, H_mask, W_mask, mask_mode_dd, gallery_ratio, gallery_examples], | |
| outputs=[gallery_image, gallery_summary], | |
| ) | |
| with gr.Tab("Sparse Spatio-Temporal Attention (SSTA)"): | |
| gr.Markdown( | |
| "Toy example: visualize how Sparse Spatio-Temporal Attention (SSTA) selects neighbors for a single query token and how it differs from full attention in efficiency and physics-informedness\n\n" | |
| "- **Bidirectional SSTA** (pretraining): attends to both past and future frames to learn generic spatial–spectral–temporal dependencies\n" | |
| "- **Causal SSTA** (generative downstream tasks): attends only to current and past frames (no future leaks) for generative tasks like channel prediction" | |
| ) | |
| # Controls on top | |
| with gr.Row(equal_height=True): | |
| T = 8; H = 8; W = 8 | |
| with gr.Column(scale=1): | |
| T_slider = gr.Slider(1, 11, value=T, step=1, label="Frames (T)") | |
| H_slider = gr.Slider(4, 32, value=H, step=1, label="Angle bins (H)") | |
| W_slider = gr.Slider(4, 32, value=W, step=1, label="Delay bins (W)") | |
| with gr.Column(scale=1): | |
| same_window_slider = gr.Slider(0, 3, value=1, step=1, label="Same-frame window") | |
| temporal_window_slider = gr.Slider(0, 4, value=1, step=1, label="Temporal spatial window") | |
| drift_slider = gr.Slider(0, 4, value=1, step=1, label="Temporal drift") | |
| ssta_mode_radio = gr.Radio( | |
| choices=[ | |
| "Bidirectional SSTA (pretraining)", | |
| "Causal SSTA (generative downstream tasks)", | |
| ], | |
| value="Bidirectional SSTA (pretraining)", | |
| label="Attention mode", | |
| ) | |
| with gr.Column(scale=1): | |
| offsets_text = gr.Textbox(value="-2,-1,0,1,2", label="Temporal offsets (comma separated)") | |
| query_t_slider = gr.Slider(0, T, value=3, step=1, label="Query t (time index)") | |
| query_h_slider = gr.Slider(0, H, value=4, step=1, label="Query angle index (h)") | |
| query_w_slider = gr.Slider(0, W, value=5, step=1, label="Query delay index (w)") | |
| attention_button = gr.Button("Visualize attention neighborhood") | |
| # Visualization underneath: SSTA on top, full attention below, summary at the end. | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| attention_ssta_image = gr.Image( | |
| type="filepath", | |
| label="SSTA neighborhood: O(N)", | |
| interactive=False, | |
| height=200, | |
| ) | |
| attention_full_image = gr.Image( | |
| type="filepath", | |
| label="Full attention: O(N^2)", | |
| interactive=False, | |
| height=200, | |
| ) | |
| attention_summary = gr.Textbox(lines=4, label="Neighborhood summary") | |
| # Keep query sliders in-range whenever T, H, or W sliders change. | |
| def _sync_query_t(T_val: float, current_q: float) -> Any: | |
| max_idx = max(0, int(T_val)) | |
| clamped = max(0, min(int(current_q), max_idx)) | |
| return gr.update(maximum=max_idx, value=clamped) | |
| def _sync_query_h(H_val: float, current_q: float) -> Any: | |
| max_idx = max(0, int(H_val)) | |
| clamped = max(0, min(int(current_q), max_idx)) | |
| return gr.update(maximum=max_idx, value=clamped) | |
| def _sync_query_w(W_val: float, current_q: float) -> Any: | |
| max_idx = max(0, int(W_val)) | |
| clamped = max(0, min(int(current_q), max_idx)) | |
| return gr.update(maximum=max_idx, value=clamped) | |
| T_slider.change( | |
| fn=_sync_query_t, | |
| inputs=[T_slider, query_t_slider], | |
| outputs=query_t_slider, | |
| ) | |
| H_slider.change( | |
| fn=_sync_query_h, | |
| inputs=[H_slider, query_h_slider], | |
| outputs=query_h_slider, | |
| ) | |
| W_slider.change( | |
| fn=_sync_query_w, | |
| inputs=[W_slider, query_w_slider], | |
| outputs=query_w_slider, | |
| ) | |
| attention_button.click( | |
| fn=ssta_attention_action, | |
| inputs=[ | |
| T_slider, | |
| H_slider, | |
| W_slider, | |
| same_window_slider, | |
| temporal_window_slider, | |
| drift_slider, | |
| ssta_mode_radio, | |
| offsets_text, | |
| query_t_slider, | |
| query_h_slider, | |
| query_w_slider, | |
| ], | |
| outputs=[attention_ssta_image, attention_full_image, attention_summary], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ## Notes | |
| - Steps reuse cached artifacts on disk so you can iterate without re-running the entire pipeline. | |
| - Heavy stages (channel prediction) run with conservative defaults. Increase limits as needed. | |
| - All artifacts are written under `figs/lab`, `examples/data/lab`, and `checkpoints/`. | |
| """ | |
| ) | |
| return demo | |
| demo = build_lab() | |
| if __name__ == "__main__": | |
| # demo.launch( | |
| # share=True, | |
| # server_name="0.0.0.0", | |
| # server_port=7860, | |
| # debug=False | |
| # ) | |
| # demo.launch( | |
| # share=True, | |
| # server_name="0.0.0.0", | |
| # server_port=7860, | |
| # debug=False, | |
| # show_error=True | |
| # ) | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| allowed_paths=["/figs"] | |
| ) | |