lwm-temporal / examples /ad_temporal_evolution.py
wi-lab's picture
update
839dea4
import argparse
import logging
import sys
from pathlib import Path
from typing import List, Optional, Sequence, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch # type: ignore
# Make sure the package root is on sys.path when running the example directly.
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
try: # Optional HF-style dataset utilities (may be absent in this checkout)
from LWMTemporal.data import AngleDelayDatasetConfig, AngleDelaySequenceDataset # type: ignore
except ImportError: # pragma: no cover - keep script functional without data module
AngleDelayDatasetConfig = None # type: ignore
AngleDelaySequenceDataset = None # type: ignore
EPS = 1e-8
logger = logging.getLogger("ad_temporal_evolution")
def configure_style() -> None:
plt.style.use("dark_background")
plt.rcParams.update(
{
"figure.facecolor": "#0b0e11",
"axes.facecolor": "#0b0e11",
"axes.edgecolor": "#374151",
"axes.labelcolor": "#e5e7eb",
"axes.titleweight": "semibold",
"text.color": "#e5e7eb",
"xtick.color": "#9ca3af",
"ytick.color": "#9ca3af",
"grid.color": "#1f2937",
"figure.autolayout": False,
"font.size": 11,
"legend.frameon": False,
}
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Visualise how angle-delay bins evolve over time using the LWM-Temporal preprocessing stack.",
)
parser.add_argument(
"--data_path",
type=Path,
default=Path("examples/data/city_12_chiyoda_3p5_20_32_32.p"),
help="Path to a .p payload (dict with 'channel') or raw tensor file.",
)
parser.add_argument(
"--sample_idx",
type=int,
default=0,
help="Sample index to select when the payload has a batch dimension (S, T, H, W).",
)
parser.add_argument(
"--keep_percentage",
type=float,
default=0.25,
help="Fraction of strongest delay taps to keep when converting to angle-delay.",
)
parser.add_argument(
"--normalize",
choices=["none", "per_sample_rms", "global_rms"],
default="global_rms",
help="Normalization mode applied after the angle-delay transform.",
)
parser.add_argument(
"--bins",
type=int,
default=6,
help="Number of angle-delay bins to visualise (top-K by average magnitude).",
)
parser.add_argument(
"--coords",
type=int,
nargs="*",
help="Optional explicit bin coordinates supplied as n0 m0 n1 m1 ...",
)
parser.add_argument(
"--out_path",
type=Path,
default=Path("examples/data/figs/ad_temporal_evolution.png"),
help="Destination path for the saved figure.",
)
parser.add_argument(
"--max_time_steps",
type=int,
default=None,
help="Optional temporal truncation applied before preprocessing.",
)
parser.add_argument(
"--cache_dir",
type=Path,
default=Path("cache"),
help="Cache directory used by the dataset API.",
)
parser.add_argument(
"--no_cache",
action="store_true",
help="Disable caching when using the dataset API.",
)
parser.add_argument(
"--overwrite_cache",
action="store_true",
help="Overwrite cached tensors when using the dataset API.",
)
parser.add_argument(
"--phase_mode",
choices=["real_imag", "mag_phase"],
default="real_imag",
help="Phase representation expected by downstream models (dataset API).",
)
parser.add_argument(
"--patch_height",
type=int,
default=1,
help="Patch height provided to the dataset API (ignored when unavailable).",
)
parser.add_argument(
"--patch_width",
type=int,
default=1,
help="Patch width provided to the dataset API (ignored when unavailable).",
)
parser.add_argument(
"--title",
type=str,
default=None,
help="Optional custom figure title.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable debug logging for troubleshooting.",
)
return parser.parse_args()
def _configure_logging(verbose: bool) -> None:
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(level=level, format="[%(levelname)s] %(message)s")
def _parse_coord_pairs(raw: Optional[Sequence[int]]) -> Optional[List[Tuple[int, int]]]:
if not raw:
return None
if len(raw) % 2 != 0:
raise ValueError("coords must be provided as pairs: n0 m0 n1 m1 ...")
pairs = []
for i in range(0, len(raw), 2):
pairs.append((int(raw[i]), int(raw[i + 1])))
return pairs
def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor:
if tensor.is_complex():
return tensor.to(torch.complex64)
if tensor.ndim >= 1 and tensor.size(-1) == 2:
real = tensor[..., 0].float()
imag = tensor[..., 1].float()
return torch.complex(real, imag)
return torch.complex(tensor.float(), torch.zeros_like(tensor.float()))
def load_sequence(args: argparse.Namespace) -> torch.Tensor:
if AngleDelayDatasetConfig is None or AngleDelaySequenceDataset is None:
raise ImportError(
"LWMTemporal.data.datasets is required. Install the full LWMTemporal package to use this example.",
)
cfg = AngleDelayDatasetConfig(
raw_path=args.data_path,
keep_percentage=args.keep_percentage,
normalize=args.normalize,
cache_dir=args.cache_dir,
use_cache=not args.no_cache,
overwrite_cache=args.overwrite_cache,
snr_db=None,
noise_seed=None,
max_time_steps=args.max_time_steps,
patch_size=(args.patch_height, args.patch_width),
phase_mode=args.phase_mode,
)
dataset = AngleDelaySequenceDataset(cfg)
if len(dataset) == 0:
raise RuntimeError("AngleDelaySequenceDataset returned zero samples.")
idx = max(0, min(args.sample_idx, len(dataset) - 1))
sample = dataset[idx]
if isinstance(sample, dict):
if "sequence" in sample:
tensor = sample["sequence"]
elif "angle_delay" in sample:
tensor = sample["angle_delay"]
else:
raise KeyError("Dataset item missing 'sequence' or 'angle_delay' entries.")
else:
tensor = sample
tensor = _ensure_complex(torch.as_tensor(tensor))
if tensor.ndim == 4 and tensor.size(0) == 1:
tensor = tensor.squeeze(0)
if tensor.ndim != 3:
raise ValueError(f"Expected dataset sample with shape (T, H, W); received {tuple(tensor.shape)}")
logger.debug("Loaded sequence via dataset API with shape %s", tuple(tensor.shape))
return tensor
def pick_bins(
sequence: torch.Tensor,
k: int,
coords: Optional[List[Tuple[int, int]]],
) -> List[Tuple[int, int]]:
if sequence.ndim != 3:
raise ValueError(f"Expected angle-delay tensor with shape (T, H, W); got {tuple(sequence.shape)}")
_, H, W = sequence.shape
picks: List[Tuple[int, int]] = []
if coords:
for n, m in coords:
if 0 <= n < H and 0 <= m < W and (n, m) not in picks:
picks.append((n, m))
if len(picks) >= k:
return picks[:k]
remaining = max(0, k - len(picks))
if remaining == 0:
return picks
mag = sequence.abs().mean(dim=0)
topk = torch.topk(mag.flatten(), k=min(remaining, H * W - len(picks)))
for idx in topk.indices.tolist():
n = idx // W
m = idx % W
if (n, m) not in picks:
picks.append((n, m))
if len(picks) == k:
break
return picks
def fit_line(y: np.ndarray) -> Tuple[float, float, float]:
x = np.arange(len(y))
A = np.vstack([x, np.ones_like(x)]).T
sol, *_ = np.linalg.lstsq(A, y, rcond=None)
slope, intercept = sol
y_pred = slope * x + intercept
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - y.mean()) ** 2) + EPS
r2 = 1.0 - ss_res / ss_tot
return float(slope), float(intercept), float(r2)
def plot_curves(
sequence: torch.Tensor,
picks: List[Tuple[int, int]],
out_path: Path,
title: str,
) -> None:
if sequence.ndim != 3:
raise ValueError("plot_curves expects a tensor with shape (T, H, W).")
T = sequence.shape[0]
times = np.arange(T)
num_bins = len(picks)
if num_bins == 0:
raise ValueError("No bins were selected for plotting.")
fig, axes = plt.subplots(
num_bins,
2,
figsize=(11, 3 * max(1, num_bins)),
dpi=150,
constrained_layout=True,
)
fig.patch.set_facecolor("#0b0e11")
axes = np.atleast_2d(axes)
label_color = "#cbd5f5"
title_color = "#f8fafc"
for row, (n, m) in enumerate(picks):
series = sequence[:, n, m]
mag = series.abs().cpu().numpy()
phase = torch.angle(series).cpu().numpy()
phase = np.unwrap(phase)
slope_mag, _, r2_mag = fit_line(mag)
slope_phase, _, r2_phase = fit_line(phase)
ax_mag = axes[row, 0]
ax_mag.set_facecolor("#111827")
ax_mag.plot(
times,
mag,
label=f"|H| slope={slope_mag:.3g}, R²={r2_mag:.2f}",
color="#38bdf8",
linewidth=2.2,
)
ax_mag.fill_between(times, mag, color="#38bdf8", alpha=0.08)
ax_mag.set_ylim(mag.min(), mag.max())
ax_mag.set_title(f"Bin (n={n}, m={m}) magnitude", color=title_color)
ax_mag.set_xlabel("time index", color=label_color)
ax_mag.set_ylabel("|H|", color=label_color)
ax_mag.tick_params(colors=label_color)
ax_mag.grid(True, linestyle="--", linewidth=0.6, alpha=0.4)
for spine in ax_mag.spines.values():
spine.set_color("#1f2937")
legend_mag = ax_mag.legend(loc="upper left", fontsize=9)
legend_mag.get_frame().set_facecolor("#111827")
legend_mag.get_frame().set_alpha(0.6)
for text in legend_mag.get_texts():
text.set_color(label_color)
ax_phase = axes[row, 1]
ax_phase.set_facecolor("#111827")
ax_phase.plot(
times,
phase,
label=f"∠H slope={slope_phase:.3g}, R²={r2_phase:.2f}",
color="#f87171",
linewidth=2.2,
)
ax_phase.set_title(f"Bin (n={n}, m={m}) phase (unwrapped)", color=title_color)
ax_phase.set_xlabel("time index", color=label_color)
ax_phase.set_ylabel("radians", color=label_color)
ax_phase.tick_params(colors=label_color)
ax_phase.grid(True, linestyle="--", linewidth=0.6, alpha=0.4)
for spine in ax_phase.spines.values():
spine.set_color("#1f2937")
legend_phase = ax_phase.legend(loc="upper left", fontsize=9)
legend_phase.get_frame().set_facecolor("#111827")
legend_phase.get_frame().set_alpha(0.6)
for text in legend_phase.get_texts():
text.set_color(label_color)
fig.suptitle(title, fontsize=12, color=title_color)
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path)
plt.close(fig)
def main() -> None:
args = parse_args()
_configure_logging(args.verbose)
configure_style()
coords = _parse_coord_pairs(args.coords)
sequence = load_sequence(args)
picks = pick_bins(sequence, args.bins, coords)
if not picks:
raise RuntimeError("Unable to select any angle-delay bins for visualisation.")
title = args.title or f"Angle-delay temporal curves | keep={args.keep_percentage:.2f} | norm={args.normalize}"
plot_curves(sequence, picks, args.out_path, title)
logger.info("Saved figure to %s", args.out_path)
if __name__ == "__main__":
main()