| | from __future__ import annotations |
| |
|
| | import dataclasses |
| | import math |
| | import pickle |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.utils.data import Dataset |
| |
|
| | from .angle_delay import AngleDelayConfig, AngleDelayProcessor |
| | from ..models.lwm import ComplexPatchTokenizer |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class AngleDelayDatasetConfig: |
| | raw_path: Path |
| | keep_percentage: float = 0.25 |
| | normalize: str = "global_rms" |
| | cache_dir: Optional[Path] = Path("cache") |
| | use_cache: bool = True |
| | overwrite_cache: bool = False |
| | snr_db: Optional[float] = None |
| | noise_seed: Optional[int] = None |
| | max_time_steps: Optional[int] = None |
| | patch_size: Tuple[int, int] = (1, 1) |
| | phase_mode: str = "real_imag" |
| |
|
| |
|
| | class AngleDelaySequenceDataset(Dataset): |
| | """Angle-delay dataset that tokenizes sequences and caches the processed tensors.""" |
| |
|
| | def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None: |
| | super().__init__() |
| | self.config = config |
| | self.logger = logger |
| | self.tokenizer = ComplexPatchTokenizer(config.phase_mode) |
| | self.samples: List[torch.Tensor] |
| | cache_hit = False |
| | cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None |
| | if cache_path and cache_path.exists() and not config.overwrite_cache: |
| | try: |
| | payload = torch.load(cache_path, map_location="cpu") |
| | if isinstance(payload, dict) and "samples" in payload: |
| | self.samples = payload["samples"] |
| | else: |
| | self.samples = payload |
| | cache_hit = True |
| | except Exception: |
| | cache_path.unlink(missing_ok=True) |
| | cache_hit = False |
| | if not cache_hit: |
| | self.samples = self._build_samples() |
| | if cache_path is not None: |
| | cache_path.parent.mkdir(parents=True, exist_ok=True) |
| | torch.save({"samples": self.samples}, cache_path) |
| | if self.config.snr_db is not None: |
| | self._apply_noise() |
| |
|
| | def _cache_path(self) -> Path: |
| | cfg = self.config |
| | name = cfg.raw_path.stem |
| | |
| | |
| | ph, pw = cfg.patch_size |
| | cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt" |
| | return cfg.cache_dir / cache_name |
| |
|
| | def _load_raw(self) -> Any: |
| | with self.config.raw_path.open("rb") as handle: |
| | return pickle.load(handle) |
| |
|
| | def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor: |
| | """Normalize a single sample by its own RMS.""" |
| | rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8) |
| | return tensor / rms.to(tensor.dtype) |
| |
|
| | def _build_samples(self) -> List[torch.Tensor]: |
| | payload = self._load_raw() |
| | channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload |
| | channel_tensor = torch.as_tensor(channel, dtype=torch.complex64) |
| | if channel_tensor.ndim == 3: |
| | channel_tensor = channel_tensor.unsqueeze(0) |
| | if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps: |
| | channel_tensor = channel_tensor[:, : self.config.max_time_steps] |
| | processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage)) |
| | samples: List[torch.Tensor] = [] |
| | for seq in channel_tensor: |
| | ad = processor.forward(seq) |
| | truncated, _ = processor.truncate_delay_bins(ad) |
| | samples.append(truncated) |
| | |
| | |
| | if self.config.normalize == "per_sample_rms": |
| | samples = [self._normalize_sample(s) for s in samples] |
| | elif self.config.normalize == "global_rms": |
| | |
| | total_sum_sq = 0.0 |
| | total_count = 0 |
| | for s in samples: |
| | s_real = s.real.float() |
| | s_imag = s.imag.float() |
| | total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item() |
| | total_count += s_real.numel() |
| | if total_count > 0: |
| | global_rms = math.sqrt(total_sum_sq / total_count) |
| | global_rms = max(global_rms, 1e-8) |
| | samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples] |
| | |
| | return samples |
| |
|
| | def _apply_noise(self) -> None: |
| | if self.config.noise_seed is not None: |
| | torch.manual_seed(int(self.config.noise_seed)) |
| | noisy: List[torch.Tensor] = [] |
| | snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0) |
| | for sample in self.samples: |
| | real = sample.real.float() |
| | imag = sample.imag.float() |
| | power = (real.square() + imag.square()).mean().item() |
| | if power <= 0: |
| | noisy.append(sample) |
| | continue |
| | noise_var = power / snr_lin |
| | std = math.sqrt(noise_var / 2.0) |
| | noise_real = torch.randn_like(real) * std |
| | noise_imag = torch.randn_like(imag) * std |
| | noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype)) |
| | noisy.append((sample + noise).to(sample.dtype)) |
| | self.samples = noisy |
| |
|
| | def __len__(self) -> int: |
| | return len(self.samples) |
| |
|
| | def __getitem__(self, index: int) -> Dict[str, Any]: |
| | sample = self.samples[index] |
| | tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size) |
| | tokens = tokens.squeeze(0) |
| | base_mask = base_mask.squeeze(0) |
| | T, N, M = sample.shape |
| | ph, pw = self.config.patch_size |
| | H = N // ph |
| | W = M // pw |
| | shape = torch.tensor([T, H, W], dtype=torch.long) |
| | payload: Dict[str, Any] = { |
| | "sequence": sample, |
| | "tokens": tokens, |
| | "base_mask": base_mask, |
| | "shape": shape, |
| | } |
| | return payload |
| |
|
| |
|
| | def load_adseq_dataset( |
| | data_path: str | Path, |
| | keep_percentage: float = 0.25, |
| | normalize: str = "global_rms", |
| | cache_dir: Optional[str | Path] = "cache", |
| | use_cache: bool = True, |
| | overwrite_cache: bool = False, |
| | logger: Optional[Any] = None, |
| | snr_db: Optional[float] = None, |
| | noise_seed: Optional[int] = None, |
| | max_time_steps: Optional[int] = None, |
| | ) -> "AngleDelaySequenceDataset": |
| | cfg = AngleDelayDatasetConfig( |
| | raw_path=Path(data_path), |
| | keep_percentage=keep_percentage, |
| | normalize=normalize, |
| | cache_dir=None if cache_dir is None else Path(cache_dir), |
| | use_cache=use_cache, |
| | overwrite_cache=overwrite_cache, |
| | snr_db=snr_db, |
| | noise_seed=noise_seed, |
| | max_time_steps=max_time_steps, |
| | ) |
| | return AngleDelaySequenceDataset(cfg, logger=logger) |
| |
|