| | import argparse |
| | import logging |
| | import sys |
| | from pathlib import Path |
| | from typing import List, Optional, Sequence, Tuple |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import torch |
| |
|
| | |
| | ROOT = Path(__file__).resolve().parents[1] |
| | if str(ROOT) not in sys.path: |
| | sys.path.insert(0, str(ROOT)) |
| |
|
| | try: |
| | from LWMTemporal.data import AngleDelayDatasetConfig, AngleDelaySequenceDataset |
| | except ImportError: |
| | AngleDelayDatasetConfig = None |
| | AngleDelaySequenceDataset = None |
| |
|
| | EPS = 1e-8 |
| | logger = logging.getLogger("ad_temporal_evolution") |
| |
|
| |
|
| | def configure_style() -> None: |
| | plt.style.use("dark_background") |
| | plt.rcParams.update( |
| | { |
| | "figure.facecolor": "#0b0e11", |
| | "axes.facecolor": "#0b0e11", |
| | "axes.edgecolor": "#374151", |
| | "axes.labelcolor": "#e5e7eb", |
| | "axes.titleweight": "semibold", |
| | "text.color": "#e5e7eb", |
| | "xtick.color": "#9ca3af", |
| | "ytick.color": "#9ca3af", |
| | "grid.color": "#1f2937", |
| | "figure.autolayout": False, |
| | "font.size": 11, |
| | "legend.frameon": False, |
| | } |
| | ) |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description="Visualise how angle-delay bins evolve over time using the LWM-Temporal preprocessing stack.", |
| | ) |
| | parser.add_argument( |
| | "--data_path", |
| | type=Path, |
| | default=Path("examples/data/city_12_chiyoda_3p5_20_32_32.p"), |
| | help="Path to a .p payload (dict with 'channel') or raw tensor file.", |
| | ) |
| | parser.add_argument( |
| | "--sample_idx", |
| | type=int, |
| | default=0, |
| | help="Sample index to select when the payload has a batch dimension (S, T, H, W).", |
| | ) |
| | parser.add_argument( |
| | "--keep_percentage", |
| | type=float, |
| | default=0.25, |
| | help="Fraction of strongest delay taps to keep when converting to angle-delay.", |
| | ) |
| | parser.add_argument( |
| | "--normalize", |
| | choices=["none", "per_sample_rms", "global_rms"], |
| | default="global_rms", |
| | help="Normalization mode applied after the angle-delay transform.", |
| | ) |
| | parser.add_argument( |
| | "--bins", |
| | type=int, |
| | default=6, |
| | help="Number of angle-delay bins to visualise (top-K by average magnitude).", |
| | ) |
| | parser.add_argument( |
| | "--coords", |
| | type=int, |
| | nargs="*", |
| | help="Optional explicit bin coordinates supplied as n0 m0 n1 m1 ...", |
| | ) |
| | parser.add_argument( |
| | "--out_path", |
| | type=Path, |
| | default=Path("examples/data/figs/ad_temporal_evolution.png"), |
| | help="Destination path for the saved figure.", |
| | ) |
| | parser.add_argument( |
| | "--max_time_steps", |
| | type=int, |
| | default=None, |
| | help="Optional temporal truncation applied before preprocessing.", |
| | ) |
| | parser.add_argument( |
| | "--cache_dir", |
| | type=Path, |
| | default=Path("cache"), |
| | help="Cache directory used by the dataset API.", |
| | ) |
| | parser.add_argument( |
| | "--no_cache", |
| | action="store_true", |
| | help="Disable caching when using the dataset API.", |
| | ) |
| | parser.add_argument( |
| | "--overwrite_cache", |
| | action="store_true", |
| | help="Overwrite cached tensors when using the dataset API.", |
| | ) |
| | parser.add_argument( |
| | "--phase_mode", |
| | choices=["real_imag", "mag_phase"], |
| | default="real_imag", |
| | help="Phase representation expected by downstream models (dataset API).", |
| | ) |
| | parser.add_argument( |
| | "--patch_height", |
| | type=int, |
| | default=1, |
| | help="Patch height provided to the dataset API (ignored when unavailable).", |
| | ) |
| | parser.add_argument( |
| | "--patch_width", |
| | type=int, |
| | default=1, |
| | help="Patch width provided to the dataset API (ignored when unavailable).", |
| | ) |
| | parser.add_argument( |
| | "--title", |
| | type=str, |
| | default=None, |
| | help="Optional custom figure title.", |
| | ) |
| | parser.add_argument( |
| | "--verbose", |
| | action="store_true", |
| | help="Enable debug logging for troubleshooting.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def _configure_logging(verbose: bool) -> None: |
| | level = logging.DEBUG if verbose else logging.INFO |
| | logging.basicConfig(level=level, format="[%(levelname)s] %(message)s") |
| |
|
| |
|
| | def _parse_coord_pairs(raw: Optional[Sequence[int]]) -> Optional[List[Tuple[int, int]]]: |
| | if not raw: |
| | return None |
| | if len(raw) % 2 != 0: |
| | raise ValueError("coords must be provided as pairs: n0 m0 n1 m1 ...") |
| | pairs = [] |
| | for i in range(0, len(raw), 2): |
| | pairs.append((int(raw[i]), int(raw[i + 1]))) |
| | return pairs |
| |
|
| |
|
| | def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor: |
| | if tensor.is_complex(): |
| | return tensor.to(torch.complex64) |
| | if tensor.ndim >= 1 and tensor.size(-1) == 2: |
| | real = tensor[..., 0].float() |
| | imag = tensor[..., 1].float() |
| | return torch.complex(real, imag) |
| | return torch.complex(tensor.float(), torch.zeros_like(tensor.float())) |
| |
|
| |
|
| | def load_sequence(args: argparse.Namespace) -> torch.Tensor: |
| | if AngleDelayDatasetConfig is None or AngleDelaySequenceDataset is None: |
| | raise ImportError( |
| | "LWMTemporal.data.datasets is required. Install the full LWMTemporal package to use this example.", |
| | ) |
| | cfg = AngleDelayDatasetConfig( |
| | raw_path=args.data_path, |
| | keep_percentage=args.keep_percentage, |
| | normalize=args.normalize, |
| | cache_dir=args.cache_dir, |
| | use_cache=not args.no_cache, |
| | overwrite_cache=args.overwrite_cache, |
| | snr_db=None, |
| | noise_seed=None, |
| | max_time_steps=args.max_time_steps, |
| | patch_size=(args.patch_height, args.patch_width), |
| | phase_mode=args.phase_mode, |
| | ) |
| | dataset = AngleDelaySequenceDataset(cfg) |
| | if len(dataset) == 0: |
| | raise RuntimeError("AngleDelaySequenceDataset returned zero samples.") |
| | idx = max(0, min(args.sample_idx, len(dataset) - 1)) |
| | sample = dataset[idx] |
| | if isinstance(sample, dict): |
| | if "sequence" in sample: |
| | tensor = sample["sequence"] |
| | elif "angle_delay" in sample: |
| | tensor = sample["angle_delay"] |
| | else: |
| | raise KeyError("Dataset item missing 'sequence' or 'angle_delay' entries.") |
| | else: |
| | tensor = sample |
| | tensor = _ensure_complex(torch.as_tensor(tensor)) |
| | if tensor.ndim == 4 and tensor.size(0) == 1: |
| | tensor = tensor.squeeze(0) |
| | if tensor.ndim != 3: |
| | raise ValueError(f"Expected dataset sample with shape (T, H, W); received {tuple(tensor.shape)}") |
| | logger.debug("Loaded sequence via dataset API with shape %s", tuple(tensor.shape)) |
| | return tensor |
| |
|
| |
|
| | def pick_bins( |
| | sequence: torch.Tensor, |
| | k: int, |
| | coords: Optional[List[Tuple[int, int]]], |
| | ) -> List[Tuple[int, int]]: |
| | if sequence.ndim != 3: |
| | raise ValueError(f"Expected angle-delay tensor with shape (T, H, W); got {tuple(sequence.shape)}") |
| | _, H, W = sequence.shape |
| | picks: List[Tuple[int, int]] = [] |
| | if coords: |
| | for n, m in coords: |
| | if 0 <= n < H and 0 <= m < W and (n, m) not in picks: |
| | picks.append((n, m)) |
| | if len(picks) >= k: |
| | return picks[:k] |
| | remaining = max(0, k - len(picks)) |
| | if remaining == 0: |
| | return picks |
| | mag = sequence.abs().mean(dim=0) |
| | topk = torch.topk(mag.flatten(), k=min(remaining, H * W - len(picks))) |
| | for idx in topk.indices.tolist(): |
| | n = idx // W |
| | m = idx % W |
| | if (n, m) not in picks: |
| | picks.append((n, m)) |
| | if len(picks) == k: |
| | break |
| | return picks |
| |
|
| |
|
| | def fit_line(y: np.ndarray) -> Tuple[float, float, float]: |
| | x = np.arange(len(y)) |
| | A = np.vstack([x, np.ones_like(x)]).T |
| | sol, *_ = np.linalg.lstsq(A, y, rcond=None) |
| | slope, intercept = sol |
| | y_pred = slope * x + intercept |
| | ss_res = np.sum((y - y_pred) ** 2) |
| | ss_tot = np.sum((y - y.mean()) ** 2) + EPS |
| | r2 = 1.0 - ss_res / ss_tot |
| | return float(slope), float(intercept), float(r2) |
| |
|
| |
|
| | def plot_curves( |
| | sequence: torch.Tensor, |
| | picks: List[Tuple[int, int]], |
| | out_path: Path, |
| | title: str, |
| | ) -> None: |
| | if sequence.ndim != 3: |
| | raise ValueError("plot_curves expects a tensor with shape (T, H, W).") |
| | T = sequence.shape[0] |
| | times = np.arange(T) |
| | num_bins = len(picks) |
| | if num_bins == 0: |
| | raise ValueError("No bins were selected for plotting.") |
| | fig, axes = plt.subplots( |
| | num_bins, |
| | 2, |
| | figsize=(11, 3 * max(1, num_bins)), |
| | dpi=150, |
| | constrained_layout=True, |
| | ) |
| | fig.patch.set_facecolor("#0b0e11") |
| | axes = np.atleast_2d(axes) |
| | label_color = "#cbd5f5" |
| | title_color = "#f8fafc" |
| | for row, (n, m) in enumerate(picks): |
| | series = sequence[:, n, m] |
| | mag = series.abs().cpu().numpy() |
| | phase = torch.angle(series).cpu().numpy() |
| | phase = np.unwrap(phase) |
| | slope_mag, _, r2_mag = fit_line(mag) |
| | slope_phase, _, r2_phase = fit_line(phase) |
| |
|
| | ax_mag = axes[row, 0] |
| | ax_mag.set_facecolor("#111827") |
| | ax_mag.plot( |
| | times, |
| | mag, |
| | label=f"|H| slope={slope_mag:.3g}, R²={r2_mag:.2f}", |
| | color="#38bdf8", |
| | linewidth=2.2, |
| | ) |
| | ax_mag.fill_between(times, mag, color="#38bdf8", alpha=0.08) |
| | ax_mag.set_ylim(mag.min(), mag.max()) |
| | ax_mag.set_title(f"Bin (n={n}, m={m}) magnitude", color=title_color) |
| | ax_mag.set_xlabel("time index", color=label_color) |
| | ax_mag.set_ylabel("|H|", color=label_color) |
| | ax_mag.tick_params(colors=label_color) |
| | ax_mag.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) |
| | for spine in ax_mag.spines.values(): |
| | spine.set_color("#1f2937") |
| | legend_mag = ax_mag.legend(loc="upper left", fontsize=9) |
| | legend_mag.get_frame().set_facecolor("#111827") |
| | legend_mag.get_frame().set_alpha(0.6) |
| | for text in legend_mag.get_texts(): |
| | text.set_color(label_color) |
| |
|
| | ax_phase = axes[row, 1] |
| | ax_phase.set_facecolor("#111827") |
| | ax_phase.plot( |
| | times, |
| | phase, |
| | label=f"∠H slope={slope_phase:.3g}, R²={r2_phase:.2f}", |
| | color="#f87171", |
| | linewidth=2.2, |
| | ) |
| | ax_phase.set_title(f"Bin (n={n}, m={m}) phase (unwrapped)", color=title_color) |
| | ax_phase.set_xlabel("time index", color=label_color) |
| | ax_phase.set_ylabel("radians", color=label_color) |
| | ax_phase.tick_params(colors=label_color) |
| | ax_phase.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) |
| | for spine in ax_phase.spines.values(): |
| | spine.set_color("#1f2937") |
| | legend_phase = ax_phase.legend(loc="upper left", fontsize=9) |
| | legend_phase.get_frame().set_facecolor("#111827") |
| | legend_phase.get_frame().set_alpha(0.6) |
| | for text in legend_phase.get_texts(): |
| | text.set_color(label_color) |
| |
|
| | fig.suptitle(title, fontsize=12, color=title_color) |
| | out_path.parent.mkdir(parents=True, exist_ok=True) |
| | fig.savefig(out_path) |
| | plt.close(fig) |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | _configure_logging(args.verbose) |
| | configure_style() |
| | coords = _parse_coord_pairs(args.coords) |
| | sequence = load_sequence(args) |
| | picks = pick_bins(sequence, args.bins, coords) |
| | if not picks: |
| | raise RuntimeError("Unable to select any angle-delay bins for visualisation.") |
| | title = args.title or f"Angle-delay temporal curves | keep={args.keep_percentage:.2f} | norm={args.normalize}" |
| | plot_curves(sequence, picks, args.out_path, title) |
| | logger.info("Saved figure to %s", args.out_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|