wi-lab commited on
Commit
30d1292
·
verified ·
1 Parent(s): 6a0e1d3

Upload task1/plot_tsne.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. task1/plot_tsne.py +816 -0
task1/plot_tsne.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Visualise how strongly metadata drives the learned embedding space.
3
+
4
+ This script mirrors the functionality of ``task1/plot_mod_tsne.py`` but groups
5
+ spectrograms by their SNR folder name (e.g. ``SNR0dB``) instead of modulation.
6
+ It is useful for checking whether the self-supervised LWM backbone mostly
7
+ captures channel/SNR differences rather than modulation characteristics.
8
+
9
+ Pass ``--label-field modulation`` to reuse the same sampled spectrograms while
10
+ colouring and scoring them by their modulation folder instead of SNR. Use
11
+ ``--label-field mobility`` to highlight link-level mobility categories when
12
+ present in the dataset tree. Saved figures automatically include the detected
13
+ communication profile (e.g. LTE/WiFi/5G) and label mode in the filename when
14
+ those suffixes are not already present.
15
+
16
+ Usage example:
17
+
18
+ ```bash
19
+ python task1/plot_snr_tsne.py \
20
+ --data-root spectrograms/city_1_losangeles/LTE \
21
+ --snrs SNR-5dB,SNR0dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB \
22
+ --save-path task1/snr_separation_plot_latest.png
23
+ ```
24
+ Shortcut presets:
25
+
26
+ ```bash
27
+ python task1/plot_snr_tsne.py --WiFi --report-metrics
28
+ ```
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import argparse
34
+ import glob
35
+ import pickle
36
+ import random
37
+ import re
38
+ from pathlib import Path
39
+ from collections import Counter, defaultdict
40
+ from typing import Dict, Iterable, List, Tuple
41
+
42
+ import matplotlib.pyplot as plt
43
+ import numpy as np
44
+ import torch
45
+ from sklearn.manifold import TSNE
46
+ from sklearn.metrics import silhouette_score
47
+ from sklearn.model_selection import StratifiedKFold
48
+ from sklearn.neighbors import KNeighborsClassifier
49
+ from sklearn.preprocessing import StandardScaler
50
+
51
+ from pretraining.pretrained_model import lwm as lwm_model
52
+ from utils import load_spectrogram_data # support .mat and .pkl uniformly
53
+
54
+
55
+ DEFAULT_DATA_ROOT = "spectrograms/city_1_losangeles/LTE"
56
+ DEFAULT_MODELS_ROOT = "models/LTE_models"
57
+
58
+ PROFILE_PRESETS: Dict[str, Dict[str, str]] = {
59
+ "LTE": {
60
+ "data_root": DEFAULT_DATA_ROOT,
61
+ "models_root": DEFAULT_MODELS_ROOT,
62
+ },
63
+ "WiFi": {
64
+ "data_root": "spectrograms/city_1_losangeles/WiFi",
65
+ "models_root": "models/WiFi_models",
66
+ },
67
+ "5G": {
68
+ "data_root": "spectrograms/city_1_losangeles/5G",
69
+ "models_root": "models/5G_models",
70
+ },
71
+ }
72
+
73
+
74
+ def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
75
+ means = specs.mean(axis=(1, 2), keepdims=True)
76
+ stds = specs.std(axis=(1, 2), keepdims=True)
77
+ stds = np.maximum(stds, eps)
78
+ return ((specs - means) / stds).astype(np.float32, copy=False)
79
+
80
+
81
+ def normalize_dataset(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
82
+ mean = float(specs.mean())
83
+ std = float(specs.std())
84
+ std = max(std, eps)
85
+ return ((specs - mean) / std).astype(np.float32, copy=False)
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Utility helpers
90
+ # ---------------------------------------------------------------------------
91
+
92
+ def parse_args() -> argparse.Namespace:
93
+ parser = argparse.ArgumentParser(description=__doc__)
94
+ parser.add_argument(
95
+ "--data-root",
96
+ default=DEFAULT_DATA_ROOT,
97
+ help="Root directory containing modulation folders (default: %(default)s)",
98
+ )
99
+ parser.add_argument(
100
+ "--modulation",
101
+ default="all",
102
+ help="Modulation folder to load (default: %(default)s)",
103
+ )
104
+ parser.add_argument(
105
+ "--snrs",
106
+ default="SNR-5dB,SNR0dB,SNR5dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB",
107
+ help=(
108
+ "Comma-separated list of SNR folder names to include. Pass 'all' "
109
+ "to include every SNR discovered under the modulation (default: %(default)s)"
110
+ ),
111
+ )
112
+ parser.add_argument(
113
+ "--mobility",
114
+ nargs="+",
115
+ default=["all"],
116
+ help=(
117
+ "Mobility folder(s) to filter on. Pass 'all' to include every mobility "
118
+ "(default: %(default)s). Multiple values can be provided either as a "
119
+ "space-separated list (e.g. '--mobility vehicular pedestrian') or a "
120
+ "comma-separated string."
121
+ ),
122
+ )
123
+ parser.add_argument(
124
+ "--fft-folder",
125
+ default="all",
126
+ help=(
127
+ "FFT size folder name to use. Pass 'all' to include every FFT variant "
128
+ "(default: %(default)s)"
129
+ ),
130
+ )
131
+ parser.add_argument(
132
+ "--samples-per-snr",
133
+ type=int,
134
+ default=500,
135
+ help="Maximum number of samples to draw for each SNR label",
136
+ )
137
+ parser.add_argument(
138
+ "--seed",
139
+ type=int,
140
+ default=42,
141
+ help="Random seed for sampling and t-SNE",
142
+ )
143
+ parser.add_argument(
144
+ "--pooling",
145
+ choices=("mean", "cls"),
146
+ default="mean",
147
+ help="How to collapse token embeddings into a single vector",
148
+ )
149
+ parser.add_argument(
150
+ "--save-path",
151
+ default="task1/snr_separation_plot_latest.png",
152
+ help="Location to save the generated figure (default: %(default)s)",
153
+ )
154
+ parser.add_argument(
155
+ "--checkpoint",
156
+ default=None,
157
+ help="Optional explicit checkpoint path; overrides automatic latest selection",
158
+ )
159
+ parser.add_argument(
160
+ "--models-root",
161
+ default=DEFAULT_MODELS_ROOT,
162
+ help=(
163
+ "Directory containing checkpoints. When --checkpoint is not given, "
164
+ "the latest/best checkpoint inside this directory will be used "
165
+ "(default: %(default)s)"
166
+ ),
167
+ )
168
+ preset_group = parser.add_mutually_exclusive_group()
169
+ preset_group.add_argument(
170
+ "--profile",
171
+ dest="profile",
172
+ choices=tuple(PROFILE_PRESETS.keys()),
173
+ help=(
174
+ "Convenience preset that sets --data-root and --models-root when they "
175
+ "are left at their defaults"
176
+ ),
177
+ )
178
+ preset_group.add_argument(
179
+ "--LTE",
180
+ dest="profile",
181
+ action="store_const",
182
+ const="LTE",
183
+ help="Shortcut for --profile LTE",
184
+ )
185
+ preset_group.add_argument(
186
+ "--WiFi",
187
+ dest="profile",
188
+ action="store_const",
189
+ const="WiFi",
190
+ help="Shortcut for --profile WiFi",
191
+ )
192
+ preset_group.add_argument(
193
+ "--5G",
194
+ dest="profile",
195
+ action="store_const",
196
+ const="5G",
197
+ help="Shortcut for --profile 5G",
198
+ )
199
+ parser.add_argument(
200
+ "--report-metrics",
201
+ action="store_true",
202
+ help="Print clustering metrics (silhouette, 5-fold kNN accuracy)",
203
+ )
204
+ parser.add_argument(
205
+ "--metrics-only",
206
+ action="store_true",
207
+ help="Exit after reporting metrics without running t-SNE or saving figures",
208
+ )
209
+ parser.add_argument(
210
+ "--sampling-mode",
211
+ choices=("first", "reservoir"),
212
+ default="first",
213
+ help="How to down-sample each class (default: first)",
214
+ )
215
+ parser.add_argument(
216
+ "--complex-mode",
217
+ choices=("auto", "magnitude", "interleaved"),
218
+ default="auto",
219
+ help=(
220
+ "How to handle complex spectrograms: 'magnitude' (abs), 'interleaved' (real/imag interleaved along width), "
221
+ "or 'auto' (prefer interleaved when complex). Real-valued inputs are unaffected."
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--label-field",
226
+ choices=("snr", "modulation", "mobility"),
227
+ default="snr",
228
+ help="Choose which label to visualise and score (default: %(default)s)",
229
+ )
230
+ parser.add_argument(
231
+ "--normalization",
232
+ choices=("per-sample", "dataset"),
233
+ default="per-sample",
234
+ help="Normalisation strategy applied before embedding extraction",
235
+ )
236
+ return parser.parse_args()
237
+
238
+
239
+ def find_latest_checkpoint(models_root: Path) -> Path:
240
+ """Return a checkpoint path under ``models_root``.
241
+
242
+ Works with either a parent directory that contains multiple run folders,
243
+ or directly with a single run directory containing ``*.pth`` files.
244
+ Chooses the checkpoint with the lowest parsed validation value when
245
+ available, else falls back to most-recent modification time.
246
+ """
247
+
248
+ if not models_root.exists():
249
+ raise FileNotFoundError(f"Models root not found: {models_root}")
250
+
251
+ if models_root.is_file():
252
+ raise FileNotFoundError(f"Expected a directory, got file: {models_root}")
253
+
254
+ # If the provided directory itself contains checkpoints, use it directly.
255
+ checkpoints = list(models_root.glob("*.pth"))
256
+ if not checkpoints:
257
+ # Otherwise, look for subdirectories that contain checkpoints and ignore others (e.g., tensorboard)
258
+ run_dirs = [p for p in models_root.iterdir() if p.is_dir()]
259
+ candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))]
260
+ if not candidate_runs:
261
+ raise FileNotFoundError(
262
+ f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)"
263
+ )
264
+ latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime)
265
+ checkpoints = list(latest_run.glob("*.pth"))
266
+
267
+ def parse_val_metric(path: Path) -> float | None:
268
+ match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name)
269
+ if match:
270
+ try:
271
+ return float(match.group(1))
272
+ except ValueError:
273
+ return None
274
+ return None
275
+
276
+ parsed = [(parse_val_metric(p), p) for p in checkpoints]
277
+ valid = [item for item in parsed if item[0] is not None]
278
+ if valid:
279
+ valid.sort(key=lambda item: item[0])
280
+ return valid[0][1]
281
+
282
+ # Fallback to most recent modification time
283
+ return max(checkpoints, key=lambda p: p.stat().st_mtime)
284
+
285
+
286
+ def parse_snr_list(snr_argument: str | None) -> set[str] | None:
287
+ if snr_argument is None or snr_argument.lower() == "all":
288
+ return None
289
+ values = [item.strip() for item in snr_argument.split(",") if item.strip()]
290
+ return set(values)
291
+
292
+
293
+ def list_snr_samples(
294
+ data_root: Path,
295
+ modulation: str,
296
+ allowed_snrs: set[str] | None,
297
+ mobility_filter: set[str] | None,
298
+ fft_folder: str,
299
+ max_per_class: int,
300
+ rng: random.Random,
301
+ mode: str,
302
+ complex_mode: str,
303
+ ) -> Dict[str, List[Tuple[np.ndarray, str, str]]]:
304
+ """Collect spectrogram samples grouped by SNR label.
305
+
306
+ Supports both legacy PKL layout with a trailing 'spectrograms/' folder and
307
+ MATLAB .mat bundles saved directly under the mobility folder.
308
+
309
+ Returns: mapping from SNR label to list of tuples: (spec, modulation, mobility)
310
+ """
311
+
312
+ class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]] = defaultdict(list)
313
+ seen_counts: Dict[str, int] = defaultdict(int)
314
+
315
+ # Search patterns:
316
+ # - PKL under .../spectrograms/*.pkl
317
+ # - MAT under .../spectrogram_*.mat
318
+ patterns = [
319
+ str(data_root / "**" / "spectrograms" / "*.pkl"),
320
+ str(data_root / "**" / "spectrogram_*.mat"),
321
+ ]
322
+
323
+ mobility_set = {"static", "pedestrian", "vehicular"}
324
+
325
+ def extract_tokens(rel_parts: Tuple[str, ...]) -> Tuple[str, str, str, str] | None:
326
+ # Heuristic extraction to support both layouts
327
+ # modulation: first path segment below data_root
328
+ if not rel_parts:
329
+ return None
330
+ modulation_folder = rel_parts[0]
331
+
332
+ # snr: first segment like SNR(-?)NdB
333
+ snr_folder = next((p for p in rel_parts if re.match(r"^SNR-?\d+dB$", p)), None)
334
+ if snr_folder is None:
335
+ return None
336
+
337
+ # mobility: one of known labels
338
+ mobility_folder = next((p for p in rel_parts if p.lower() in mobility_set), None)
339
+ if mobility_folder is None:
340
+ return None
341
+
342
+ # fft/window folder if present (PKL layout), else fallback for MAT
343
+ fft_folder_name = next((p for p in rel_parts if p.startswith("win") or p.startswith("fft")), "fft_unknown")
344
+
345
+ return modulation_folder, snr_folder, mobility_folder, fft_folder_name
346
+
347
+ for pattern in patterns:
348
+ for path_str in glob.iglob(pattern, recursive=True):
349
+ path = Path(path_str)
350
+ try:
351
+ rel_parts = path.relative_to(data_root).parts
352
+ except ValueError:
353
+ continue
354
+
355
+ tokens = extract_tokens(rel_parts)
356
+ if tokens is None:
357
+ continue
358
+ modulation_folder, snr_folder, mobility_folder, fft_folder_name = tokens
359
+
360
+ # Apply filters
361
+ if modulation.lower() != "all" and modulation_folder != modulation:
362
+ continue
363
+ if allowed_snrs is not None and snr_folder not in allowed_snrs:
364
+ continue
365
+ if mobility_filter is not None and mobility_folder.lower() not in mobility_filter:
366
+ continue
367
+ if fft_folder != "all" and fft_folder_name != fft_folder:
368
+ continue
369
+
370
+ class_label = snr_folder
371
+ if mode == "first" and len(class_samples[class_label]) >= max_per_class:
372
+ continue
373
+
374
+ # Load spectrogram data (supports .pkl and .mat)
375
+ try:
376
+ arr = load_spectrogram_data(str(path))
377
+ except Exception as exc: # pragma: no cover - I/O heavy
378
+ print(f"[WARN] Failed to load {path}: {exc}")
379
+ continue
380
+
381
+ if not isinstance(arr, np.ndarray) or arr.size == 0:
382
+ continue
383
+
384
+ # If loaded spectrograms are complex, convert according to mode
385
+ if np.iscomplexobj(arr):
386
+ if complex_mode == "magnitude":
387
+ arr = np.abs(arr)
388
+ else:
389
+ # Interleave real/imag parts along the width dimension
390
+ if arr.ndim == 4 and arr.shape[1] == 1:
391
+ arr = arr[:, 0]
392
+ if arr.ndim == 3:
393
+ real = arr.real.astype(np.float32, copy=False)
394
+ imag = arr.imag.astype(np.float32, copy=False)
395
+ n, h, w = real.shape
396
+ inter = np.empty((n, h, w * 2), dtype=np.float32)
397
+ inter[:, :, 0::2] = real
398
+ inter[:, :, 1::2] = imag
399
+ arr = inter
400
+ else:
401
+ # Fallback to magnitude for unsupported shapes
402
+ arr = np.abs(arr)
403
+
404
+ # Normalize shapes:
405
+ # - (N, H, W)
406
+ # - (N, C, H, W) -> collapse channels via mean
407
+ if arr.ndim == 4:
408
+ # (N, C, H, W) -> (N, H, W)
409
+ if arr.shape[1] > 1:
410
+ specs = arr.mean(axis=1)
411
+ else:
412
+ specs = arr[:, 0]
413
+ elif arr.ndim == 3:
414
+ specs = arr
415
+ elif arr.ndim == 2:
416
+ specs = arr[None, ...]
417
+ else:
418
+ print(f"[WARN] Unexpected spectrogram shape in {path}: {arr.shape}")
419
+ continue
420
+
421
+ for spec in specs:
422
+ sample = np.asarray(spec, dtype=np.float32)
423
+ bucket = class_samples[class_label]
424
+
425
+ if len(bucket) < max_per_class:
426
+ bucket.append((sample, modulation_folder, mobility_folder))
427
+ seen_counts[class_label] += 1
428
+ elif mode == "reservoir":
429
+ seen_counts[class_label] += 1
430
+ j = rng.randint(0, seen_counts[class_label] - 1)
431
+ if j < max_per_class:
432
+ bucket[j] = (sample, modulation_folder, mobility_folder)
433
+ else: # mode == "first" and already full
434
+ break
435
+
436
+ return class_samples
437
+
438
+
439
+ def sample_balanced_dataset(
440
+ class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]],
441
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[str]]:
442
+ """Stack the sampled spectrograms alongside SNR, modulation, and mobility labels."""
443
+
444
+ features: List[np.ndarray] = []
445
+ snr_labels: List[str] = []
446
+ modulation_labels: List[str] = []
447
+ mobility_labels: List[str] = []
448
+ class_names = sorted(class_samples.keys())
449
+
450
+ for class_name in class_names:
451
+ samples = class_samples[class_name]
452
+ if not samples:
453
+ continue
454
+ for sample, modulation_label, mobility_label in samples:
455
+ features.append(sample)
456
+ snr_labels.append(class_name)
457
+ modulation_labels.append(modulation_label)
458
+ mobility_labels.append(mobility_label)
459
+
460
+ if not features:
461
+ raise RuntimeError("No spectrogram samples collected for the specified filters")
462
+
463
+ stacked = np.stack(features) # [N, 128, 128]
464
+ return (
465
+ stacked,
466
+ np.array(snr_labels),
467
+ np.array(modulation_labels),
468
+ np.array(mobility_labels),
469
+ class_names,
470
+ )
471
+
472
+
473
+ def unfold_patches_square(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor:
474
+ # Input shape: [B, H, W]; extracts (patch_size x patch_size) patches
475
+ patches_h = x.unfold(1, patch_size, patch_size)
476
+ patches = patches_h.unfold(2, patch_size, patch_size)
477
+ return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size)
478
+
479
+
480
+ def unfold_patches_rect(x: torch.Tensor, patch_rows: int = 4, patch_cols: int = 8) -> torch.Tensor:
481
+ # Input shape: [B, H, W]; extracts (patch_rows x patch_cols) patches (for interleaved complex)
482
+ patches_h = x.unfold(1, patch_rows, patch_rows)
483
+ patches = patches_h.unfold(2, patch_cols, patch_cols)
484
+ return patches.contiguous().view(x.shape[0], -1, patch_rows * patch_cols)
485
+
486
+
487
+ def extract_tokens(spec: np.ndarray, device: torch.device, interleaved: bool) -> torch.Tensor:
488
+ tensor = torch.from_numpy(spec).unsqueeze(0).to(device)
489
+ if interleaved:
490
+ # Rectangular patches 4x8 to cover 4x4 complex bins (real+imag)
491
+ return unfold_patches_rect(tensor, 4, 8) # [1, 1024, 32]
492
+ else:
493
+ return unfold_patches_square(tensor, 4) # [1, 1024, 16]
494
+
495
+
496
+ def pool_embeddings(
497
+ tokens: torch.Tensor,
498
+ model: torch.nn.Module,
499
+ pooling: str,
500
+ ) -> np.ndarray:
501
+ # Append CLS token (value 0.2) before passing through the transformer.
502
+ cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device)
503
+ inputs = torch.cat([cls_token, tokens], dim=1) # [B, 1025, 16]
504
+
505
+ with torch.no_grad():
506
+ outputs = model(inputs) # [B, 1025, 128]
507
+
508
+ if pooling == "cls":
509
+ pooled = outputs[:, 0]
510
+ else: # mean pooling across patch tokens (exclude CLS)
511
+ pooled = outputs[:, 1:].mean(dim=1)
512
+
513
+ return pooled.detach().cpu().numpy()
514
+
515
+
516
+ def sort_snr_labels(labels: List[str]) -> List[str]:
517
+ """Sort SNR labels by numeric value instead of lexicographic order."""
518
+ def extract_snr_value(label: str) -> float:
519
+ """Extract numeric SNR value from label like 'SNR-5dB' -> -5.0"""
520
+ import re
521
+ match = re.search(r'SNR(-?\d+)dB', label)
522
+ if match:
523
+ return float(match.group(1))
524
+ else:
525
+ return float('inf') # Put non-SNR labels at the end
526
+
527
+ return sorted(labels, key=extract_snr_value)
528
+
529
+
530
+ def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None:
531
+ scaler = StandardScaler()
532
+ x_scaled = scaler.fit_transform(x)
533
+ # Guard against NaN/Inf or extreme values that can break SVD/TSNE
534
+ x_scaled = np.nan_to_num(x_scaled, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
535
+ x_scaled = np.clip(x_scaled, -1e6, 1e6)
536
+ x_scaled = x_scaled.astype(np.float32, copy=False)
537
+ # Use a safe perplexity relative to sample count (sklearn requirement: < n_samples).
538
+ max_perplexity = max(5, min(30, len(x_scaled) // 10))
539
+ perplexity = min(max_perplexity, len(x_scaled) - 1)
540
+ perplexity = max(perplexity, 5)
541
+
542
+ tsne = TSNE(
543
+ n_components=2,
544
+ perplexity=perplexity,
545
+ random_state=42,
546
+ init="random",
547
+ learning_rate="auto",
548
+ )
549
+ try:
550
+ embedding = tsne.fit_transform(x_scaled)
551
+ except Exception as e:
552
+ # Fallback to PCA if TSNE/SVD fails
553
+ print(f"[WARN] t-SNE failed ({e}); falling back to PCA.")
554
+ pca = PCA(n_components=2, svd_solver="full", random_state=42)
555
+ embedding = pca.fit_transform(x_scaled)
556
+
557
+ class_names = sort_snr_labels(list(np.unique(labels)))
558
+ colors = plt.cm.Set3(np.linspace(0, 1, len(class_names)))
559
+ for color, class_name in zip(colors, class_names):
560
+ mask = labels == class_name
561
+ ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name)
562
+
563
+ # ax.set_title(title, fontsize=14, fontweight="bold") # Title removed for paper
564
+ ax.set_xlabel("t-SNE Component 1", fontsize=16)
565
+ ax.set_ylabel("t-SNE Component 2", fontsize=16)
566
+ ax.tick_params(labelsize=14) # Increase tick label size
567
+ ax.grid(True, alpha=0.3)
568
+ ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=12)
569
+
570
+
571
+ def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None:
572
+ if len(np.unique(labels)) < 2:
573
+ print(f"[METRIC] {name}: skipped (only one class present)")
574
+ return
575
+
576
+ scaler = StandardScaler()
577
+ features_scaled = scaler.fit_transform(features)
578
+
579
+ silhouette = silhouette_score(features_scaled, labels)
580
+
581
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
582
+ scores: List[float] = []
583
+ for train_idx, test_idx in skf.split(features_scaled, labels):
584
+ clf = KNeighborsClassifier(n_neighbors=5)
585
+ clf.fit(features_scaled[train_idx], labels[train_idx])
586
+ scores.append(clf.score(features_scaled[test_idx], labels[test_idx]))
587
+
588
+ mean_acc = float(np.mean(scores))
589
+ std_acc = float(np.std(scores))
590
+ print(
591
+ f"[METRIC] {name}: silhouette={silhouette:.3f}, "
592
+ f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}"
593
+ )
594
+
595
+
596
+ # ---------------------------------------------------------------------------
597
+ # Main execution
598
+ # ---------------------------------------------------------------------------
599
+
600
+
601
+ def main() -> None:
602
+ args = parse_args()
603
+
604
+ if args.profile:
605
+ preset = PROFILE_PRESETS.get(args.profile)
606
+ if not preset:
607
+ raise ValueError(f"Unknown profile requested: {args.profile}")
608
+ if args.data_root == DEFAULT_DATA_ROOT:
609
+ args.data_root = preset["data_root"]
610
+ if args.models_root == DEFAULT_MODELS_ROOT:
611
+ args.models_root = preset["models_root"]
612
+
613
+ if args.profile:
614
+ print(f"[INFO] Profile preset active: {args.profile}")
615
+
616
+ random.seed(args.seed)
617
+ np.random.seed(args.seed)
618
+ torch.manual_seed(args.seed)
619
+
620
+ data_root = Path(args.data_root)
621
+ if not data_root.exists():
622
+ raise FileNotFoundError(f"Data root not found: {data_root}")
623
+
624
+ allowed_snrs = parse_snr_list(args.snrs)
625
+
626
+ mobility_filter: set[str] | None = None
627
+ if args.mobility:
628
+ mobility_values: List[str] = []
629
+ for value in args.mobility:
630
+ mobility_values.extend([item.strip() for item in value.split(",") if item.strip()])
631
+ mobility_values = [value for value in mobility_values if value]
632
+ if mobility_values and not (len(mobility_values) == 1 and mobility_values[0].lower() == "all"):
633
+ mobility_filter = {value.lower() for value in mobility_values}
634
+ print(
635
+ "[INFO] Mobility filter active: "
636
+ + ", ".join(sorted(mobility_filter))
637
+ )
638
+
639
+ class_samples = list_snr_samples(
640
+ data_root,
641
+ args.modulation,
642
+ allowed_snrs,
643
+ mobility_filter,
644
+ args.fft_folder,
645
+ args.samples_per_snr,
646
+ random,
647
+ args.sampling_mode,
648
+ args.complex_mode,
649
+ )
650
+ samples, snr_labels, modulation_labels, mobility_labels, _ = sample_balanced_dataset(class_samples)
651
+
652
+ if args.label_field == "snr":
653
+ labels = snr_labels
654
+ label_name = "SNR"
655
+ label_display = "SNR"
656
+ elif args.label_field == "modulation":
657
+ labels = modulation_labels
658
+ label_name = "modulation"
659
+ label_display = "Modulation"
660
+ else: # mobility
661
+ labels = mobility_labels
662
+ label_name = "mobility"
663
+ label_display = "Mobility"
664
+
665
+ unique_labels = np.unique(labels)
666
+ print(
667
+ f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(unique_labels)} {label_name} buckets"
668
+ )
669
+ class_counts = Counter(labels)
670
+ print(f"[INFO] Samples per {label_name}:")
671
+ for name, count in sorted(class_counts.items()):
672
+ print(f" {name}: {count}")
673
+
674
+ if args.label_field != "snr":
675
+ snr_counts = Counter(snr_labels)
676
+ print("[INFO] SNR distribution (sampling classes):")
677
+ for name, count in sorted(snr_counts.items()):
678
+ print(f" {name}: {count}")
679
+ if args.label_field == "mobility":
680
+ modulation_counts = Counter(modulation_labels)
681
+ print("[INFO] Modulation distribution:")
682
+ for name, count in sorted(modulation_counts.items()):
683
+ print(f" {name}: {count}")
684
+
685
+ normalization_mode = args.normalization
686
+ if normalization_mode == "per-sample":
687
+ normalized_samples = normalize_per_sample(samples)
688
+ else:
689
+ normalized_samples = normalize_dataset(samples)
690
+ print(f"[INFO] Normalisation mode: {normalization_mode}")
691
+
692
+ # Flatten spectrograms (after optional normalization) for the raw t-SNE view.
693
+ raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1)
694
+
695
+ # Prepare LWM model and embeddings for the right subplot.
696
+ if args.checkpoint:
697
+ checkpoint_path = Path(args.checkpoint)
698
+ if not checkpoint_path.exists():
699
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
700
+ else:
701
+ checkpoint_path = find_latest_checkpoint(Path(args.models_root))
702
+ print(f"[INFO] Using checkpoint: {checkpoint_path}")
703
+
704
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
705
+ print(f"[INFO] Using device: {device}")
706
+ print(f"[INFO] Pooling strategy: {args.pooling}")
707
+ # Determine complex handling strategy for model/patching
708
+ use_interleaved = False
709
+ if args.complex_mode == "interleaved":
710
+ use_interleaved = True
711
+ elif args.complex_mode == "auto":
712
+ # Heuristic: if any sample contains width > 128, assume interleaved (e.g., 128x256)
713
+ sample_shape = tuple(normalized_samples.shape[1:])
714
+ if len(sample_shape) == 2 and sample_shape[1] > 128:
715
+ use_interleaved = True
716
+
717
+ element_length = 32 if use_interleaved else 16
718
+
719
+ model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
720
+ state_dict = torch.load(checkpoint_path, map_location=device)
721
+ if any(k.startswith("module.") for k in state_dict):
722
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
723
+ try:
724
+ model.load_state_dict(state_dict, strict=False)
725
+ except RuntimeError as e:
726
+ msg = str(e)
727
+ # Fallback: checkpoint expects element_length=16 (magnitude), but we constructed 32 (interleaved)
728
+ mismatch16 = "[128, 16]" in msg or "[16]" in msg
729
+ mismatch32 = "[128, 32]" in msg or "[32]" in msg
730
+ if mismatch16 and not mismatch32:
731
+ print("[WARN] Checkpoint expects token dimension 16. Falling back to magnitude embedding.")
732
+ use_interleaved = False
733
+ element_length = 16
734
+ # Recreate model and reload
735
+ model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
736
+ model.load_state_dict(state_dict, strict=False)
737
+ else:
738
+ raise
739
+ model = model.to(device).eval()
740
+
741
+ def collapse_interleaved_to_magnitude(spec: np.ndarray) -> np.ndarray:
742
+ # spec: [H, 2W] with interleaved real/imag along width -> [H, W] magnitude
743
+ h, w2 = spec.shape
744
+ if w2 % 2 != 0:
745
+ return spec # cannot collapse; return as-is
746
+ real = spec[:, 0::2]
747
+ imag = spec[:, 1::2]
748
+ return np.sqrt(np.maximum(real * real + imag * imag, 0.0, dtype=np.float32))
749
+
750
+ # If we fell back to magnitude (use_interleaved False) but inputs are interleaved, collapse for embeddings only
751
+ embed_inputs = normalized_samples
752
+ if not use_interleaved and normalized_samples.shape[2] > 128:
753
+ collapsed = []
754
+ for spec in normalized_samples:
755
+ collapsed.append(collapse_interleaved_to_magnitude(spec))
756
+ embed_inputs = np.stack(collapsed).astype(np.float32, copy=False)
757
+
758
+ embeddings: List[np.ndarray] = []
759
+ for spec in embed_inputs:
760
+ tokens = extract_tokens(spec, device, interleaved=use_interleaved)
761
+ embedding = pool_embeddings(tokens, model, args.pooling)
762
+ embeddings.append(embedding.squeeze(0))
763
+
764
+ embeddings_np = np.vstack(embeddings)
765
+ print(f"[INFO] Generated embeddings with shape {embeddings_np.shape}")
766
+
767
+ if args.report_metrics:
768
+ compute_metrics("Raw spectrogram", raw_vectors, labels)
769
+ pool_label = "LWM mean" if args.pooling == "mean" else "LWM CLS"
770
+ compute_metrics(pool_label, embeddings_np, labels)
771
+ if args.metrics_only:
772
+ return
773
+
774
+ # Plot results (two subplots matching the original figure format).
775
+ fig, axes = plt.subplots(1, 2, figsize=(18, 7))
776
+ raw_title = f"Raw Spectrogram t-SNE (by {label_display})"
777
+ pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token"
778
+ embedding_title = f"LWM Embedding t-SNE ({pooling_label}, by {label_display})"
779
+ run_tsne(raw_vectors, labels, raw_title, axes[0])
780
+ run_tsne(embeddings_np, labels, embedding_title, axes[1])
781
+
782
+ fig.tight_layout()
783
+ save_path = Path(args.save_path)
784
+
785
+ communication_tag: str | None = None
786
+ if args.profile:
787
+ communication_tag = args.profile
788
+ else:
789
+ root_name = Path(args.data_root).name
790
+ if root_name:
791
+ communication_tag = root_name
792
+
793
+ def ensure_suffix(stem: str, suffix: str) -> str:
794
+ return stem if stem.endswith(suffix) else f"{stem}_{suffix}"
795
+
796
+ updated_stem = save_path.stem
797
+ if communication_tag:
798
+ updated_stem = ensure_suffix(updated_stem, communication_tag)
799
+ if args.label_field != "snr":
800
+ label_suffix = f"by_{args.label_field}"
801
+ updated_stem = ensure_suffix(updated_stem, label_suffix)
802
+
803
+ if updated_stem != save_path.stem:
804
+ save_path = save_path.with_name(f"{updated_stem}{save_path.suffix}")
805
+ save_path.parent.mkdir(parents=True, exist_ok=True)
806
+ plt.savefig(save_path, dpi=600, bbox_inches="tight")
807
+ print(f"[INFO] Figure saved to {save_path}")
808
+
809
+ # Also save PDF version for paper (vector format, no resolution limit)
810
+ pdf_path = save_path.with_suffix('.pdf')
811
+ plt.savefig(pdf_path, format='pdf', bbox_inches="tight")
812
+ print(f"[INFO] PDF version saved to {pdf_path}")
813
+
814
+
815
+ if __name__ == "__main__":
816
+ main()