wi-lab's picture
Organize dataset.py and notebook
b854c8e
from __future__ import annotations
import dataclasses
import math
import pickle
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
from .angle_delay import AngleDelayConfig, AngleDelayProcessor
from ..models.lwm import ComplexPatchTokenizer
@dataclasses.dataclass
class AngleDelayDatasetConfig:
raw_path: Path
keep_percentage: float = 0.25
normalize: str = "global_rms"
cache_dir: Optional[Path] = Path("cache")
use_cache: bool = True
overwrite_cache: bool = False
snr_db: Optional[float] = None
noise_seed: Optional[int] = None
max_time_steps: Optional[int] = None
patch_size: Tuple[int, int] = (1, 1)
phase_mode: str = "real_imag"
class AngleDelaySequenceDataset(Dataset):
"""Angle-delay dataset that tokenizes sequences and caches the processed tensors."""
def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None:
super().__init__()
self.config = config
self.logger = logger
self.tokenizer = ComplexPatchTokenizer(config.phase_mode)
self.samples: List[torch.Tensor]
cache_hit = False
cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None
if cache_path and cache_path.exists() and not config.overwrite_cache:
try:
payload = torch.load(cache_path, map_location="cpu")
if isinstance(payload, dict) and "samples" in payload:
self.samples = payload["samples"]
else:
self.samples = payload
cache_hit = True
except Exception:
cache_path.unlink(missing_ok=True)
cache_hit = False
if not cache_hit:
self.samples = self._build_samples()
if cache_path is not None:
cache_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({"samples": self.samples}, cache_path)
if self.config.snr_db is not None:
self._apply_noise()
def _cache_path(self) -> Path:
cfg = self.config
name = cfg.raw_path.stem
# Include patch_size and phase_mode in cache name to ensure cache invalidation
# when these parameters change. Also add 'v2' to invalidate old caches with wrong normalization.
ph, pw = cfg.patch_size
cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt"
return cfg.cache_dir / cache_name # type: ignore[operator]
def _load_raw(self) -> Any:
with self.config.raw_path.open("rb") as handle:
return pickle.load(handle)
def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor:
"""Normalize a single sample by its own RMS."""
rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
return tensor / rms.to(tensor.dtype)
def _build_samples(self) -> List[torch.Tensor]:
payload = self._load_raw()
channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload
channel_tensor = torch.as_tensor(channel, dtype=torch.complex64)
if channel_tensor.ndim == 3:
channel_tensor = channel_tensor.unsqueeze(0)
if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps:
channel_tensor = channel_tensor[:, : self.config.max_time_steps]
processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage))
samples: List[torch.Tensor] = []
for seq in channel_tensor:
ad = processor.forward(seq)
truncated, _ = processor.truncate_delay_bins(ad)
samples.append(truncated)
# Apply normalization after collecting all samples
if self.config.normalize == "per_sample_rms":
samples = [self._normalize_sample(s) for s in samples]
elif self.config.normalize == "global_rms":
# Compute global RMS across all samples
total_sum_sq = 0.0
total_count = 0
for s in samples:
s_real = s.real.float()
s_imag = s.imag.float()
total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item()
total_count += s_real.numel()
if total_count > 0:
global_rms = math.sqrt(total_sum_sq / total_count)
global_rms = max(global_rms, 1e-8)
samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples]
return samples
def _apply_noise(self) -> None:
if self.config.noise_seed is not None:
torch.manual_seed(int(self.config.noise_seed))
noisy: List[torch.Tensor] = []
snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0)
for sample in self.samples:
real = sample.real.float()
imag = sample.imag.float()
power = (real.square() + imag.square()).mean().item()
if power <= 0:
noisy.append(sample)
continue
noise_var = power / snr_lin
std = math.sqrt(noise_var / 2.0)
noise_real = torch.randn_like(real) * std
noise_imag = torch.randn_like(imag) * std
noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype))
noisy.append((sample + noise).to(sample.dtype))
self.samples = noisy
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self.samples[index]
tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size)
tokens = tokens.squeeze(0)
base_mask = base_mask.squeeze(0)
T, N, M = sample.shape
ph, pw = self.config.patch_size
H = N // ph
W = M // pw
shape = torch.tensor([T, H, W], dtype=torch.long)
payload: Dict[str, Any] = {
"sequence": sample,
"tokens": tokens,
"base_mask": base_mask,
"shape": shape,
}
return payload
def load_adseq_dataset(
data_path: str | Path,
keep_percentage: float = 0.25,
normalize: str = "global_rms",
cache_dir: Optional[str | Path] = "cache",
use_cache: bool = True,
overwrite_cache: bool = False,
logger: Optional[Any] = None,
snr_db: Optional[float] = None,
noise_seed: Optional[int] = None,
max_time_steps: Optional[int] = None,
) -> "AngleDelaySequenceDataset":
cfg = AngleDelayDatasetConfig(
raw_path=Path(data_path),
keep_percentage=keep_percentage,
normalize=normalize,
cache_dir=None if cache_dir is None else Path(cache_dir),
use_cache=use_cache,
overwrite_cache=overwrite_cache,
snr_db=snr_db,
noise_seed=noise_seed,
max_time_steps=max_time_steps,
)
return AngleDelaySequenceDataset(cfg, logger=logger)