import argparse import logging import sys from pathlib import Path from typing import List, Optional, Sequence, Tuple import matplotlib.pyplot as plt import numpy as np import torch # type: ignore # Make sure the package root is on sys.path when running the example directly. ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) try: # Optional HF-style dataset utilities (may be absent in this checkout) from LWMTemporal.data import AngleDelayDatasetConfig, AngleDelaySequenceDataset # type: ignore except ImportError: # pragma: no cover - keep script functional without data module AngleDelayDatasetConfig = None # type: ignore AngleDelaySequenceDataset = None # type: ignore EPS = 1e-8 logger = logging.getLogger("ad_temporal_evolution") def configure_style() -> None: plt.style.use("dark_background") plt.rcParams.update( { "figure.facecolor": "#0b0e11", "axes.facecolor": "#0b0e11", "axes.edgecolor": "#374151", "axes.labelcolor": "#e5e7eb", "axes.titleweight": "semibold", "text.color": "#e5e7eb", "xtick.color": "#9ca3af", "ytick.color": "#9ca3af", "grid.color": "#1f2937", "figure.autolayout": False, "font.size": 11, "legend.frameon": False, } ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Visualise how angle-delay bins evolve over time using the LWM-Temporal preprocessing stack.", ) parser.add_argument( "--data_path", type=Path, default=Path("examples/data/city_12_chiyoda_3p5_20_32_32.p"), help="Path to a .p payload (dict with 'channel') or raw tensor file.", ) parser.add_argument( "--sample_idx", type=int, default=0, help="Sample index to select when the payload has a batch dimension (S, T, H, W).", ) parser.add_argument( "--keep_percentage", type=float, default=0.25, help="Fraction of strongest delay taps to keep when converting to angle-delay.", ) parser.add_argument( "--normalize", choices=["none", "per_sample_rms", "global_rms"], default="global_rms", help="Normalization mode applied after the angle-delay transform.", ) parser.add_argument( "--bins", type=int, default=6, help="Number of angle-delay bins to visualise (top-K by average magnitude).", ) parser.add_argument( "--coords", type=int, nargs="*", help="Optional explicit bin coordinates supplied as n0 m0 n1 m1 ...", ) parser.add_argument( "--out_path", type=Path, default=Path("examples/data/figs/ad_temporal_evolution.png"), help="Destination path for the saved figure.", ) parser.add_argument( "--max_time_steps", type=int, default=None, help="Optional temporal truncation applied before preprocessing.", ) parser.add_argument( "--cache_dir", type=Path, default=Path("cache"), help="Cache directory used by the dataset API.", ) parser.add_argument( "--no_cache", action="store_true", help="Disable caching when using the dataset API.", ) parser.add_argument( "--overwrite_cache", action="store_true", help="Overwrite cached tensors when using the dataset API.", ) parser.add_argument( "--phase_mode", choices=["real_imag", "mag_phase"], default="real_imag", help="Phase representation expected by downstream models (dataset API).", ) parser.add_argument( "--patch_height", type=int, default=1, help="Patch height provided to the dataset API (ignored when unavailable).", ) parser.add_argument( "--patch_width", type=int, default=1, help="Patch width provided to the dataset API (ignored when unavailable).", ) parser.add_argument( "--title", type=str, default=None, help="Optional custom figure title.", ) parser.add_argument( "--verbose", action="store_true", help="Enable debug logging for troubleshooting.", ) return parser.parse_args() def _configure_logging(verbose: bool) -> None: level = logging.DEBUG if verbose else logging.INFO logging.basicConfig(level=level, format="[%(levelname)s] %(message)s") def _parse_coord_pairs(raw: Optional[Sequence[int]]) -> Optional[List[Tuple[int, int]]]: if not raw: return None if len(raw) % 2 != 0: raise ValueError("coords must be provided as pairs: n0 m0 n1 m1 ...") pairs = [] for i in range(0, len(raw), 2): pairs.append((int(raw[i]), int(raw[i + 1]))) return pairs def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor: if tensor.is_complex(): return tensor.to(torch.complex64) if tensor.ndim >= 1 and tensor.size(-1) == 2: real = tensor[..., 0].float() imag = tensor[..., 1].float() return torch.complex(real, imag) return torch.complex(tensor.float(), torch.zeros_like(tensor.float())) def load_sequence(args: argparse.Namespace) -> torch.Tensor: if AngleDelayDatasetConfig is None or AngleDelaySequenceDataset is None: raise ImportError( "LWMTemporal.data.datasets is required. Install the full LWMTemporal package to use this example.", ) cfg = AngleDelayDatasetConfig( raw_path=args.data_path, keep_percentage=args.keep_percentage, normalize=args.normalize, cache_dir=args.cache_dir, use_cache=not args.no_cache, overwrite_cache=args.overwrite_cache, snr_db=None, noise_seed=None, max_time_steps=args.max_time_steps, patch_size=(args.patch_height, args.patch_width), phase_mode=args.phase_mode, ) dataset = AngleDelaySequenceDataset(cfg) if len(dataset) == 0: raise RuntimeError("AngleDelaySequenceDataset returned zero samples.") idx = max(0, min(args.sample_idx, len(dataset) - 1)) sample = dataset[idx] if isinstance(sample, dict): if "sequence" in sample: tensor = sample["sequence"] elif "angle_delay" in sample: tensor = sample["angle_delay"] else: raise KeyError("Dataset item missing 'sequence' or 'angle_delay' entries.") else: tensor = sample tensor = _ensure_complex(torch.as_tensor(tensor)) if tensor.ndim == 4 and tensor.size(0) == 1: tensor = tensor.squeeze(0) if tensor.ndim != 3: raise ValueError(f"Expected dataset sample with shape (T, H, W); received {tuple(tensor.shape)}") logger.debug("Loaded sequence via dataset API with shape %s", tuple(tensor.shape)) return tensor def pick_bins( sequence: torch.Tensor, k: int, coords: Optional[List[Tuple[int, int]]], ) -> List[Tuple[int, int]]: if sequence.ndim != 3: raise ValueError(f"Expected angle-delay tensor with shape (T, H, W); got {tuple(sequence.shape)}") _, H, W = sequence.shape picks: List[Tuple[int, int]] = [] if coords: for n, m in coords: if 0 <= n < H and 0 <= m < W and (n, m) not in picks: picks.append((n, m)) if len(picks) >= k: return picks[:k] remaining = max(0, k - len(picks)) if remaining == 0: return picks mag = sequence.abs().mean(dim=0) topk = torch.topk(mag.flatten(), k=min(remaining, H * W - len(picks))) for idx in topk.indices.tolist(): n = idx // W m = idx % W if (n, m) not in picks: picks.append((n, m)) if len(picks) == k: break return picks def fit_line(y: np.ndarray) -> Tuple[float, float, float]: x = np.arange(len(y)) A = np.vstack([x, np.ones_like(x)]).T sol, *_ = np.linalg.lstsq(A, y, rcond=None) slope, intercept = sol y_pred = slope * x + intercept ss_res = np.sum((y - y_pred) ** 2) ss_tot = np.sum((y - y.mean()) ** 2) + EPS r2 = 1.0 - ss_res / ss_tot return float(slope), float(intercept), float(r2) def plot_curves( sequence: torch.Tensor, picks: List[Tuple[int, int]], out_path: Path, title: str, ) -> None: if sequence.ndim != 3: raise ValueError("plot_curves expects a tensor with shape (T, H, W).") T = sequence.shape[0] times = np.arange(T) num_bins = len(picks) if num_bins == 0: raise ValueError("No bins were selected for plotting.") fig, axes = plt.subplots( num_bins, 2, figsize=(11, 3 * max(1, num_bins)), dpi=150, constrained_layout=True, ) fig.patch.set_facecolor("#0b0e11") axes = np.atleast_2d(axes) label_color = "#cbd5f5" title_color = "#f8fafc" for row, (n, m) in enumerate(picks): series = sequence[:, n, m] mag = series.abs().cpu().numpy() phase = torch.angle(series).cpu().numpy() phase = np.unwrap(phase) slope_mag, _, r2_mag = fit_line(mag) slope_phase, _, r2_phase = fit_line(phase) ax_mag = axes[row, 0] ax_mag.set_facecolor("#111827") ax_mag.plot( times, mag, label=f"|H| slope={slope_mag:.3g}, R²={r2_mag:.2f}", color="#38bdf8", linewidth=2.2, ) ax_mag.fill_between(times, mag, color="#38bdf8", alpha=0.08) ax_mag.set_ylim(mag.min(), mag.max()) ax_mag.set_title(f"Bin (n={n}, m={m}) magnitude", color=title_color) ax_mag.set_xlabel("time index", color=label_color) ax_mag.set_ylabel("|H|", color=label_color) ax_mag.tick_params(colors=label_color) ax_mag.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) for spine in ax_mag.spines.values(): spine.set_color("#1f2937") legend_mag = ax_mag.legend(loc="upper left", fontsize=9) legend_mag.get_frame().set_facecolor("#111827") legend_mag.get_frame().set_alpha(0.6) for text in legend_mag.get_texts(): text.set_color(label_color) ax_phase = axes[row, 1] ax_phase.set_facecolor("#111827") ax_phase.plot( times, phase, label=f"∠H slope={slope_phase:.3g}, R²={r2_phase:.2f}", color="#f87171", linewidth=2.2, ) ax_phase.set_title(f"Bin (n={n}, m={m}) phase (unwrapped)", color=title_color) ax_phase.set_xlabel("time index", color=label_color) ax_phase.set_ylabel("radians", color=label_color) ax_phase.tick_params(colors=label_color) ax_phase.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) for spine in ax_phase.spines.values(): spine.set_color("#1f2937") legend_phase = ax_phase.legend(loc="upper left", fontsize=9) legend_phase.get_frame().set_facecolor("#111827") legend_phase.get_frame().set_alpha(0.6) for text in legend_phase.get_texts(): text.set_color(label_color) fig.suptitle(title, fontsize=12, color=title_color) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path) plt.close(fig) def main() -> None: args = parse_args() _configure_logging(args.verbose) configure_style() coords = _parse_coord_pairs(args.coords) sequence = load_sequence(args) picks = pick_bins(sequence, args.bins, coords) if not picks: raise RuntimeError("Unable to select any angle-delay bins for visualisation.") title = args.title or f"Angle-delay temporal curves | keep={args.keep_percentage:.2f} | norm={args.normalize}" plot_curves(sequence, picks, args.out_path, title) logger.info("Saved figure to %s", args.out_path) if __name__ == "__main__": main()