from __future__ import annotations import dataclasses import math import pickle from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from torch.utils.data import Dataset from .angle_delay import AngleDelayConfig, AngleDelayProcessor from ..models.lwm import ComplexPatchTokenizer @dataclasses.dataclass class AngleDelayDatasetConfig: raw_path: Path keep_percentage: float = 0.25 normalize: str = "global_rms" cache_dir: Optional[Path] = Path("cache") use_cache: bool = True overwrite_cache: bool = False snr_db: Optional[float] = None noise_seed: Optional[int] = None max_time_steps: Optional[int] = None patch_size: Tuple[int, int] = (1, 1) phase_mode: str = "real_imag" class AngleDelaySequenceDataset(Dataset): """Angle-delay dataset that tokenizes sequences and caches the processed tensors.""" def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None: super().__init__() self.config = config self.logger = logger self.tokenizer = ComplexPatchTokenizer(config.phase_mode) self.samples: List[torch.Tensor] cache_hit = False cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None if cache_path and cache_path.exists() and not config.overwrite_cache: try: payload = torch.load(cache_path, map_location="cpu") if isinstance(payload, dict) and "samples" in payload: self.samples = payload["samples"] else: self.samples = payload cache_hit = True except Exception: cache_path.unlink(missing_ok=True) cache_hit = False if not cache_hit: self.samples = self._build_samples() if cache_path is not None: cache_path.parent.mkdir(parents=True, exist_ok=True) torch.save({"samples": self.samples}, cache_path) if self.config.snr_db is not None: self._apply_noise() def _cache_path(self) -> Path: cfg = self.config name = cfg.raw_path.stem # Include patch_size and phase_mode in cache name to ensure cache invalidation # when these parameters change. Also add 'v2' to invalidate old caches with wrong normalization. ph, pw = cfg.patch_size cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt" return cfg.cache_dir / cache_name # type: ignore[operator] def _load_raw(self) -> Any: with self.config.raw_path.open("rb") as handle: return pickle.load(handle) def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor: """Normalize a single sample by its own RMS.""" rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8) return tensor / rms.to(tensor.dtype) def _build_samples(self) -> List[torch.Tensor]: payload = self._load_raw() channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload channel_tensor = torch.as_tensor(channel, dtype=torch.complex64) if channel_tensor.ndim == 3: channel_tensor = channel_tensor.unsqueeze(0) if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps: channel_tensor = channel_tensor[:, : self.config.max_time_steps] processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage)) samples: List[torch.Tensor] = [] for seq in channel_tensor: ad = processor.forward(seq) truncated, _ = processor.truncate_delay_bins(ad) samples.append(truncated) # Apply normalization after collecting all samples if self.config.normalize == "per_sample_rms": samples = [self._normalize_sample(s) for s in samples] elif self.config.normalize == "global_rms": # Compute global RMS across all samples total_sum_sq = 0.0 total_count = 0 for s in samples: s_real = s.real.float() s_imag = s.imag.float() total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item() total_count += s_real.numel() if total_count > 0: global_rms = math.sqrt(total_sum_sq / total_count) global_rms = max(global_rms, 1e-8) samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples] return samples def _apply_noise(self) -> None: if self.config.noise_seed is not None: torch.manual_seed(int(self.config.noise_seed)) noisy: List[torch.Tensor] = [] snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0) for sample in self.samples: real = sample.real.float() imag = sample.imag.float() power = (real.square() + imag.square()).mean().item() if power <= 0: noisy.append(sample) continue noise_var = power / snr_lin std = math.sqrt(noise_var / 2.0) noise_real = torch.randn_like(real) * std noise_imag = torch.randn_like(imag) * std noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype)) noisy.append((sample + noise).to(sample.dtype)) self.samples = noisy def __len__(self) -> int: return len(self.samples) def __getitem__(self, index: int) -> Dict[str, Any]: sample = self.samples[index] tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size) tokens = tokens.squeeze(0) base_mask = base_mask.squeeze(0) T, N, M = sample.shape ph, pw = self.config.patch_size H = N // ph W = M // pw shape = torch.tensor([T, H, W], dtype=torch.long) payload: Dict[str, Any] = { "sequence": sample, "tokens": tokens, "base_mask": base_mask, "shape": shape, } return payload def load_adseq_dataset( data_path: str | Path, keep_percentage: float = 0.25, normalize: str = "global_rms", cache_dir: Optional[str | Path] = "cache", use_cache: bool = True, overwrite_cache: bool = False, logger: Optional[Any] = None, snr_db: Optional[float] = None, noise_seed: Optional[int] = None, max_time_steps: Optional[int] = None, ) -> "AngleDelaySequenceDataset": cfg = AngleDelayDatasetConfig( raw_path=Path(data_path), keep_percentage=keep_percentage, normalize=normalize, cache_dir=None if cache_dir is None else Path(cache_dir), use_cache=use_cache, overwrite_cache=overwrite_cache, snr_db=snr_db, noise_seed=noise_seed, max_time_steps=max_time_steps, ) return AngleDelaySequenceDataset(cfg, logger=logger)