lwm-temporal / LWMTemporal /data /angle_delay.py
wi-lab's picture
update
839dea4
from __future__ import annotations
import dataclasses
import math
from pathlib import Path
from typing import Iterable, List, Optional, Sequence
import imageio.v2 as imageio
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
@dataclasses.dataclass
class AngleDelayConfig:
"""Configuration options for angle-delay processing."""
angle_range: tuple[float, float] = (-math.pi / 2, math.pi / 2)
delay_range: tuple[float, float] = (0.0, 100.0)
keep_percentage: float = 0.25
fps: int = 4
dpi: int = 120
num_bins: int = 6
output_dir: Path = Path("figs")
def validate(self) -> None:
if not 0.0 < self.keep_percentage <= 1.0:
raise ValueError("keep_percentage must be in (0, 1]")
if self.fps <= 0:
raise ValueError("fps must be positive")
if self.dpi <= 0:
raise ValueError("dpi must be positive")
if self.num_bins <= 0:
raise ValueError("num_bins must be positive")
class AngleDelayProcessor:
"""Project complex channels into the angle-delay domain and visualise them."""
def __init__(self, config: AngleDelayConfig | None = None) -> None:
self.config = config or AngleDelayConfig()
self.config.validate()
# ------------------------------------------------------------------
# Core transforms
# ------------------------------------------------------------------
@staticmethod
def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor:
if not torch.is_complex(tensor):
raise TypeError("expected complex tensor")
return tensor
def forward(self, channel: torch.Tensor) -> torch.Tensor:
channel = self._ensure_complex(channel)
angle_domain = torch.fft.fft(channel, dim=1, norm="ortho")
delay_domain = torch.fft.ifft(angle_domain, dim=2, norm="ortho")
return delay_domain
def inverse(self, angle_delay: torch.Tensor) -> torch.Tensor:
angle_delay = self._ensure_complex(angle_delay)
subcarrier = torch.fft.fft(angle_delay, dim=2, norm="ortho")
antenna = torch.fft.ifft(subcarrier, dim=1, norm="ortho")
return antenna
# ------------------------------------------------------------------
# Truncation helpers and metrics
# ------------------------------------------------------------------
def truncate_delay_bins(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
tensor = self._ensure_complex(tensor)
if tensor.ndim != 3:
raise ValueError("angle-delay tensor must have shape (T, N, M)")
keep = max(1, int(round(tensor.size(-1) * self.config.keep_percentage)))
truncated = tensor[..., :keep]
padded = torch.zeros_like(tensor)
padded[..., :keep] = truncated
return truncated, padded
@staticmethod
def nmse(reference: torch.Tensor, reconstruction: torch.Tensor) -> float:
reference = AngleDelayProcessor._ensure_complex(reference)
reconstruction = AngleDelayProcessor._ensure_complex(reconstruction)
mse = torch.mean(torch.abs(reference - reconstruction) ** 2)
power = torch.mean(torch.abs(reference) ** 2).clamp_min(1e-12)
return float(10.0 * torch.log10(mse / power))
def reconstruction_nmse(self, channel: torch.Tensor) -> tuple[float, float]:
ad_full = self.forward(channel)
recon_full = self.inverse(ad_full)
nmse_full = self.nmse(channel, recon_full)
truncated, padded = self.truncate_delay_bins(ad_full)
recon_trunc = self.inverse(padded)
nmse_trunc = self.nmse(channel, recon_trunc)
return nmse_full, nmse_trunc
# ------------------------------------------------------------------
# Visualisation helpers
# ------------------------------------------------------------------
def save_angle_delay_gif(
self,
tensor: torch.Tensor,
output_path: Path,
fps: Optional[int] = None,
show: bool = False,
) -> None:
tensor = self._ensure_complex(tensor)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
magnitude = tensor.abs().cpu()
vmin, vmax = float(magnitude.min()), float(magnitude.max())
if show:
fig, ax = plt.subplots(figsize=(8, 6))
fig.patch.set_facecolor("#0b0e11")
ax.set_facecolor("#0b0e11")
ax.tick_params(colors="#cbd5f5")
for spine in ax.spines.values():
spine.set_color("#374151")
im = ax.imshow(
magnitude[0].numpy(),
cmap="magma",
origin="lower",
aspect="auto",
extent=[*self.config.delay_range, *self.config.angle_range],
vmin=vmin,
vmax=vmax,
)
ax.set_xlabel("Delay bins", color="#cbd5f5")
ax.set_ylabel("Angle bins", color="#cbd5f5")
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.ax.yaxis.set_tick_params(color="#cbd5f5")
plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5")
cbar.set_label("|H| (dB)", color="#cbd5f5")
def animate(idx: int):
im.set_array(magnitude[idx].numpy())
ax.set_title(
f"Angle-Delay Intensity — Frame {idx}",
color="#f8fafc",
fontsize=12,
fontweight="semibold",
)
return (im,)
# For interactive notebook usage, delegate to the generic animation helper
# and return early (no GIF encoding here).
self._save_animation(fig, animate, output_path, fps=fps, frames=magnitude.size(0), show=True)
return
# Non-interactive path: render each frame and encode a GIF on disk.
frames: List[np.ndarray] = []
for frame_idx in range(magnitude.size(0)):
fig, ax = plt.subplots(figsize=(8, 6))
fig.patch.set_facecolor("#0b0e11")
ax.set_facecolor("#0b0e11")
ax.tick_params(colors="#cbd5f5")
for spine in ax.spines.values():
spine.set_color("#374151")
im = ax.imshow(
magnitude[frame_idx].numpy(),
cmap="magma", # gray_r
origin="lower",
aspect="auto",
extent=[*self.config.delay_range, *self.config.angle_range],
vmin=vmin,
vmax=vmax,
)
ax.set_xlabel("Delay bins", color="#cbd5f5")
ax.set_ylabel("Angle bins", color="#cbd5f5")
ax.set_title(
f"Angle-Delay Intensity — Frame {frame_idx}",
color="#f8fafc",
fontsize=12,
fontweight="semibold",
)
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.ax.yaxis.set_tick_params(color="#cbd5f5")
plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5")
cbar.set_label("|H| (dB)", color="#cbd5f5")
fig.canvas.draw()
frames.append(np.asarray(fig.canvas.buffer_rgba()))
plt.close(fig)
imageio.mimsave(output_path, frames, fps=fps or self.config.fps)
def _save_animation(
self,
fig: plt.Figure,
animate_fn,
output_path: Path,
fps: Optional[int] = None,
dpi: Optional[int] = None,
frames: Optional[int] = None,
show: bool = False,
) -> None:
anim = animation.FuncAnimation(fig, animate_fn, frames=frames)
if show:
from IPython.display import HTML, display # type: ignore
html = anim.to_jshtml(fps=fps or self.config.fps)
plt.close(fig)
display(HTML(html))
else:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
anim.save(output_path, writer="pillow", fps=fps or self.config.fps, dpi=dpi or self.config.dpi)
plt.close(fig)
def save_channel_animation(self, channel: torch.Tensor, output_path: Path, show: bool = False) -> None:
channel = self._ensure_complex(channel)
magnitude = channel.abs().cpu()
vmin, vmax = float(magnitude.min()), float(magnitude.max())
fig, ax_mag = plt.subplots(figsize=(8, 6))
fig.patch.set_facecolor("#0b0e11")
ax_mag.set_facecolor("#0b0e11")
ax_mag.tick_params(colors="#cbd5f5")
for spine in ax_mag.spines.values():
spine.set_color("#374151")
mag_img = ax_mag.imshow(
magnitude[0].numpy(),
cmap="magma",
origin="upper",
aspect="auto",
vmin=vmin,
vmax=vmax,
)
ax_mag.set_xlabel("Subcarrier", color="#cbd5f5")
ax_mag.set_ylabel("Antenna", color="#cbd5f5")
cbar = fig.colorbar(mag_img, ax=ax_mag, fraction=0.046, pad=0.04)
cbar.ax.yaxis.set_tick_params(color="#cbd5f5")
plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5")
cbar.set_label("|H| (linear)", color="#cbd5f5")
def animate(idx: int):
mag_img.set_array(magnitude[idx].numpy())
ax_mag.set_title(
f"Channel Magnitude — Frame {idx}",
color="#f8fafc",
fontsize=12,
fontweight="semibold",
)
return (mag_img,)
self._save_animation(fig, animate, output_path, frames=channel.size(0), show=show)
def save_angle_delay_animation(
self,
tensor: torch.Tensor,
output_path: Path,
keep_percentage: Optional[float] = None,
show: bool = False,
) -> None:
tensor = self._ensure_complex(tensor)
magnitude = tensor.abs().cpu()
phase = torch.angle(tensor).cpu()
keep_suffix = "" if keep_percentage is None else f" (keep={keep_percentage * 100:.0f}%)"
fig, axes = plt.subplots(2, 2, figsize=(18, 10))
mag_ax, phase_ax, mag_line_ax, phase_line_ax = axes.flat
mag_img = mag_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto")
mag_ax.set_xlabel("Delay Bin")
mag_ax.set_ylabel("Angle Bin")
fig.colorbar(mag_img, ax=mag_ax, label="Magnitude")
phase_img = phase_ax.imshow(phase[0].numpy(), cmap="twilight", origin="upper", aspect="auto", vmin=-math.pi, vmax=math.pi)
phase_ax.set_xlabel("Delay Bin")
phase_ax.set_ylabel("Angle Bin")
fig.colorbar(phase_img, ax=phase_ax, label="Phase (rad)")
temporal_mag = magnitude.mean(dim=(1, 2))
temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy())
mag_line, = mag_line_ax.plot([], [], "r-o", linewidth=2)
phase_line, = phase_line_ax.plot([], [], "b-s", linewidth=2)
for axis, label in ((mag_line_ax, "Average Magnitude"), (phase_line_ax, "Average Phase (rad)")):
axis.set_xlabel("Frame")
axis.set_ylabel(label)
axis.set_xlim(0, tensor.size(0) - 1)
axis.grid(True, alpha=0.3)
def animate(idx: int):
mag_img.set_array(magnitude[idx].numpy())
phase_img.set_array(phase[idx].numpy())
mag_ax.set_title(f"AD Magnitude – Frame {idx}{keep_suffix}")
phase_ax.set_title(f"AD Phase – Frame {idx}{keep_suffix}")
xs = np.arange(idx + 1)
mag_line.set_data(xs, temporal_mag[: idx + 1].numpy())
phase_line.set_data(xs, temporal_phase[: idx + 1])
return mag_img, phase_img, mag_line, phase_line
self._save_animation(fig, animate, output_path, show=show)
def save_dominant_bin_animation(
self,
tensor: torch.Tensor,
output_path: Path,
threshold_ratio: float = 0.05,
show: bool = False,
) -> None:
tensor = self._ensure_complex(tensor)
magnitude = tensor.abs().cpu()
threshold = float(magnitude.max()) * threshold_ratio
dominant_counts = (magnitude > threshold).sum(dim=(1, 2)).numpy()
fig, (heat_ax, line_ax) = plt.subplots(1, 2, figsize=(16, 6))
heat_img = heat_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto")
heat_ax.set_xlabel("Delay Bin")
heat_ax.set_ylabel("Angle Bin")
fig.colorbar(heat_img, ax=heat_ax, label="Magnitude")
count_line, = line_ax.plot([], [], "r-s", linewidth=2)
line_ax.set_xlabel("Frame")
line_ax.set_ylabel("Dominant Bin Count")
line_ax.set_xlim(0, tensor.size(0) - 1)
line_ax.set_ylim(0, dominant_counts.max() * 1.1)
line_ax.grid(True, alpha=0.3)
def animate(idx: int):
heat_img.set_array(magnitude[idx].numpy())
heat_ax.set_title(f"Magnitude – Frame {idx}")
xs = np.arange(idx + 1)
count_line.set_data(xs, dominant_counts[: idx + 1])
return heat_img, count_line
self._save_animation(fig, animate, output_path, show=show)
def save_bin_evolution_plot(self, tensor: torch.Tensor, output_path: Path, show: bool = False) -> None:
tensor = self._ensure_complex(tensor)
magnitude = tensor.abs()
avg_mag = magnitude.mean(dim=0)
flat_mag = avg_mag.flatten()
# Use a compact, dark-mode visualization of the top-3 bins, similar to
# the style in examples.ad_temporal_evolution.plot_curves.
k = min(3, flat_mag.numel())
if k == 0:
return
_, indices = torch.topk(flat_mag, k)
angle_indices = (indices // tensor.size(-1)).tolist()
delay_indices = (indices % tensor.size(-1)).tolist()
time_axis = np.arange(tensor.size(0))
fig, axes = plt.subplots(
k,
2,
figsize=(11, 3 * max(1, k)),
dpi=150,
constrained_layout=True,
)
fig.patch.set_facecolor("#0b0e11")
axes = np.atleast_2d(axes)
label_color = "#cbd5f5"
title_color = "#f8fafc"
for row in range(k):
series = tensor[:, angle_indices[row], delay_indices[row]]
mag_series = torch.abs(series).cpu().numpy()
phase_series = np.unwrap(torch.angle(series).cpu().numpy())
ax_mag, ax_phase = axes[row]
# Magnitude subplot (dark mode)
ax_mag.set_facecolor("#111827")
ax_mag.plot(
time_axis,
mag_series,
label="|H|",
color="#38bdf8",
linewidth=2.2,
)
ax_mag.fill_between(time_axis, mag_series, color="#38bdf8", alpha=0.08)
ax_mag.set_title(
f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) magnitude",
color=title_color,
)
ax_mag.set_xlabel("time index", color=label_color)
ax_mag.set_ylabel("|H|", color=label_color)
ax_mag.tick_params(colors=label_color)
ax_mag.grid(True, linestyle="--", linewidth=0.6, alpha=0.4)
for spine in ax_mag.spines.values():
spine.set_color("#1f2937")
legend_mag = ax_mag.legend(loc="upper left", fontsize=9)
legend_mag.get_frame().set_facecolor("#111827")
legend_mag.get_frame().set_alpha(0.6)
for text in legend_mag.get_texts():
text.set_color(label_color)
# Phase subplot (dark mode)
ax_phase.set_facecolor("#111827")
ax_phase.plot(
time_axis,
phase_series,
label="∠H",
color="#f87171",
linewidth=2.2,
)
ax_phase.set_title(
f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) phase (unwrapped)",
color=title_color,
)
ax_phase.set_xlabel("time index", color=label_color)
ax_phase.set_ylabel("radians", color=label_color)
ax_phase.tick_params(colors=label_color)
ax_phase.grid(True, linestyle="--", linewidth=0.6, alpha=0.4)
for spine in ax_phase.spines.values():
spine.set_color("#1f2937")
legend_phase = ax_phase.legend(loc="upper left", fontsize=9)
legend_phase.get_frame().set_facecolor("#111827")
legend_phase.get_frame().set_alpha(0.6)
for text in legend_phase.get_texts():
text.set_color(label_color)
fig.suptitle("Top-3 angle–delay bins over time", fontsize=12, color=title_color)
if show:
plt.show()
else:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, dpi=self.config.dpi, bbox_inches="tight")
plt.close(fig)