from __future__ import annotations import dataclasses import math from pathlib import Path from typing import Iterable, List, Optional, Sequence import imageio.v2 as imageio import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np import torch @dataclasses.dataclass class AngleDelayConfig: """Configuration options for angle-delay processing.""" angle_range: tuple[float, float] = (-math.pi / 2, math.pi / 2) delay_range: tuple[float, float] = (0.0, 100.0) keep_percentage: float = 0.25 fps: int = 4 dpi: int = 120 num_bins: int = 6 output_dir: Path = Path("figs") def validate(self) -> None: if not 0.0 < self.keep_percentage <= 1.0: raise ValueError("keep_percentage must be in (0, 1]") if self.fps <= 0: raise ValueError("fps must be positive") if self.dpi <= 0: raise ValueError("dpi must be positive") if self.num_bins <= 0: raise ValueError("num_bins must be positive") class AngleDelayProcessor: """Project complex channels into the angle-delay domain and visualise them.""" def __init__(self, config: AngleDelayConfig | None = None) -> None: self.config = config or AngleDelayConfig() self.config.validate() # ------------------------------------------------------------------ # Core transforms # ------------------------------------------------------------------ @staticmethod def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor: if not torch.is_complex(tensor): raise TypeError("expected complex tensor") return tensor def forward(self, channel: torch.Tensor) -> torch.Tensor: channel = self._ensure_complex(channel) angle_domain = torch.fft.fft(channel, dim=1, norm="ortho") delay_domain = torch.fft.ifft(angle_domain, dim=2, norm="ortho") return delay_domain def inverse(self, angle_delay: torch.Tensor) -> torch.Tensor: angle_delay = self._ensure_complex(angle_delay) subcarrier = torch.fft.fft(angle_delay, dim=2, norm="ortho") antenna = torch.fft.ifft(subcarrier, dim=1, norm="ortho") return antenna # ------------------------------------------------------------------ # Truncation helpers and metrics # ------------------------------------------------------------------ def truncate_delay_bins(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: tensor = self._ensure_complex(tensor) if tensor.ndim != 3: raise ValueError("angle-delay tensor must have shape (T, N, M)") keep = max(1, int(round(tensor.size(-1) * self.config.keep_percentage))) truncated = tensor[..., :keep] padded = torch.zeros_like(tensor) padded[..., :keep] = truncated return truncated, padded @staticmethod def nmse(reference: torch.Tensor, reconstruction: torch.Tensor) -> float: reference = AngleDelayProcessor._ensure_complex(reference) reconstruction = AngleDelayProcessor._ensure_complex(reconstruction) mse = torch.mean(torch.abs(reference - reconstruction) ** 2) power = torch.mean(torch.abs(reference) ** 2).clamp_min(1e-12) return float(10.0 * torch.log10(mse / power)) def reconstruction_nmse(self, channel: torch.Tensor) -> tuple[float, float]: ad_full = self.forward(channel) recon_full = self.inverse(ad_full) nmse_full = self.nmse(channel, recon_full) truncated, padded = self.truncate_delay_bins(ad_full) recon_trunc = self.inverse(padded) nmse_trunc = self.nmse(channel, recon_trunc) return nmse_full, nmse_trunc # ------------------------------------------------------------------ # Visualisation helpers # ------------------------------------------------------------------ def save_angle_delay_gif( self, tensor: torch.Tensor, output_path: Path, fps: Optional[int] = None, show: bool = False, ) -> None: tensor = self._ensure_complex(tensor) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) magnitude = tensor.abs().cpu() vmin, vmax = float(magnitude.min()), float(magnitude.max()) if show: fig, ax = plt.subplots(figsize=(8, 6)) fig.patch.set_facecolor("#0b0e11") ax.set_facecolor("#0b0e11") ax.tick_params(colors="#cbd5f5") for spine in ax.spines.values(): spine.set_color("#374151") im = ax.imshow( magnitude[0].numpy(), cmap="magma", origin="lower", aspect="auto", extent=[*self.config.delay_range, *self.config.angle_range], vmin=vmin, vmax=vmax, ) ax.set_xlabel("Delay bins", color="#cbd5f5") ax.set_ylabel("Angle bins", color="#cbd5f5") cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.ax.yaxis.set_tick_params(color="#cbd5f5") plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") cbar.set_label("|H| (dB)", color="#cbd5f5") def animate(idx: int): im.set_array(magnitude[idx].numpy()) ax.set_title( f"Angle-Delay Intensity — Frame {idx}", color="#f8fafc", fontsize=12, fontweight="semibold", ) return (im,) # For interactive notebook usage, delegate to the generic animation helper # and return early (no GIF encoding here). self._save_animation(fig, animate, output_path, fps=fps, frames=magnitude.size(0), show=True) return # Non-interactive path: render each frame and encode a GIF on disk. frames: List[np.ndarray] = [] for frame_idx in range(magnitude.size(0)): fig, ax = plt.subplots(figsize=(8, 6)) fig.patch.set_facecolor("#0b0e11") ax.set_facecolor("#0b0e11") ax.tick_params(colors="#cbd5f5") for spine in ax.spines.values(): spine.set_color("#374151") im = ax.imshow( magnitude[frame_idx].numpy(), cmap="magma", # gray_r origin="lower", aspect="auto", extent=[*self.config.delay_range, *self.config.angle_range], vmin=vmin, vmax=vmax, ) ax.set_xlabel("Delay bins", color="#cbd5f5") ax.set_ylabel("Angle bins", color="#cbd5f5") ax.set_title( f"Angle-Delay Intensity — Frame {frame_idx}", color="#f8fafc", fontsize=12, fontweight="semibold", ) cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.ax.yaxis.set_tick_params(color="#cbd5f5") plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") cbar.set_label("|H| (dB)", color="#cbd5f5") fig.canvas.draw() frames.append(np.asarray(fig.canvas.buffer_rgba())) plt.close(fig) imageio.mimsave(output_path, frames, fps=fps or self.config.fps) def _save_animation( self, fig: plt.Figure, animate_fn, output_path: Path, fps: Optional[int] = None, dpi: Optional[int] = None, frames: Optional[int] = None, show: bool = False, ) -> None: anim = animation.FuncAnimation(fig, animate_fn, frames=frames) if show: from IPython.display import HTML, display # type: ignore html = anim.to_jshtml(fps=fps or self.config.fps) plt.close(fig) display(HTML(html)) else: output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) anim.save(output_path, writer="pillow", fps=fps or self.config.fps, dpi=dpi or self.config.dpi) plt.close(fig) def save_channel_animation(self, channel: torch.Tensor, output_path: Path, show: bool = False) -> None: channel = self._ensure_complex(channel) magnitude = channel.abs().cpu() vmin, vmax = float(magnitude.min()), float(magnitude.max()) fig, ax_mag = plt.subplots(figsize=(8, 6)) fig.patch.set_facecolor("#0b0e11") ax_mag.set_facecolor("#0b0e11") ax_mag.tick_params(colors="#cbd5f5") for spine in ax_mag.spines.values(): spine.set_color("#374151") mag_img = ax_mag.imshow( magnitude[0].numpy(), cmap="magma", origin="upper", aspect="auto", vmin=vmin, vmax=vmax, ) ax_mag.set_xlabel("Subcarrier", color="#cbd5f5") ax_mag.set_ylabel("Antenna", color="#cbd5f5") cbar = fig.colorbar(mag_img, ax=ax_mag, fraction=0.046, pad=0.04) cbar.ax.yaxis.set_tick_params(color="#cbd5f5") plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") cbar.set_label("|H| (linear)", color="#cbd5f5") def animate(idx: int): mag_img.set_array(magnitude[idx].numpy()) ax_mag.set_title( f"Channel Magnitude — Frame {idx}", color="#f8fafc", fontsize=12, fontweight="semibold", ) return (mag_img,) self._save_animation(fig, animate, output_path, frames=channel.size(0), show=show) def save_angle_delay_animation( self, tensor: torch.Tensor, output_path: Path, keep_percentage: Optional[float] = None, show: bool = False, ) -> None: tensor = self._ensure_complex(tensor) magnitude = tensor.abs().cpu() phase = torch.angle(tensor).cpu() keep_suffix = "" if keep_percentage is None else f" (keep={keep_percentage * 100:.0f}%)" fig, axes = plt.subplots(2, 2, figsize=(18, 10)) mag_ax, phase_ax, mag_line_ax, phase_line_ax = axes.flat mag_img = mag_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto") mag_ax.set_xlabel("Delay Bin") mag_ax.set_ylabel("Angle Bin") fig.colorbar(mag_img, ax=mag_ax, label="Magnitude") phase_img = phase_ax.imshow(phase[0].numpy(), cmap="twilight", origin="upper", aspect="auto", vmin=-math.pi, vmax=math.pi) phase_ax.set_xlabel("Delay Bin") phase_ax.set_ylabel("Angle Bin") fig.colorbar(phase_img, ax=phase_ax, label="Phase (rad)") temporal_mag = magnitude.mean(dim=(1, 2)) temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy()) mag_line, = mag_line_ax.plot([], [], "r-o", linewidth=2) phase_line, = phase_line_ax.plot([], [], "b-s", linewidth=2) for axis, label in ((mag_line_ax, "Average Magnitude"), (phase_line_ax, "Average Phase (rad)")): axis.set_xlabel("Frame") axis.set_ylabel(label) axis.set_xlim(0, tensor.size(0) - 1) axis.grid(True, alpha=0.3) def animate(idx: int): mag_img.set_array(magnitude[idx].numpy()) phase_img.set_array(phase[idx].numpy()) mag_ax.set_title(f"AD Magnitude – Frame {idx}{keep_suffix}") phase_ax.set_title(f"AD Phase – Frame {idx}{keep_suffix}") xs = np.arange(idx + 1) mag_line.set_data(xs, temporal_mag[: idx + 1].numpy()) phase_line.set_data(xs, temporal_phase[: idx + 1]) return mag_img, phase_img, mag_line, phase_line self._save_animation(fig, animate, output_path, show=show) def save_dominant_bin_animation( self, tensor: torch.Tensor, output_path: Path, threshold_ratio: float = 0.05, show: bool = False, ) -> None: tensor = self._ensure_complex(tensor) magnitude = tensor.abs().cpu() threshold = float(magnitude.max()) * threshold_ratio dominant_counts = (magnitude > threshold).sum(dim=(1, 2)).numpy() fig, (heat_ax, line_ax) = plt.subplots(1, 2, figsize=(16, 6)) heat_img = heat_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto") heat_ax.set_xlabel("Delay Bin") heat_ax.set_ylabel("Angle Bin") fig.colorbar(heat_img, ax=heat_ax, label="Magnitude") count_line, = line_ax.plot([], [], "r-s", linewidth=2) line_ax.set_xlabel("Frame") line_ax.set_ylabel("Dominant Bin Count") line_ax.set_xlim(0, tensor.size(0) - 1) line_ax.set_ylim(0, dominant_counts.max() * 1.1) line_ax.grid(True, alpha=0.3) def animate(idx: int): heat_img.set_array(magnitude[idx].numpy()) heat_ax.set_title(f"Magnitude – Frame {idx}") xs = np.arange(idx + 1) count_line.set_data(xs, dominant_counts[: idx + 1]) return heat_img, count_line self._save_animation(fig, animate, output_path, show=show) def save_bin_evolution_plot(self, tensor: torch.Tensor, output_path: Path, show: bool = False) -> None: tensor = self._ensure_complex(tensor) magnitude = tensor.abs() avg_mag = magnitude.mean(dim=0) flat_mag = avg_mag.flatten() # Use a compact, dark-mode visualization of the top-3 bins, similar to # the style in examples.ad_temporal_evolution.plot_curves. k = min(3, flat_mag.numel()) if k == 0: return _, indices = torch.topk(flat_mag, k) angle_indices = (indices // tensor.size(-1)).tolist() delay_indices = (indices % tensor.size(-1)).tolist() time_axis = np.arange(tensor.size(0)) fig, axes = plt.subplots( k, 2, figsize=(11, 3 * max(1, k)), dpi=150, constrained_layout=True, ) fig.patch.set_facecolor("#0b0e11") axes = np.atleast_2d(axes) label_color = "#cbd5f5" title_color = "#f8fafc" for row in range(k): series = tensor[:, angle_indices[row], delay_indices[row]] mag_series = torch.abs(series).cpu().numpy() phase_series = np.unwrap(torch.angle(series).cpu().numpy()) ax_mag, ax_phase = axes[row] # Magnitude subplot (dark mode) ax_mag.set_facecolor("#111827") ax_mag.plot( time_axis, mag_series, label="|H|", color="#38bdf8", linewidth=2.2, ) ax_mag.fill_between(time_axis, mag_series, color="#38bdf8", alpha=0.08) ax_mag.set_title( f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) magnitude", color=title_color, ) ax_mag.set_xlabel("time index", color=label_color) ax_mag.set_ylabel("|H|", color=label_color) ax_mag.tick_params(colors=label_color) ax_mag.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) for spine in ax_mag.spines.values(): spine.set_color("#1f2937") legend_mag = ax_mag.legend(loc="upper left", fontsize=9) legend_mag.get_frame().set_facecolor("#111827") legend_mag.get_frame().set_alpha(0.6) for text in legend_mag.get_texts(): text.set_color(label_color) # Phase subplot (dark mode) ax_phase.set_facecolor("#111827") ax_phase.plot( time_axis, phase_series, label="∠H", color="#f87171", linewidth=2.2, ) ax_phase.set_title( f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) phase (unwrapped)", color=title_color, ) ax_phase.set_xlabel("time index", color=label_color) ax_phase.set_ylabel("radians", color=label_color) ax_phase.tick_params(colors=label_color) ax_phase.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) for spine in ax_phase.spines.values(): spine.set_color("#1f2937") legend_phase = ax_phase.legend(loc="upper left", fontsize=9) legend_phase.get_frame().set_facecolor("#111827") legend_phase.get_frame().set_alpha(0.6) for text in legend_phase.get_texts(): text.set_color(label_color) fig.suptitle("Top-3 angle–delay bins over time", fontsize=12, color=title_color) if show: plt.show() else: output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path, dpi=self.config.dpi, bbox_inches="tight") plt.close(fig)