File size: 7,220 Bytes
265d187 b854c8e 265d187 b854c8e 265d187 b854c8e 265d187 b854c8e 265d187 b854c8e 265d187 b854c8e 265d187 b854c8e 265d187 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | from __future__ import annotations
import dataclasses
import math
import pickle
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
from .angle_delay import AngleDelayConfig, AngleDelayProcessor
from ..models.lwm import ComplexPatchTokenizer
@dataclasses.dataclass
class AngleDelayDatasetConfig:
raw_path: Path
keep_percentage: float = 0.25
normalize: str = "global_rms"
cache_dir: Optional[Path] = Path("cache")
use_cache: bool = True
overwrite_cache: bool = False
snr_db: Optional[float] = None
noise_seed: Optional[int] = None
max_time_steps: Optional[int] = None
patch_size: Tuple[int, int] = (1, 1)
phase_mode: str = "real_imag"
class AngleDelaySequenceDataset(Dataset):
"""Angle-delay dataset that tokenizes sequences and caches the processed tensors."""
def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None:
super().__init__()
self.config = config
self.logger = logger
self.tokenizer = ComplexPatchTokenizer(config.phase_mode)
self.samples: List[torch.Tensor]
cache_hit = False
cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None
if cache_path and cache_path.exists() and not config.overwrite_cache:
try:
payload = torch.load(cache_path, map_location="cpu")
if isinstance(payload, dict) and "samples" in payload:
self.samples = payload["samples"]
else:
self.samples = payload
cache_hit = True
except Exception:
cache_path.unlink(missing_ok=True)
cache_hit = False
if not cache_hit:
self.samples = self._build_samples()
if cache_path is not None:
cache_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({"samples": self.samples}, cache_path)
if self.config.snr_db is not None:
self._apply_noise()
def _cache_path(self) -> Path:
cfg = self.config
name = cfg.raw_path.stem
# Include patch_size and phase_mode in cache name to ensure cache invalidation
# when these parameters change. Also add 'v2' to invalidate old caches with wrong normalization.
ph, pw = cfg.patch_size
cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt"
return cfg.cache_dir / cache_name # type: ignore[operator]
def _load_raw(self) -> Any:
with self.config.raw_path.open("rb") as handle:
return pickle.load(handle)
def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor:
"""Normalize a single sample by its own RMS."""
rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
return tensor / rms.to(tensor.dtype)
def _build_samples(self) -> List[torch.Tensor]:
payload = self._load_raw()
channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload
channel_tensor = torch.as_tensor(channel, dtype=torch.complex64)
if channel_tensor.ndim == 3:
channel_tensor = channel_tensor.unsqueeze(0)
if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps:
channel_tensor = channel_tensor[:, : self.config.max_time_steps]
processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage))
samples: List[torch.Tensor] = []
for seq in channel_tensor:
ad = processor.forward(seq)
truncated, _ = processor.truncate_delay_bins(ad)
samples.append(truncated)
# Apply normalization after collecting all samples
if self.config.normalize == "per_sample_rms":
samples = [self._normalize_sample(s) for s in samples]
elif self.config.normalize == "global_rms":
# Compute global RMS across all samples
total_sum_sq = 0.0
total_count = 0
for s in samples:
s_real = s.real.float()
s_imag = s.imag.float()
total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item()
total_count += s_real.numel()
if total_count > 0:
global_rms = math.sqrt(total_sum_sq / total_count)
global_rms = max(global_rms, 1e-8)
samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples]
return samples
def _apply_noise(self) -> None:
if self.config.noise_seed is not None:
torch.manual_seed(int(self.config.noise_seed))
noisy: List[torch.Tensor] = []
snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0)
for sample in self.samples:
real = sample.real.float()
imag = sample.imag.float()
power = (real.square() + imag.square()).mean().item()
if power <= 0:
noisy.append(sample)
continue
noise_var = power / snr_lin
std = math.sqrt(noise_var / 2.0)
noise_real = torch.randn_like(real) * std
noise_imag = torch.randn_like(imag) * std
noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype))
noisy.append((sample + noise).to(sample.dtype))
self.samples = noisy
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self.samples[index]
tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size)
tokens = tokens.squeeze(0)
base_mask = base_mask.squeeze(0)
T, N, M = sample.shape
ph, pw = self.config.patch_size
H = N // ph
W = M // pw
shape = torch.tensor([T, H, W], dtype=torch.long)
payload: Dict[str, Any] = {
"sequence": sample,
"tokens": tokens,
"base_mask": base_mask,
"shape": shape,
}
return payload
def load_adseq_dataset(
data_path: str | Path,
keep_percentage: float = 0.25,
normalize: str = "global_rms",
cache_dir: Optional[str | Path] = "cache",
use_cache: bool = True,
overwrite_cache: bool = False,
logger: Optional[Any] = None,
snr_db: Optional[float] = None,
noise_seed: Optional[int] = None,
max_time_steps: Optional[int] = None,
) -> "AngleDelaySequenceDataset":
cfg = AngleDelayDatasetConfig(
raw_path=Path(data_path),
keep_percentage=keep_percentage,
normalize=normalize,
cache_dir=None if cache_dir is None else Path(cache_dir),
use_cache=use_cache,
overwrite_cache=overwrite_cache,
snr_db=snr_db,
noise_seed=noise_seed,
max_time_steps=max_time_steps,
)
return AngleDelaySequenceDataset(cfg, logger=logger)
|