| | from __future__ import annotations |
| |
|
| | import dataclasses |
| | import math |
| | from pathlib import Path |
| | from typing import Iterable, List, Optional, Sequence |
| |
|
| | import imageio.v2 as imageio |
| | import matplotlib.animation as animation |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class AngleDelayConfig: |
| | """Configuration options for angle-delay processing.""" |
| |
|
| | angle_range: tuple[float, float] = (-math.pi / 2, math.pi / 2) |
| | delay_range: tuple[float, float] = (0.0, 100.0) |
| | keep_percentage: float = 0.25 |
| | fps: int = 4 |
| | dpi: int = 120 |
| | num_bins: int = 6 |
| | output_dir: Path = Path("figs") |
| |
|
| | def validate(self) -> None: |
| | if not 0.0 < self.keep_percentage <= 1.0: |
| | raise ValueError("keep_percentage must be in (0, 1]") |
| | if self.fps <= 0: |
| | raise ValueError("fps must be positive") |
| | if self.dpi <= 0: |
| | raise ValueError("dpi must be positive") |
| | if self.num_bins <= 0: |
| | raise ValueError("num_bins must be positive") |
| |
|
| |
|
| | class AngleDelayProcessor: |
| | """Project complex channels into the angle-delay domain and visualise them.""" |
| |
|
| | def __init__(self, config: AngleDelayConfig | None = None) -> None: |
| | self.config = config or AngleDelayConfig() |
| | self.config.validate() |
| |
|
| | |
| | |
| | |
| | @staticmethod |
| | def _ensure_complex(tensor: torch.Tensor) -> torch.Tensor: |
| | if not torch.is_complex(tensor): |
| | raise TypeError("expected complex tensor") |
| | return tensor |
| |
|
| | def forward(self, channel: torch.Tensor) -> torch.Tensor: |
| | channel = self._ensure_complex(channel) |
| | angle_domain = torch.fft.fft(channel, dim=1, norm="ortho") |
| | delay_domain = torch.fft.ifft(angle_domain, dim=2, norm="ortho") |
| | return delay_domain |
| |
|
| | def inverse(self, angle_delay: torch.Tensor) -> torch.Tensor: |
| | angle_delay = self._ensure_complex(angle_delay) |
| | subcarrier = torch.fft.fft(angle_delay, dim=2, norm="ortho") |
| | antenna = torch.fft.ifft(subcarrier, dim=1, norm="ortho") |
| | return antenna |
| |
|
| | |
| | |
| | |
| | def truncate_delay_bins(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | tensor = self._ensure_complex(tensor) |
| | if tensor.ndim != 3: |
| | raise ValueError("angle-delay tensor must have shape (T, N, M)") |
| | keep = max(1, int(round(tensor.size(-1) * self.config.keep_percentage))) |
| | truncated = tensor[..., :keep] |
| | padded = torch.zeros_like(tensor) |
| | padded[..., :keep] = truncated |
| | return truncated, padded |
| |
|
| | @staticmethod |
| | def nmse(reference: torch.Tensor, reconstruction: torch.Tensor) -> float: |
| | reference = AngleDelayProcessor._ensure_complex(reference) |
| | reconstruction = AngleDelayProcessor._ensure_complex(reconstruction) |
| | mse = torch.mean(torch.abs(reference - reconstruction) ** 2) |
| | power = torch.mean(torch.abs(reference) ** 2).clamp_min(1e-12) |
| | return float(10.0 * torch.log10(mse / power)) |
| |
|
| | def reconstruction_nmse(self, channel: torch.Tensor) -> tuple[float, float]: |
| | ad_full = self.forward(channel) |
| | recon_full = self.inverse(ad_full) |
| | nmse_full = self.nmse(channel, recon_full) |
| | truncated, padded = self.truncate_delay_bins(ad_full) |
| | recon_trunc = self.inverse(padded) |
| | nmse_trunc = self.nmse(channel, recon_trunc) |
| | return nmse_full, nmse_trunc |
| |
|
| | |
| | |
| | |
| | def save_angle_delay_gif( |
| | self, |
| | tensor: torch.Tensor, |
| | output_path: Path, |
| | fps: Optional[int] = None, |
| | show: bool = False, |
| | ) -> None: |
| | tensor = self._ensure_complex(tensor) |
| | output_path = Path(output_path) |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | magnitude = tensor.abs().cpu() |
| | vmin, vmax = float(magnitude.min()), float(magnitude.max()) |
| | if show: |
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | fig.patch.set_facecolor("#0b0e11") |
| | ax.set_facecolor("#0b0e11") |
| | ax.tick_params(colors="#cbd5f5") |
| | for spine in ax.spines.values(): |
| | spine.set_color("#374151") |
| | im = ax.imshow( |
| | magnitude[0].numpy(), |
| | cmap="magma", |
| | origin="lower", |
| | aspect="auto", |
| | extent=[*self.config.delay_range, *self.config.angle_range], |
| | vmin=vmin, |
| | vmax=vmax, |
| | ) |
| | ax.set_xlabel("Delay bins", color="#cbd5f5") |
| | ax.set_ylabel("Angle bins", color="#cbd5f5") |
| | cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| | cbar.ax.yaxis.set_tick_params(color="#cbd5f5") |
| | plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") |
| | cbar.set_label("|H| (dB)", color="#cbd5f5") |
| |
|
| | def animate(idx: int): |
| | im.set_array(magnitude[idx].numpy()) |
| | ax.set_title( |
| | f"Angle-Delay Intensity — Frame {idx}", |
| | color="#f8fafc", |
| | fontsize=12, |
| | fontweight="semibold", |
| | ) |
| | return (im,) |
| |
|
| | |
| | |
| | self._save_animation(fig, animate, output_path, fps=fps, frames=magnitude.size(0), show=True) |
| | return |
| |
|
| | |
| | frames: List[np.ndarray] = [] |
| | for frame_idx in range(magnitude.size(0)): |
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | fig.patch.set_facecolor("#0b0e11") |
| | ax.set_facecolor("#0b0e11") |
| | ax.tick_params(colors="#cbd5f5") |
| | for spine in ax.spines.values(): |
| | spine.set_color("#374151") |
| | im = ax.imshow( |
| | magnitude[frame_idx].numpy(), |
| | cmap="magma", |
| | origin="lower", |
| | aspect="auto", |
| | extent=[*self.config.delay_range, *self.config.angle_range], |
| | vmin=vmin, |
| | vmax=vmax, |
| | ) |
| | ax.set_xlabel("Delay bins", color="#cbd5f5") |
| | ax.set_ylabel("Angle bins", color="#cbd5f5") |
| | ax.set_title( |
| | f"Angle-Delay Intensity — Frame {frame_idx}", |
| | color="#f8fafc", |
| | fontsize=12, |
| | fontweight="semibold", |
| | ) |
| | cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| | cbar.ax.yaxis.set_tick_params(color="#cbd5f5") |
| | plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") |
| | cbar.set_label("|H| (dB)", color="#cbd5f5") |
| | fig.canvas.draw() |
| | frames.append(np.asarray(fig.canvas.buffer_rgba())) |
| | plt.close(fig) |
| |
|
| | imageio.mimsave(output_path, frames, fps=fps or self.config.fps) |
| |
|
| | def _save_animation( |
| | self, |
| | fig: plt.Figure, |
| | animate_fn, |
| | output_path: Path, |
| | fps: Optional[int] = None, |
| | dpi: Optional[int] = None, |
| | frames: Optional[int] = None, |
| | show: bool = False, |
| | ) -> None: |
| | anim = animation.FuncAnimation(fig, animate_fn, frames=frames) |
| | if show: |
| | from IPython.display import HTML, display |
| |
|
| | html = anim.to_jshtml(fps=fps or self.config.fps) |
| | plt.close(fig) |
| | display(HTML(html)) |
| | else: |
| | output_path = Path(output_path) |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | anim.save(output_path, writer="pillow", fps=fps or self.config.fps, dpi=dpi or self.config.dpi) |
| | plt.close(fig) |
| |
|
| | def save_channel_animation(self, channel: torch.Tensor, output_path: Path, show: bool = False) -> None: |
| | channel = self._ensure_complex(channel) |
| | magnitude = channel.abs().cpu() |
| | vmin, vmax = float(magnitude.min()), float(magnitude.max()) |
| |
|
| | fig, ax_mag = plt.subplots(figsize=(8, 6)) |
| | fig.patch.set_facecolor("#0b0e11") |
| | ax_mag.set_facecolor("#0b0e11") |
| | ax_mag.tick_params(colors="#cbd5f5") |
| | for spine in ax_mag.spines.values(): |
| | spine.set_color("#374151") |
| | mag_img = ax_mag.imshow( |
| | magnitude[0].numpy(), |
| | cmap="magma", |
| | origin="upper", |
| | aspect="auto", |
| | vmin=vmin, |
| | vmax=vmax, |
| | ) |
| | ax_mag.set_xlabel("Subcarrier", color="#cbd5f5") |
| | ax_mag.set_ylabel("Antenna", color="#cbd5f5") |
| | cbar = fig.colorbar(mag_img, ax=ax_mag, fraction=0.046, pad=0.04) |
| | cbar.ax.yaxis.set_tick_params(color="#cbd5f5") |
| | plt.setp(cbar.ax.get_yticklabels(), color="#cbd5f5") |
| | cbar.set_label("|H| (linear)", color="#cbd5f5") |
| |
|
| | def animate(idx: int): |
| | mag_img.set_array(magnitude[idx].numpy()) |
| | ax_mag.set_title( |
| | f"Channel Magnitude — Frame {idx}", |
| | color="#f8fafc", |
| | fontsize=12, |
| | fontweight="semibold", |
| | ) |
| | return (mag_img,) |
| |
|
| | self._save_animation(fig, animate, output_path, frames=channel.size(0), show=show) |
| |
|
| | def save_angle_delay_animation( |
| | self, |
| | tensor: torch.Tensor, |
| | output_path: Path, |
| | keep_percentage: Optional[float] = None, |
| | show: bool = False, |
| | ) -> None: |
| | tensor = self._ensure_complex(tensor) |
| | magnitude = tensor.abs().cpu() |
| | phase = torch.angle(tensor).cpu() |
| | keep_suffix = "" if keep_percentage is None else f" (keep={keep_percentage * 100:.0f}%)" |
| |
|
| | fig, axes = plt.subplots(2, 2, figsize=(18, 10)) |
| | mag_ax, phase_ax, mag_line_ax, phase_line_ax = axes.flat |
| | mag_img = mag_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto") |
| | mag_ax.set_xlabel("Delay Bin") |
| | mag_ax.set_ylabel("Angle Bin") |
| | fig.colorbar(mag_img, ax=mag_ax, label="Magnitude") |
| |
|
| | phase_img = phase_ax.imshow(phase[0].numpy(), cmap="twilight", origin="upper", aspect="auto", vmin=-math.pi, vmax=math.pi) |
| | phase_ax.set_xlabel("Delay Bin") |
| | phase_ax.set_ylabel("Angle Bin") |
| | fig.colorbar(phase_img, ax=phase_ax, label="Phase (rad)") |
| |
|
| | temporal_mag = magnitude.mean(dim=(1, 2)) |
| | temporal_phase = np.unwrap(phase.mean(dim=(1, 2)).numpy()) |
| | mag_line, = mag_line_ax.plot([], [], "r-o", linewidth=2) |
| | phase_line, = phase_line_ax.plot([], [], "b-s", linewidth=2) |
| |
|
| | for axis, label in ((mag_line_ax, "Average Magnitude"), (phase_line_ax, "Average Phase (rad)")): |
| | axis.set_xlabel("Frame") |
| | axis.set_ylabel(label) |
| | axis.set_xlim(0, tensor.size(0) - 1) |
| | axis.grid(True, alpha=0.3) |
| |
|
| | def animate(idx: int): |
| | mag_img.set_array(magnitude[idx].numpy()) |
| | phase_img.set_array(phase[idx].numpy()) |
| | mag_ax.set_title(f"AD Magnitude – Frame {idx}{keep_suffix}") |
| | phase_ax.set_title(f"AD Phase – Frame {idx}{keep_suffix}") |
| | xs = np.arange(idx + 1) |
| | mag_line.set_data(xs, temporal_mag[: idx + 1].numpy()) |
| | phase_line.set_data(xs, temporal_phase[: idx + 1]) |
| | return mag_img, phase_img, mag_line, phase_line |
| |
|
| | self._save_animation(fig, animate, output_path, show=show) |
| |
|
| | def save_dominant_bin_animation( |
| | self, |
| | tensor: torch.Tensor, |
| | output_path: Path, |
| | threshold_ratio: float = 0.05, |
| | show: bool = False, |
| | ) -> None: |
| | tensor = self._ensure_complex(tensor) |
| | magnitude = tensor.abs().cpu() |
| | threshold = float(magnitude.max()) * threshold_ratio |
| | dominant_counts = (magnitude > threshold).sum(dim=(1, 2)).numpy() |
| |
|
| | fig, (heat_ax, line_ax) = plt.subplots(1, 2, figsize=(16, 6)) |
| | heat_img = heat_ax.imshow(magnitude[0].numpy(), cmap="gray_r", origin="upper", aspect="auto") |
| | heat_ax.set_xlabel("Delay Bin") |
| | heat_ax.set_ylabel("Angle Bin") |
| | fig.colorbar(heat_img, ax=heat_ax, label="Magnitude") |
| |
|
| | count_line, = line_ax.plot([], [], "r-s", linewidth=2) |
| | line_ax.set_xlabel("Frame") |
| | line_ax.set_ylabel("Dominant Bin Count") |
| | line_ax.set_xlim(0, tensor.size(0) - 1) |
| | line_ax.set_ylim(0, dominant_counts.max() * 1.1) |
| | line_ax.grid(True, alpha=0.3) |
| |
|
| | def animate(idx: int): |
| | heat_img.set_array(magnitude[idx].numpy()) |
| | heat_ax.set_title(f"Magnitude – Frame {idx}") |
| | xs = np.arange(idx + 1) |
| | count_line.set_data(xs, dominant_counts[: idx + 1]) |
| | return heat_img, count_line |
| |
|
| | self._save_animation(fig, animate, output_path, show=show) |
| |
|
| | def save_bin_evolution_plot(self, tensor: torch.Tensor, output_path: Path, show: bool = False) -> None: |
| | tensor = self._ensure_complex(tensor) |
| | magnitude = tensor.abs() |
| | avg_mag = magnitude.mean(dim=0) |
| | flat_mag = avg_mag.flatten() |
| | |
| | |
| | k = min(3, flat_mag.numel()) |
| | if k == 0: |
| | return |
| | _, indices = torch.topk(flat_mag, k) |
| | angle_indices = (indices // tensor.size(-1)).tolist() |
| | delay_indices = (indices % tensor.size(-1)).tolist() |
| |
|
| | time_axis = np.arange(tensor.size(0)) |
| | fig, axes = plt.subplots( |
| | k, |
| | 2, |
| | figsize=(11, 3 * max(1, k)), |
| | dpi=150, |
| | constrained_layout=True, |
| | ) |
| | fig.patch.set_facecolor("#0b0e11") |
| | axes = np.atleast_2d(axes) |
| | label_color = "#cbd5f5" |
| | title_color = "#f8fafc" |
| |
|
| | for row in range(k): |
| | series = tensor[:, angle_indices[row], delay_indices[row]] |
| | mag_series = torch.abs(series).cpu().numpy() |
| | phase_series = np.unwrap(torch.angle(series).cpu().numpy()) |
| |
|
| | ax_mag, ax_phase = axes[row] |
| |
|
| | |
| | ax_mag.set_facecolor("#111827") |
| | ax_mag.plot( |
| | time_axis, |
| | mag_series, |
| | label="|H|", |
| | color="#38bdf8", |
| | linewidth=2.2, |
| | ) |
| | ax_mag.fill_between(time_axis, mag_series, color="#38bdf8", alpha=0.08) |
| | ax_mag.set_title( |
| | f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) magnitude", |
| | color=title_color, |
| | ) |
| | ax_mag.set_xlabel("time index", color=label_color) |
| | ax_mag.set_ylabel("|H|", color=label_color) |
| | ax_mag.tick_params(colors=label_color) |
| | ax_mag.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) |
| | for spine in ax_mag.spines.values(): |
| | spine.set_color("#1f2937") |
| | legend_mag = ax_mag.legend(loc="upper left", fontsize=9) |
| | legend_mag.get_frame().set_facecolor("#111827") |
| | legend_mag.get_frame().set_alpha(0.6) |
| | for text in legend_mag.get_texts(): |
| | text.set_color(label_color) |
| |
|
| | |
| | ax_phase.set_facecolor("#111827") |
| | ax_phase.plot( |
| | time_axis, |
| | phase_series, |
| | label="∠H", |
| | color="#f87171", |
| | linewidth=2.2, |
| | ) |
| | ax_phase.set_title( |
| | f"Bin (angle={angle_indices[row]}, delay={delay_indices[row]}) phase (unwrapped)", |
| | color=title_color, |
| | ) |
| | ax_phase.set_xlabel("time index", color=label_color) |
| | ax_phase.set_ylabel("radians", color=label_color) |
| | ax_phase.tick_params(colors=label_color) |
| | ax_phase.grid(True, linestyle="--", linewidth=0.6, alpha=0.4) |
| | for spine in ax_phase.spines.values(): |
| | spine.set_color("#1f2937") |
| | legend_phase = ax_phase.legend(loc="upper left", fontsize=9) |
| | legend_phase.get_frame().set_facecolor("#111827") |
| | legend_phase.get_frame().set_alpha(0.6) |
| | for text in legend_phase.get_texts(): |
| | text.set_color(label_color) |
| |
|
| | fig.suptitle("Top-3 angle–delay bins over time", fontsize=12, color=title_color) |
| | if show: |
| | plt.show() |
| | else: |
| | output_path = Path(output_path) |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | fig.savefig(output_path, dpi=self.config.dpi, bbox_inches="tight") |
| | plt.close(fig) |
| |
|