File size: 7,220 Bytes
265d187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b854c8e
265d187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b854c8e
265d187
 
b854c8e
265d187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b854c8e
265d187
 
 
 
 
 
 
 
 
b854c8e
265d187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b854c8e
265d187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b854c8e
265d187
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from __future__ import annotations

import dataclasses
import math
import pickle
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset

from .angle_delay import AngleDelayConfig, AngleDelayProcessor
from ..models.lwm import ComplexPatchTokenizer


@dataclasses.dataclass
class AngleDelayDatasetConfig:
    raw_path: Path
    keep_percentage: float = 0.25
    normalize: str = "global_rms"
    cache_dir: Optional[Path] = Path("cache")
    use_cache: bool = True
    overwrite_cache: bool = False
    snr_db: Optional[float] = None
    noise_seed: Optional[int] = None
    max_time_steps: Optional[int] = None
    patch_size: Tuple[int, int] = (1, 1)
    phase_mode: str = "real_imag"


class AngleDelaySequenceDataset(Dataset):
    """Angle-delay dataset that tokenizes sequences and caches the processed tensors."""

    def __init__(self, config: AngleDelayDatasetConfig, logger: Optional[Any] = None) -> None:
        super().__init__()
        self.config = config
        self.logger = logger
        self.tokenizer = ComplexPatchTokenizer(config.phase_mode)
        self.samples: List[torch.Tensor]
        cache_hit = False
        cache_path = self._cache_path() if config.use_cache and config.cache_dir is not None else None
        if cache_path and cache_path.exists() and not config.overwrite_cache:
            try:
                payload = torch.load(cache_path, map_location="cpu")
                if isinstance(payload, dict) and "samples" in payload:
                    self.samples = payload["samples"]
                else:
                    self.samples = payload
                cache_hit = True
            except Exception:
                cache_path.unlink(missing_ok=True)
                cache_hit = False
        if not cache_hit:
            self.samples = self._build_samples()
            if cache_path is not None:
                cache_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save({"samples": self.samples}, cache_path)
        if self.config.snr_db is not None:
            self._apply_noise()

    def _cache_path(self) -> Path:
        cfg = self.config
        name = cfg.raw_path.stem
        # Include patch_size and phase_mode in cache name to ensure cache invalidation
        # when these parameters change. Also add 'v2' to invalidate old caches with wrong normalization.
        ph, pw = cfg.patch_size
        cache_name = f"adseq_{name}_keep{int(cfg.keep_percentage * 100)}_{cfg.normalize}_p{ph}x{pw}_{cfg.phase_mode}_v2.pt"
        return cfg.cache_dir / cache_name  # type: ignore[operator]

    def _load_raw(self) -> Any:
        with self.config.raw_path.open("rb") as handle:
            return pickle.load(handle)

    def _normalize_sample(self, tensor: torch.Tensor) -> torch.Tensor:
        """Normalize a single sample by its own RMS."""
        rms = torch.sqrt((tensor.real.float() ** 2 + tensor.imag.float() ** 2).mean()).clamp_min(1e-8)
        return tensor / rms.to(tensor.dtype)

    def _build_samples(self) -> List[torch.Tensor]:
        payload = self._load_raw()
        channel = payload["channel"] if isinstance(payload, dict) and "channel" in payload else payload
        channel_tensor = torch.as_tensor(channel, dtype=torch.complex64)
        if channel_tensor.ndim == 3:
            channel_tensor = channel_tensor.unsqueeze(0)
        if self.config.max_time_steps is not None and channel_tensor.size(1) > self.config.max_time_steps:
            channel_tensor = channel_tensor[:, : self.config.max_time_steps]
        processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=self.config.keep_percentage))
        samples: List[torch.Tensor] = []
        for seq in channel_tensor:
            ad = processor.forward(seq)
            truncated, _ = processor.truncate_delay_bins(ad)
            samples.append(truncated)
        
        # Apply normalization after collecting all samples
        if self.config.normalize == "per_sample_rms":
            samples = [self._normalize_sample(s) for s in samples]
        elif self.config.normalize == "global_rms":
            # Compute global RMS across all samples
            total_sum_sq = 0.0
            total_count = 0
            for s in samples:
                s_real = s.real.float()
                s_imag = s.imag.float()
                total_sum_sq += (s_real ** 2 + s_imag ** 2).sum().item()
                total_count += s_real.numel()
            if total_count > 0:
                global_rms = math.sqrt(total_sum_sq / total_count)
                global_rms = max(global_rms, 1e-8)
                samples = [s / torch.tensor(global_rms, dtype=torch.float32).to(s.dtype) for s in samples]
        
        return samples

    def _apply_noise(self) -> None:
        if self.config.noise_seed is not None:
            torch.manual_seed(int(self.config.noise_seed))
        noisy: List[torch.Tensor] = []
        snr_lin = 10.0 ** (float(self.config.snr_db) / 10.0)
        for sample in self.samples:
            real = sample.real.float()
            imag = sample.imag.float()
            power = (real.square() + imag.square()).mean().item()
            if power <= 0:
                noisy.append(sample)
                continue
            noise_var = power / snr_lin
            std = math.sqrt(noise_var / 2.0)
            noise_real = torch.randn_like(real) * std
            noise_imag = torch.randn_like(imag) * std
            noise = torch.complex(noise_real.to(sample.dtype), noise_imag.to(sample.dtype))
            noisy.append((sample + noise).to(sample.dtype))
        self.samples = noisy

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int) -> Dict[str, Any]:
        sample = self.samples[index]
        tokens, base_mask = self.tokenizer(sample.unsqueeze(0), self.config.patch_size)
        tokens = tokens.squeeze(0)
        base_mask = base_mask.squeeze(0)
        T, N, M = sample.shape
        ph, pw = self.config.patch_size
        H = N // ph
        W = M // pw
        shape = torch.tensor([T, H, W], dtype=torch.long)
        payload: Dict[str, Any] = {
            "sequence": sample,
            "tokens": tokens,
            "base_mask": base_mask,
            "shape": shape,
        }
        return payload


def load_adseq_dataset(
    data_path: str | Path,
    keep_percentage: float = 0.25,
    normalize: str = "global_rms",
    cache_dir: Optional[str | Path] = "cache",
    use_cache: bool = True,
    overwrite_cache: bool = False,
    logger: Optional[Any] = None,
    snr_db: Optional[float] = None,
    noise_seed: Optional[int] = None,
    max_time_steps: Optional[int] = None,
) -> "AngleDelaySequenceDataset":
    cfg = AngleDelayDatasetConfig(
        raw_path=Path(data_path),
        keep_percentage=keep_percentage,
        normalize=normalize,
        cache_dir=None if cache_dir is None else Path(cache_dir),
        use_cache=use_cache,
        overwrite_cache=overwrite_cache,
        snr_db=snr_db,
        noise_seed=noise_seed,
        max_time_steps=max_time_steps,
    )
    return AngleDelaySequenceDataset(cfg, logger=logger)