“Namhyun-Kim” commited on
Commit
aebafe2
·
1 Parent(s): 24c4d80

Update demo with MoE centroid evaluation

Browse files
.gitattributes CHANGED
@@ -1,24 +1,2 @@
1
- # Git LFS configuration for large model files
2
- *.pth filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.pt filter=lfs diff=lfs merge=lfs -text
5
- *.safetensors filter=lfs diff=lfs merge=lfs -text
6
- *.h5 filter=lfs diff=lfs merge=lfs -text
7
- *.onnx filter=lfs diff=lfs merge=lfs -text
8
- *.pkl filter=lfs diff=lfs merge=lfs -text
9
- *.pickle filter=lfs diff=lfs merge=lfs -text
10
- *.pb filter=lfs diff=lfs merge=lfs -text
11
- *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.ckpt filter=lfs diff=lfs merge=lfs -text
13
 
14
- # Large data files
15
- *.zip filter=lfs diff=lfs merge=lfs -text
16
- *.tar filter=lfs diff=lfs merge=lfs -text
17
- *.tar.gz filter=lfs diff=lfs merge=lfs -text
18
- *.npy filter=lfs diff=lfs merge=lfs -text
19
- *.npz filter=lfs diff=lfs merge=lfs -text
20
-
21
- # Large image files (if needed)
22
- *.png filter=lfs diff=lfs merge=lfs -text
23
- *.jpg filter=lfs diff=lfs merge=lfs -text
24
- *.jpeg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,10 +0,0 @@
1
- ---
2
- title: LWM Spectro Demo
3
- emoji: 🔬
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.5.0
8
- app_file: app.py
9
- pinned: false
10
- ---
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,412 +1,481 @@
1
 
2
- import inspect
3
- import random
4
  import sys
5
  from pathlib import Path
6
-
7
- import huggingface_hub as hf_hub
8
-
9
- # Gradio imports HfFolder; add shim before importing gradio.
10
- if not hasattr(hf_hub, "HfFolder"):
11
- class _HfFolderShim:
12
- @staticmethod
13
- def get_token():
14
- return None
15
-
16
- @staticmethod
17
- def save_token(token):
18
- return None
19
-
20
- hf_hub.HfFolder = _HfFolderShim # type: ignore[attr-defined]
21
 
22
  import gradio as gr
23
- import torch
24
  import numpy as np
25
  import pandas as pd
26
- from sklearn.manifold import TSNE
 
 
27
  from sklearn.decomposition import PCA
28
- from sklearn.preprocessing import StandardScaler
29
- from sklearn.metrics import confusion_matrix, f1_score
30
- import matplotlib.pyplot as plt
 
 
 
 
 
 
31
 
32
- # Repo root for local imports
33
- REPO_ROOT = Path(__file__).resolve().parent
34
  if str(REPO_ROOT) not in sys.path:
35
  sys.path.append(str(REPO_ROOT))
36
 
37
- from mixture.train_embedding_router import MoEPredictor # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # ------------------------------------------------------------------------------
40
- # Data loading (t-SNE + evaluation)
41
- # ------------------------------------------------------------------------------
42
 
43
- # Load data
44
- def load_data():
45
- print("Loading data...")
46
- data = torch.load("demo_data.pt")
 
 
 
 
 
 
 
47
  records = []
48
- for i, d in enumerate(data):
49
- records.append({
50
- "index": i,
51
- "tech": d['tech'],
52
- "snr": d['snr'],
53
- "mod": d['mod'],
54
- "mob": d['mob'],
55
- "embedding": d['embedding'].numpy(),
56
- "spectrogram": d['data'].numpy().flatten()
57
- })
58
- df = pd.DataFrame(records)
59
- print(f"Loaded {len(df)} samples.")
60
- return df, data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- df, raw_samples = load_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Get unique values for filters
65
- tech_choices = sorted(list(df['tech'].unique()))
66
- snr_choices = sorted(list(df['snr'].unique()))
67
- mod_choices = sorted(list(df['mod'].unique()))
68
- mob_choices = sorted(list(df['mob'].unique()))
69
 
70
  def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter):
71
- # Filter data
72
- filtered_df = df.copy()
73
- if not tech_filter:
74
- return None, "Select at least one technology."
75
-
76
- if tech_filter and len(tech_filter) > 0:
77
- filtered_df = filtered_df[filtered_df['tech'].isin(tech_filter)]
78
-
79
- if snr_filter and len(snr_filter) > 0:
80
- filtered_df = filtered_df[filtered_df['snr'].isin(snr_filter)]
81
-
82
- if mod_filter and len(mod_filter) > 0:
83
- filtered_df = filtered_df[filtered_df['mod'].isin(mod_filter)]
84
-
85
- if mob_filter and len(mob_filter) > 0:
86
- filtered_df = filtered_df[filtered_df['mob'].isin(mob_filter)]
87
-
88
  if len(filtered_df) < 5:
89
  return None, f"Not enough data points ({len(filtered_df)}). Need at least 5."
90
-
91
- # Select features
92
  if representation == "LWM Embedding":
93
- features = np.stack(filtered_df['embedding'].values)
94
  else:
95
- features = np.stack(filtered_df['spectrogram'].values)
96
- # PCA for raw spectrograms to speed up t-SNE
97
  if features.shape[1] > 50:
98
  pca = PCA(n_components=50, random_state=42)
99
  features = pca.fit_transform(features)
100
-
101
- # Clean up NaNs/Infs that can blank out t-SNE plots
102
- features = np.nan_to_num(features, copy=False)
103
- # Match task1/plot_tsne.py preprocessing: standardize, clamp, float32
104
- scaler = StandardScaler()
105
- features = scaler.fit_transform(features)
106
- features = np.nan_to_num(features, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
107
- features = np.clip(features, -1e6, 1e6).astype(np.float32, copy=False)
108
-
109
- # Run t-SNE
110
- # Adjust perplexity if N is small; cap similarly to task1/plot_tsne.py
111
- max_perplexity = max(5, min(30, len(filtered_df) // 10 if len(filtered_df) > 10 else len(filtered_df) - 1))
112
- eff_perplexity = min(perplexity, len(filtered_df) - 1, max_perplexity)
113
- eff_perplexity = max(eff_perplexity, 5)
114
-
115
- tsne_kwargs = {"n_components": 2, "perplexity": eff_perplexity, "random_state": 42}
116
- sig = inspect.signature(TSNE.__init__)
117
- if "init" in sig.parameters:
118
- tsne_kwargs["init"] = "random"
119
- if "learning_rate" in sig.parameters:
120
- tsne_kwargs["learning_rate"] = "auto"
121
- if "n_iter" in sig.parameters:
122
- tsne_kwargs["n_iter"] = n_iter
123
- elif "max_iter" in sig.parameters:
124
- tsne_kwargs["max_iter"] = n_iter
125
 
126
- try:
127
- tsne = TSNE(**tsne_kwargs)
128
- projections = tsne.fit_transform(features)
129
- if not np.isfinite(projections).all():
130
- raise ValueError("t-SNE produced NaN/Inf projections")
131
- status_msg = f"t-SNE ok ({len(filtered_df)} samples)."
132
- except Exception as e:
133
- # Fallback to 2D PCA so we always show something
134
- pca_fallback = PCA(n_components=2, random_state=42, svd_solver="full")
135
- projections = pca_fallback.fit_transform(features)
136
- status_msg = f"t-SNE failed ({e}); showing 2D PCA instead. Samples: {len(filtered_df)}"
137
-
138
- filtered_df['x'] = projections[:, 0]
139
- filtered_df['y'] = projections[:, 1]
140
- # If t-SNE collapses to a line/point, add tiny jitter so points are visible.
141
- x_span = filtered_df['x'].max() - filtered_df['x'].min()
142
- y_span = filtered_df['y'].max() - filtered_df['y'].min()
143
- if x_span < 1e-6:
144
- filtered_df['x'] += np.random.normal(scale=1e-3, size=len(filtered_df))
145
- if y_span < 1e-6:
146
- filtered_df['y'] += np.random.normal(scale=1e-3, size=len(filtered_df))
147
- x_min, x_max = filtered_df['x'].min(), filtered_df['x'].max()
148
- y_min, y_max = filtered_df['y'].min(), filtered_df['y'].max()
149
- x_pad = max(1e-3, (x_max - x_min) * 0.05)
150
- y_pad = max(1e-3, (y_max - y_min) * 0.05)
151
-
152
- # Plot using matplotlib for maximum reliability in Spaces
153
- fig, ax = plt.subplots(figsize=(7, 6))
154
- colors = plt.cm.tab20(np.linspace(0, 1, len(filtered_df[color_by].unique())))
155
- for c, cls in zip(colors, sorted(filtered_df[color_by].unique())):
156
- mask = filtered_df[color_by] == cls
157
- ax.scatter(
158
- filtered_df.loc[mask, 'x'],
159
- filtered_df.loc[mask, 'y'],
160
- s=18,
161
- alpha=0.8,
162
- label=str(cls),
163
- color=c,
164
- edgecolors='none',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- ax.set_xlim(x_min - x_pad, x_max + x_pad)
167
- ax.set_ylim(y_min - y_pad, y_max + y_pad)
168
- ax.set_xlabel("t-SNE x")
169
- ax.set_ylabel("t-SNE y")
170
- ax.set_title(f"t-SNE of {representation} ({len(filtered_df)} samples)")
171
- ax.grid(True, alpha=0.3)
172
- ax.legend(title=color_by, fontsize=9, title_fontsize=10, loc='best')
173
- fig.tight_layout()
174
-
175
- coord_info = f"x[{x_min:.3f},{x_max:.3f}] y[{y_min:.3f},{y_max:.3f}]"
176
- trace_info = f"traces: {len(filtered_df[color_by].unique())}"
177
- return fig, f"{status_msg} | filtered samples: {len(filtered_df)} | {coord_info} | {trace_info}"
178
-
179
- # ------------------------------------------------------------------------------
180
- # Evaluation utilities (confusion matrix, F1) using the MoE checkpoint
181
- # ------------------------------------------------------------------------------
182
-
183
- _predictor: MoEPredictor | None = None
184
-
185
-
186
- def load_predictor() -> MoEPredictor:
187
- global _predictor
188
- if _predictor is not None:
189
- return _predictor
190
-
191
- # Prefer local checkpoint if present; otherwise pull from Hub
192
- candidates = [
193
- REPO_ROOT / "mixture" / "runs" / "embedding_router" / "moe_checkpoint.pth",
194
- REPO_ROOT / "moe_checkpoint.pth",
195
- ]
196
- ckpt_path = None
197
- for cand in candidates:
198
- if cand.exists():
199
- ckpt_path = cand
200
- break
201
- if ckpt_path is None:
202
- ckpt_path = Path(
203
- hf_hub.hf_hub_download(repo_id="wi-lab/lwm-spectro", filename="moe_checkpoint.pth")
204
- )
205
-
206
- # Ensure expert checkpoints are resolvable in the Space (paths inside ckpt are absolute)
207
- def ensure_expert(name: str, comm: str) -> Path:
208
- """Return a local path to the expert checkpoint, downloading if needed."""
209
- fname = Path(name).name
210
- comm_tag = comm.replace("/", "_")
211
- local_candidates = [
212
- REPO_ROOT / "experts" / fname,
213
- REPO_ROOT / fname,
214
- REPO_ROOT / "experts" / f"{comm_tag}_expert.pth",
215
- REPO_ROOT / f"{comm_tag}_expert.pth",
216
- ]
217
- for cand in local_candidates:
218
- if cand.exists():
219
- return cand
220
- # Download from model repo with multiple filename guesses
221
- download_candidates = [
222
- f"experts/{fname}",
223
- f"experts/{comm_tag}_expert.pth",
224
- fname,
225
- ]
226
- last_err = None
227
- for rel in download_candidates:
228
- try:
229
- downloaded = hf_hub.hf_hub_download(
230
- repo_id="wi-lab/lwm-spectro",
231
- filename=rel,
232
- )
233
- return Path(downloaded)
234
- except Exception as exc: # pragma: no cover - network/permissions issues
235
- last_err = exc
236
- continue
237
- raise RuntimeError(f"Could not resolve expert checkpoint for {comm} ({fname}): {last_err}")
238
-
239
- # Rewrite expert paths into a temp checkpoint so MoEPredictor loads cleanly
240
- import torch # local import to keep top import list compact
241
-
242
- raw_ckpt = torch.load(ckpt_path, map_location="cpu")
243
- experts = raw_ckpt.get("experts", [])
244
- if experts:
245
- patched = False
246
- for expert in experts:
247
- ckpt_field = expert.get("checkpoint")
248
- if not ckpt_field:
249
- continue
250
- fname = Path(ckpt_field).name
251
- comm = expert.get("comm", "unknown")
252
- local_path = ensure_expert(fname, comm)
253
- if str(local_path) != ckpt_field:
254
- expert["checkpoint"] = str(local_path)
255
- patched = True
256
- if patched:
257
- tmp_path = Path("/tmp/moe_checkpoint_patched.pth")
258
- torch.save(raw_ckpt, tmp_path)
259
- ckpt_path = tmp_path
260
-
261
- _predictor = MoEPredictor.from_checkpoint(ckpt_path)
262
- return _predictor
263
-
264
-
265
- def _to_tensor(spec) -> torch.Tensor:
266
- t = spec
267
- if not isinstance(t, torch.Tensor):
268
- t = torch.as_tensor(t)
269
- if t.dim() == 2:
270
- t = t.unsqueeze(0)
271
- return t
272
-
273
-
274
- def _normalize_label(val):
275
- """Convert labels to a simple string for metrics."""
276
- if isinstance(val, (list, tuple)):
277
- return " | ".join(str(v) for v in val)
278
- return str(val)
279
-
280
-
281
- def compute_eval(task: str):
282
- """Compute confusion matrix + macro F1 with balanced sampling per class."""
283
- predictor = load_predictor()
284
- y_true, y_pred = [], []
285
-
286
- # Balanced sampling per class
287
- rng = random.Random(42)
288
- per_class_target = 100
289
-
290
- def class_key(sample):
291
- if task == "comm":
292
- return _normalize_label(sample["tech"])
293
- return _normalize_label((sample["snr"], sample["mob"]))
294
-
295
- buckets = {}
296
- for s in raw_samples:
297
- key = class_key(s)
298
- buckets.setdefault(key, []).append(s)
299
-
300
- selected = []
301
- for key, items in buckets.items():
302
- rng.shuffle(items)
303
- take = min(per_class_target, len(items))
304
- selected.extend(items[:take])
305
-
306
- rng.shuffle(selected)
307
-
308
- for sample in selected:
309
- spec = _to_tensor(sample["data"])
310
- try:
311
- res = predictor.predict(spec, return_routing=True)
312
- except Exception as exc:
313
- print(f"[WARN] predict failed: {exc}")
314
- continue
315
-
316
- if task == "comm":
317
- routing = res.get("routing") or []
318
- pred = _normalize_label(routing[0]["comm"]) if routing else "Unknown"
319
- true = _normalize_label(sample["tech"])
320
- else: # snr_mobility
321
- pred_raw = res.get("label", res["predicted_class"])
322
- pred = _normalize_label(pred_raw)
323
- true = _normalize_label((sample["snr"], sample["mob"]))
324
- y_true.append(true)
325
- y_pred.append(pred)
326
-
327
- if not y_true or not y_pred:
328
- raise RuntimeError("No samples were evaluated; check data or predictions.")
329
-
330
- labels = sorted(list({*y_true, *y_pred}))
331
- cm = confusion_matrix(y_true, y_pred, labels=labels)
332
- f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
333
- acc = (np.array(y_true) == np.array(y_pred)).mean()
334
- return cm, labels, f1, acc, len(y_true)
335
-
336
-
337
- def plot_confusion(cm: np.ndarray, labels):
338
- fig, ax = plt.subplots(figsize=(6, 5))
339
- im = ax.imshow(cm, cmap="Blues")
340
- ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
341
- ax.set_xticks(np.arange(len(labels)), labels=labels, rotation=45, ha="right")
342
- ax.set_yticks(np.arange(len(labels)), labels=labels)
343
- ax.set_xlabel("Predicted")
344
- ax.set_ylabel("True")
345
- for i in range(cm.shape[0]):
346
- for j in range(cm.shape[1]):
347
- ax.text(j, i, int(cm[i, j]), ha="center", va="center", color="black")
348
- fig.tight_layout()
349
  return fig
350
 
351
 
352
- def run_eval(task):
353
- cm, labels, f1, acc, n = compute_eval(task)
354
- fig = plot_confusion(cm, labels)
355
- summary = f"Task: {task} | Samples: {n} | Accuracy: {acc:.4f} | Macro F1: {f1:.4f}"
356
- return fig, summary
 
 
 
 
 
 
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
- # ------------------------------------------------------------------------------
360
- # UI
361
- # ------------------------------------------------------------------------------
362
  with gr.Blocks(title="LWM-Spectro Demo") as demo:
363
  gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
364
- gr.Markdown("Compare embeddings vs raw for t-SNE, and view quick metrics from the latest MoE checkpoint.")
365
-
366
- with gr.Tab("t-SNE"):
367
- with gr.Row():
368
- with gr.Column(scale=1, min_width=300):
369
- gr.Markdown("### Filters")
370
- tech_filter = gr.CheckboxGroup(choices=tech_choices, value=tech_choices[:1], label="Technology (default: single tech)")
371
- snr_filter = gr.Dropdown(choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)")
372
- mod_filter = gr.Dropdown(choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)")
373
- mob_filter = gr.Dropdown(choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)")
374
-
375
- gr.Markdown("### Visualization Settings")
376
- representation = gr.Radio(choices=["LWM Embedding", "Raw Spectrogram"], value="LWM Embedding", label="Representation")
377
- color_by = gr.Dropdown(choices=["tech", "snr", "mod", "mob"], value="snr", label="Color By")
378
-
379
- with gr.Accordion("Advanced t-SNE Settings", open=False):
380
- perplexity = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Perplexity")
381
- n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
382
-
383
- btn = gr.Button("Update Plot", variant="primary")
384
- status = gr.Textbox(label="Status", interactive=False)
385
-
386
- with gr.Column(scale=3):
387
- plot = gr.Plot(label="t-SNE Visualization")
388
-
389
- btn.click(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
390
-
391
- # Initial load
392
- demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
393
-
394
- with gr.Tab("Evaluation (MoE)"):
395
- gr.Markdown("Uses the latest MoE checkpoint to score the bundled demo set.\n\n- **comm**: predicts communication type (LTE/WiFi/5G) via router gating.\n- **snr_mobility**: predicts the SNR/Mobility class via the classifier head.")
396
- task_choice = gr.Radio(choices=["comm", "snr_mobility"], value="snr_mobility", label="Task")
397
- eval_btn = gr.Button("Run Evaluation", variant="primary")
398
- cm_plot = gr.Plot(label="Confusion Matrix")
399
- eval_summary = gr.Textbox(label="Metrics", interactive=False)
400
-
401
- def _safe_run(task):
402
- try:
403
- return run_eval(task)
404
- except Exception as exc:
405
- return None, f"Error during evaluation: {exc}"
406
-
407
- eval_btn.click(_safe_run, inputs=[task_choice], outputs=[cm_plot, eval_summary])
408
- # Run once on load for convenience
409
- demo.load(_safe_run, inputs=[task_choice], outputs=[cm_plot, eval_summary])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  if __name__ == "__main__":
412
  demo.launch()
 
1
 
2
+ import json
 
3
  import sys
4
  from pathlib import Path
5
+ from typing import Dict, List, Optional, Sequence, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  import gradio as gr
 
8
  import numpy as np
9
  import pandas as pd
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ import torch
13
  from sklearn.decomposition import PCA
14
+ from sklearn.manifold import TSNE
15
+ from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
16
+
17
+ REPO_ROOT = Path(__file__).resolve().parents[1]
18
+ APP_DIR = Path(__file__).resolve().parent
19
+ DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
20
+ MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
21
+ MOE_CHECKPOINT = REPO_ROOT / "mixture" / "runs" / "embedding_router" / "moe_checkpoint.pth"
22
+ SNR_MOB_MAPPING_PATH = REPO_ROOT / "mixture" / "runs" / "embedding_router" / "snr_mobility_mapping.json"
23
 
 
 
24
  if str(REPO_ROOT) not in sys.path:
25
  sys.path.append(str(REPO_ROOT))
26
 
27
+ from mixture.train_embedding_router import ( # type: ignore
28
+ MoEPredictor,
29
+ compute_selected_expert_embeddings,
30
+ normalize_per_sample_tensor,
31
+ stack_expert_embeddings,
32
+ )
33
+
34
+
35
+ def load_joint_mapping() -> Optional[Dict[str, object]]:
36
+ if not SNR_MOB_MAPPING_PATH.exists():
37
+ print(f"[WARN] Mapping file not found at {SNR_MOB_MAPPING_PATH}")
38
+ return None
39
+ raw = json.loads(SNR_MOB_MAPPING_PATH.read_text())
40
+ ordered_pairs: List[Tuple[str, str]] = []
41
+ for key in sorted(raw.keys(), key=lambda k: int(k)):
42
+ snr, mob = raw[key]
43
+ ordered_pairs.append((snr, mob))
44
+ label_names = [f"{snr} | {mob}" for snr, mob in ordered_pairs]
45
+ pair_to_name = {pair: name for pair, name in zip(ordered_pairs, label_names)}
46
+ name_to_id = {name: idx for idx, name in enumerate(label_names)}
47
+ pair_to_id = {pair: idx for idx, pair in enumerate(ordered_pairs)}
48
+ return {
49
+ "pairs": ordered_pairs,
50
+ "label_names": label_names,
51
+ "pair_to_name": pair_to_name,
52
+ "name_to_id": name_to_id,
53
+ "pair_to_id": pair_to_id,
54
+ }
55
+
56
+
57
+ def compute_moe_embeddings(
58
+ samples: Sequence[Dict[str, object]],
59
+ predictor: MoEPredictor,
60
+ batch_size: int = 64,
61
+ ) -> torch.Tensor:
62
+ router = predictor.router
63
+ experts = predictor.experts
64
+ device = predictor.device
65
+ embeddings: List[torch.Tensor] = []
66
+
67
+ with torch.no_grad():
68
+ for start in range(0, len(samples), batch_size):
69
+ batch = samples[start : start + batch_size]
70
+ specs = torch.cat([sample["data"] for sample in batch], dim=0).to(device)
71
+ specs_norm = normalize_per_sample_tensor(specs)
72
+
73
+ if router is not None:
74
+ router_logits = router(specs_norm)
75
+ probs = torch.softmax(router_logits, dim=1)
76
+ topk_vals, topk_idx = probs.topk(k=predictor.topk, dim=1)
77
+ weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6)
78
+ selected_embeddings = compute_selected_expert_embeddings(
79
+ experts,
80
+ specs_norm,
81
+ topk_idx,
82
+ allow_grad=False,
83
+ )
84
+ weighted = (weights.unsqueeze(-1) * selected_embeddings).sum(dim=1)
85
+ else:
86
+ stacked = stack_expert_embeddings(experts, specs_norm)
87
+ weighted = stacked.mean(dim=1)
88
+
89
+ embeddings.append(weighted.cpu())
90
+
91
+ return torch.cat(embeddings, dim=0)
92
+
93
+
94
+ def ensure_moe_embeddings(samples: List[Dict[str, object]]) -> Tuple[List[Dict[str, object]], bool]:
95
+ if MOE_DATA_PATH.exists():
96
+ cached = torch.load(MOE_DATA_PATH)
97
+ if len(cached) == len(samples):
98
+ print(f"[INFO] Loaded cached MoE embeddings from {MOE_DATA_PATH}")
99
+ return cached, True
100
+ print("[WARN] Cached MoE embeddings length mismatch. Recomputing...")
101
+
102
+ if not MOE_CHECKPOINT.exists():
103
+ print(f"[WARN] MoE checkpoint not found at {MOE_CHECKPOINT}. Skipping MoE embeddings.")
104
+ return samples, False
105
+
106
+ print("[INFO] Computing MoE embeddings using router checkpoint...")
107
+ predictor = MoEPredictor.from_checkpoint(MOE_CHECKPOINT)
108
+ moe_embeddings = compute_moe_embeddings(samples, predictor)
109
+ for sample, emb in zip(samples, moe_embeddings):
110
+ sample["moe_embedding"] = emb.detach().cpu()
111
+
112
+ torch.save(samples, MOE_DATA_PATH)
113
+ print(f"[INFO] Saved MoE-augmented dataset to {MOE_DATA_PATH}")
114
+ return samples, True
115
 
 
 
 
116
 
117
+ def load_data(mapping: Optional[Dict[str, object]]):
118
+ if not DEMO_DATA_PATH.exists():
119
+ raise FileNotFoundError(f"Dataset not found at {DEMO_DATA_PATH}")
120
+
121
+ print(f"[INFO] Loading base dataset from {DEMO_DATA_PATH}")
122
+ data: List[Dict[str, object]] = torch.load(DEMO_DATA_PATH)
123
+ data, has_moe = ensure_moe_embeddings(data)
124
+
125
+ pair_to_name = mapping["pair_to_name"] if mapping else {}
126
+ pair_to_id = mapping["pair_to_id"] if mapping else {}
127
+
128
  records = []
129
+ for i, sample in enumerate(data):
130
+ embedding = sample["embedding"]
131
+ if isinstance(embedding, torch.Tensor):
132
+ base_embedding = embedding.detach().cpu().numpy()
133
+ else:
134
+ base_embedding = np.asarray(embedding)
135
+
136
+ spectrogram = sample["data"]
137
+ if isinstance(spectrogram, torch.Tensor):
138
+ flat_spec = spectrogram.numpy().flatten()
139
+ else:
140
+ flat_spec = np.asarray(spectrogram).flatten()
141
+
142
+ moe_embedding = sample.get("moe_embedding")
143
+ if isinstance(moe_embedding, torch.Tensor):
144
+ moe_embedding = moe_embedding.numpy()
145
+ elif moe_embedding is not None:
146
+ moe_embedding = np.asarray(moe_embedding)
147
+
148
+ pair = (sample["snr"], sample["mob"])
149
+ joint_label = pair_to_name.get(pair)
150
+ joint_label_id = pair_to_id.get(pair)
151
+
152
+ records.append(
153
+ {
154
+ "index": i,
155
+ "tech": sample["tech"],
156
+ "snr": sample["snr"],
157
+ "mod": sample["mod"],
158
+ "mob": sample["mob"],
159
+ "embedding": base_embedding,
160
+ "moe_embedding": moe_embedding,
161
+ "spectrogram": flat_spec,
162
+ "joint_label": joint_label,
163
+ "joint_label_id": joint_label_id,
164
+ }
165
+ )
166
 
167
+ df = pd.DataFrame(records)
168
+ print(f"[INFO] Loaded {len(df)} samples.")
169
+ return df, has_moe
170
+
171
+
172
+ def apply_filters(
173
+ dataframe: pd.DataFrame,
174
+ tech_filter,
175
+ snr_filter,
176
+ mod_filter,
177
+ mob_filter,
178
+ ) -> pd.DataFrame:
179
+ filtered = dataframe.copy()
180
+ if tech_filter:
181
+ filtered = filtered[filtered["tech"].isin(tech_filter)]
182
+ if snr_filter:
183
+ filtered = filtered[filtered["snr"].isin(snr_filter)]
184
+ if mod_filter:
185
+ filtered = filtered[filtered["mod"].isin(mod_filter)]
186
+ if mob_filter:
187
+ filtered = filtered[filtered["mob"].isin(mob_filter)]
188
+ return filtered
189
 
 
 
 
 
 
190
 
191
  def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter):
192
+ filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  if len(filtered_df) < 5:
194
  return None, f"Not enough data points ({len(filtered_df)}). Need at least 5."
195
+
 
196
  if representation == "LWM Embedding":
197
+ features = np.stack(filtered_df["embedding"].values)
198
  else:
199
+ features = np.stack(filtered_df["spectrogram"].values)
 
200
  if features.shape[1] > 50:
201
  pca = PCA(n_components=50, random_state=42)
202
  features = pca.fit_transform(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ eff_perplexity = min(perplexity, len(filtered_df) - 1)
205
+ tsne = TSNE(
206
+ n_components=2,
207
+ perplexity=eff_perplexity,
208
+ n_iter=n_iter,
209
+ random_state=42,
210
+ init="pca",
211
+ learning_rate="auto",
212
+ )
213
+ projections = tsne.fit_transform(features)
214
+ filtered_df = filtered_df.copy()
215
+ filtered_df["x"] = projections[:, 0]
216
+ filtered_df["y"] = projections[:, 1]
217
+
218
+ fig = px.scatter(
219
+ filtered_df,
220
+ x="x",
221
+ y="y",
222
+ color=color_by,
223
+ hover_data=["tech", "snr", "mod", "mob"],
224
+ title=f"t-SNE of {representation} ({len(filtered_df)} samples)",
225
+ template="plotly_white",
226
+ )
227
+ fig.update_layout(legend_title_text=color_by.capitalize())
228
+ return fig, f"Displayed {len(filtered_df)} samples."
229
+
230
+
231
+ def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
232
+ rng = np.random.default_rng(int(seed))
233
+ train_indices: List[int] = []
234
+ test_indices: List[int] = []
235
+
236
+ for label_id, group in filtered_df.groupby("joint_label_id"):
237
+ indices = group.index.to_numpy()
238
+ if indices.size < 2:
239
+ raise ValueError(f"Class '{CLASS_LABELS[int(label_id)]}' needs at least 2 samples for evaluation.")
240
+
241
+ rng.shuffle(indices)
242
+ split = int(round(indices.size * train_ratio))
243
+ split = max(1, min(indices.size - 1, split))
244
+ train_indices.extend(indices[:split])
245
+ test_indices.extend(indices[split:])
246
+
247
+ return np.array(train_indices), np.array(test_indices)
248
+
249
+
250
+ def compute_centroid_metrics(filtered_df: pd.DataFrame, train_idx: np.ndarray, test_idx: np.ndarray) -> Dict[str, object]:
251
+ train_subset = filtered_df.loc[train_idx]
252
+ test_subset = filtered_df.loc[test_idx]
253
+
254
+ train_embeddings = np.stack(train_subset["moe_embedding"].values)
255
+ test_embeddings = np.stack(test_subset["moe_embedding"].values)
256
+ train_labels = train_subset["joint_label_id"].to_numpy(dtype=int)
257
+ test_labels = test_subset["joint_label_id"].to_numpy(dtype=int)
258
+
259
+ unique_labels = np.unique(train_labels)
260
+ centroids = []
261
+ centroid_ids: List[int] = []
262
+ for label_id in unique_labels:
263
+ mask = train_labels == label_id
264
+ centroids.append(train_embeddings[mask].mean(axis=0))
265
+ centroid_ids.append(int(label_id))
266
+
267
+ centroids = np.stack(centroids)
268
+ centroid_ids = np.array(centroid_ids, dtype=int)
269
+
270
+ dists = ((test_embeddings[:, None, :] - centroids[None, :, :]) ** 2).sum(axis=-1)
271
+ preds = centroid_ids[np.argmin(dists, axis=1)]
272
+
273
+ accuracy = accuracy_score(test_labels, preds)
274
+ macro_f1 = f1_score(test_labels, preds, average="macro", labels=centroid_ids, zero_division=0)
275
+
276
+ active_ids = sorted(np.unique(np.concatenate([test_labels, preds])))
277
+ label_names = [CLASS_LABELS[i] for i in active_ids]
278
+ cm = confusion_matrix(test_labels, preds, labels=active_ids)
279
+
280
+ return {
281
+ "accuracy": accuracy,
282
+ "macro_f1": macro_f1,
283
+ "confusion": cm,
284
+ "label_names": label_names,
285
+ "train_size": len(train_idx),
286
+ "test_size": len(test_idx),
287
+ }
288
+
289
+
290
+ def plot_confusion_heatmap(confusion: np.ndarray, label_names: List[str]) -> go.Figure:
291
+ fig = go.Figure(
292
+ data=go.Heatmap(
293
+ z=confusion,
294
+ x=label_names,
295
+ y=label_names,
296
+ colorscale="Viridis",
297
+ hovertemplate="Predicted %{x}<br>True %{y}<br>Count %{z}<extra></extra>",
298
  )
299
+ )
300
+ fig.update_layout(
301
+ title="Prototype Classifier Confusion Matrix",
302
+ xaxis_title="Predicted",
303
+ yaxis_title="True",
304
+ xaxis=dict(tickangle=45),
305
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  return fig
307
 
308
 
309
+ def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter):
310
+ if joint_eval_df.empty:
311
+ fig = go.Figure()
312
+ fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False))
313
+ return fig, "MoE embeddings are not available for evaluation."
314
+
315
+ filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter)
316
+ if filtered.empty:
317
+ fig = go.Figure()
318
+ fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False))
319
+ return fig, "No samples match the selected filters."
320
 
321
+ if filtered["joint_label_id"].nunique() < 2:
322
+ fig = go.Figure()
323
+ fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False))
324
+ return fig, "Need at least two joint SNR/Doppler classes to evaluate."
325
+
326
+ try:
327
+ train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed)
328
+ except ValueError as exc:
329
+ fig = go.Figure()
330
+ fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False))
331
+ return fig, str(exc)
332
+
333
+ metrics = compute_centroid_metrics(filtered, train_idx, test_idx)
334
+ fig = plot_confusion_heatmap(metrics["confusion"], metrics["label_names"])
335
+ status = (
336
+ f"Train samples: {metrics['train_size']}\n"
337
+ f"Test samples: {metrics['test_size']}\n"
338
+ f"Accuracy: {metrics['accuracy'] * 100:.2f}%\n"
339
+ f"Macro F1: {metrics['macro_f1']:.3f}"
340
+ )
341
+ return fig, status
342
+
343
+
344
+ mapping_info = load_joint_mapping()
345
+ df, has_moe_embeddings = load_data(mapping_info)
346
+ CLASS_LABELS: List[str] = mapping_info["label_names"] if mapping_info else []
347
+
348
+ joint_eval_df = df.copy()
349
+ joint_eval_df = joint_eval_df[joint_eval_df["joint_label_id"].notna()]
350
+ joint_eval_df = joint_eval_df[joint_eval_df["moe_embedding"].notna()]
351
+
352
+ tech_choices = sorted(df["tech"].unique())
353
+ snr_choices = sorted(df["snr"].unique())
354
+ mod_choices = sorted(df["mod"].unique())
355
+ mob_choices = sorted(df["mob"].unique())
356
+
357
+ evaluation_disabled = joint_eval_df.empty
358
 
 
 
 
359
  with gr.Blocks(title="LWM-Spectro Demo") as demo:
360
  gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
361
+ gr.Markdown(
362
+ """
363
+ Compare **LWM embeddings** vs **Raw Spectrograms** for visualization, then evaluate **MoE embeddings**
364
+ with a lightweight prototype classifier for joint SNR/Doppler recognition.
365
+ """
366
+ )
367
+
368
+ with gr.Tabs():
369
+ with gr.Tab("Visualization"):
370
+ with gr.Row():
371
+ with gr.Column(scale=1, min_width=300):
372
+ gr.Markdown("### Filters")
373
+ tech_filter = gr.CheckboxGroup(choices=tech_choices, value=tech_choices, label="Technology")
374
+ snr_filter = gr.Dropdown(
375
+ choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)"
376
+ )
377
+ mod_filter = gr.Dropdown(
378
+ choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)"
379
+ )
380
+ mob_filter = gr.Dropdown(
381
+ choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)"
382
+ )
383
+
384
+ gr.Markdown("### Visualization Settings")
385
+ representation = gr.Radio(
386
+ choices=["LWM Embedding", "Raw Spectrogram"],
387
+ value="LWM Embedding",
388
+ label="Representation",
389
+ )
390
+ color_by = gr.Dropdown(choices=["tech", "snr", "mod", "mob"], value="tech", label="Color By")
391
+
392
+ with gr.Accordion("Advanced t-SNE Settings", open=False):
393
+ perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity")
394
+ n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
395
+
396
+ btn = gr.Button("Update Plot", variant="primary")
397
+ status = gr.Textbox(label="Status", interactive=False)
398
+
399
+ with gr.Column(scale=3):
400
+ plot = gr.Plot(label="t-SNE Visualization")
401
+
402
+ btn.click(
403
+ plot_tsne,
404
+ inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
405
+ outputs=[plot, status],
406
+ )
407
+
408
+ demo.load(
409
+ plot_tsne,
410
+ inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
411
+ outputs=[plot, status],
412
+ )
413
+
414
+ with gr.Tab("Evaludation (Joint SNR/Doppler)"):
415
+ if evaluation_disabled:
416
+ gr.Markdown(
417
+ "⚠️ MoE embeddings are unavailable. Ensure `demo_data_moe.pt` exists or the checkpoint is present."
418
+ )
419
+
420
+ with gr.Row():
421
+ with gr.Column(scale=1, min_width=320):
422
+ gr.Markdown("### Evaluation Filters")
423
+ eval_tech_filter = gr.CheckboxGroup(
424
+ choices=tech_choices,
425
+ value=tech_choices,
426
+ label="Technology",
427
+ interactive=not evaluation_disabled,
428
+ )
429
+ eval_snr_filter = gr.Dropdown(
430
+ choices=snr_choices,
431
+ value=None,
432
+ multiselect=True,
433
+ label="SNR (Empty = All)",
434
+ interactive=not evaluation_disabled,
435
+ )
436
+ eval_mod_filter = gr.Dropdown(
437
+ choices=mod_choices,
438
+ value=None,
439
+ multiselect=True,
440
+ label="Modulation (Empty = All)",
441
+ interactive=not evaluation_disabled,
442
+ )
443
+ eval_mob_filter = gr.Dropdown(
444
+ choices=mob_choices,
445
+ value=None,
446
+ multiselect=True,
447
+ label="Mobility (Empty = All)",
448
+ interactive=not evaluation_disabled,
449
+ )
450
+
451
+ gr.Markdown("### Prototype Settings")
452
+ train_pct = gr.Slider(
453
+ minimum=10,
454
+ maximum=80,
455
+ step=5,
456
+ value=60,
457
+ label="Training Percentage (%)",
458
+ interactive=not evaluation_disabled,
459
+ )
460
+ seed = gr.Slider(
461
+ minimum=0,
462
+ maximum=9999,
463
+ step=1,
464
+ value=42,
465
+ label="Random Seed",
466
+ interactive=not evaluation_disabled,
467
+ )
468
+ eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled)
469
+
470
+ with gr.Column(scale=3):
471
+ eval_plot = gr.Plot(label="Prototype Confusion Matrix")
472
+ eval_status = gr.Textbox(label="Metrics", interactive=False)
473
+
474
+ eval_btn.click(
475
+ run_joint_evaluation,
476
+ inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter],
477
+ outputs=[eval_plot, eval_status],
478
+ )
479
 
480
  if __name__ == "__main__":
481
  demo.launch()
mixture/train_embedding_router.py DELETED
The diff for this file is too large to render. See raw diff
 
mixture/train_top1_router.py DELETED
@@ -1,1039 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Train a communication-router with top-1 hard expert selection.
3
-
4
- The script builds a supervised mixture-of-experts pipeline:
5
-
6
- 1. Gather spectrogram samples for each communication profile (LTE/WiFi/5G).
7
- 2. Train a lightweight CNN router that predicts the communication label.
8
- 3. (Optional) Attach pre-trained experts and evaluate top-1 hard routing by
9
- running only the expert picked by the router's argmax for each sample.
10
-
11
- Expert checkpoints are expected to be LWM-based classifiers (for example those
12
- produced by `task2/train_joint_snr_mobility.py` or earlier mobility fine-tuning
13
- pipelines).
14
- Their architecture is inferred from the checkpoint to avoid manual plumbing.
15
-
16
- Example:
17
-
18
- ```bash
19
- python mixture/train_top1_router.py \
20
- --data-root spectrograms \
21
- --cities city_1_losangeles \
22
- --comm-types LTE WiFi 5G \
23
- --task snr_mobility \
24
- --mobilities vehicular pedestrian \
25
- --snrs SNR-5dB SNR0dB SNR5dB SNR10dB SNR15dB \
26
- --max-samples-per-comm 6000 \
27
- --max-per-combo 400 \
28
- --epochs 25 \
29
- --batch-size 128 \
30
- --lr 3e-4 \
31
- --output-dir mixture/runs/top1_router \
32
- --expert LTE=models/doppler_finetuned_binary/lte/lwm_lte_doppler_val90.67.pth \
33
- --expert WiFi=models/doppler_finetuned_binary/wifi/lwm_wifi_doppler_val95.01.pth \
34
- --expert 5G=models/doppler_finetuned_binary/5g/lwm_5g_doppler_val96.05.pth
35
- ```
36
- """
37
-
38
- from __future__ import annotations
39
-
40
- import argparse
41
- import json
42
- import random
43
- from collections import Counter, defaultdict
44
- from dataclasses import dataclass
45
- from pathlib import Path
46
- from typing import Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
47
-
48
- import glob
49
- import numpy as np
50
- import torch
51
- import torch.nn as nn
52
- import torch.nn.functional as F
53
- from torch.amp import GradScaler, autocast
54
- from torch.utils.data import DataLoader, Dataset
55
-
56
-
57
- try:
58
- from task1.train_mcs_models import (
59
- MODULATION_LABELS,
60
- identify_modulation,
61
- load_all_samples,
62
- normalize_per_sample,
63
- _extract_metadata,
64
- )
65
- except ImportError as exc: # pragma: no cover - safety net
66
- raise ImportError(
67
- "Failed to import helpers from task1.train_mcs_models. "
68
- "Ensure the repository root is on PYTHONPATH."
69
- ) from exc
70
-
71
- try:
72
- from task2.train_joint_snr_mobility import snr_sort_key
73
- except ImportError: # pragma: no cover - fallback if task2 module is unavailable
74
-
75
- def snr_sort_key(snr: str) -> Tuple[int, str]:
76
- import re
77
-
78
- match = re.search(r"SNR(-?\d+)dB", snr)
79
- if match:
80
- return int(match.group(1)), snr
81
- return 0, snr
82
-
83
-
84
- from pretraining.pretrained_model import lwm as lwm_model
85
-
86
-
87
- COMM_CANONICAL = {"lte": "LTE", "wifi": "WiFi", "5g": "5G"}
88
-
89
-
90
- def canonical_comm_name(name: str) -> str:
91
- lower = name.strip().lower()
92
- if lower in COMM_CANONICAL:
93
- return COMM_CANONICAL[lower]
94
- for canonical in COMM_CANONICAL.values():
95
- if canonical.lower() == lower:
96
- return canonical
97
- raise ValueError(f"Unknown communication type: {name}")
98
-
99
-
100
- @dataclass(slots=True)
101
- class SampleMetadata:
102
- comm: str
103
- modulation: str
104
- snr: str
105
- mobility: str
106
- rate: str
107
- source: str
108
-
109
-
110
- @dataclass(slots=True)
111
- class ExpertSpec:
112
- comm: str
113
- checkpoint: Path
114
- stats_path: Optional[Path]
115
-
116
-
117
- class RoutedSpectrogramDataset(Dataset):
118
- """Spectrogram dataset that tracks both router and downstream labels."""
119
-
120
- def __init__(
121
- self,
122
- specs: np.ndarray,
123
- comm_labels: np.ndarray,
124
- task_labels: np.ndarray,
125
- metadata: List[SampleMetadata],
126
- ) -> None:
127
- if not (len(specs) == len(comm_labels) == len(task_labels) == len(metadata)):
128
- raise ValueError("All dataset inputs must have the same length")
129
- self.specs = torch.from_numpy(specs.astype(np.float32, copy=False))
130
- self.comm_labels = torch.from_numpy(comm_labels.astype(np.int64, copy=False))
131
- self.task_labels = torch.from_numpy(task_labels.astype(np.int64, copy=False))
132
- self.metadata = metadata
133
-
134
- def __len__(self) -> int:
135
- return self.specs.shape[0]
136
-
137
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]:
138
- return self.specs[idx], int(self.comm_labels[idx]), int(self.task_labels[idx])
139
-
140
-
141
- class RouterNet(nn.Module):
142
- """Lightweight CNN router for 128×128 spectrogram inputs."""
143
-
144
- def __init__(self, num_comm: int, dropout: float = 0.1) -> None:
145
- super().__init__()
146
- self.features = nn.Sequential(
147
- nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2),
148
- nn.BatchNorm2d(32),
149
- nn.SiLU(inplace=True),
150
- nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
151
- nn.BatchNorm2d(64),
152
- nn.SiLU(inplace=True),
153
- nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),
154
- nn.BatchNorm2d(96),
155
- nn.SiLU(inplace=True),
156
- nn.Conv2d(96, 128, kernel_size=3, stride=2, padding=1),
157
- nn.BatchNorm2d(128),
158
- nn.SiLU(inplace=True),
159
- nn.AdaptiveAvgPool2d((1, 1)),
160
- )
161
- head_layers: List[nn.Module] = [nn.Flatten()]
162
- if dropout > 0:
163
- head_layers.append(nn.Dropout(dropout))
164
- head_layers.append(nn.Linear(128, num_comm))
165
- self.classifier = nn.Sequential(*head_layers)
166
-
167
- def forward(self, specs: torch.Tensor) -> torch.Tensor:
168
- x = specs
169
- if x.dim() == 3:
170
- x = x.unsqueeze(1)
171
- elif x.dim() != 4:
172
- raise ValueError(f"Expected specs rank 3 or 4, got shape {tuple(specs.shape)}")
173
- features = self.features(x)
174
- logits = self.classifier(features)
175
- return logits
176
-
177
-
178
- def set_seed(seed: int) -> None:
179
- random.seed(seed)
180
- np.random.seed(seed)
181
- torch.manual_seed(seed)
182
- if torch.cuda.is_available():
183
- torch.cuda.manual_seed(seed)
184
- torch.cuda.manual_seed_all(seed)
185
-
186
-
187
- def parse_expert_definitions(entries: Sequence[str]) -> Dict[str, ExpertSpec]:
188
- experts: Dict[str, ExpertSpec] = {}
189
- for entry in entries:
190
- if "=" not in entry:
191
- raise ValueError(f"Expert definition must use COMM=path syntax (got: {entry})")
192
- comm_part, _, path_part = entry.partition("=")
193
- comm = canonical_comm_name(comm_part)
194
- if not path_part:
195
- raise ValueError(f"Missing checkpoint path for expert '{comm}'")
196
- if ":" in path_part:
197
- ckpt_str, stats_str = path_part.split(":", 1)
198
- stats_path = Path(stats_str).expanduser().resolve()
199
- else:
200
- ckpt_str = path_part
201
- stats_path = None
202
- checkpoint = Path(ckpt_str).expanduser().resolve()
203
- if not checkpoint.exists():
204
- raise FileNotFoundError(f"Expert checkpoint not found: {checkpoint}")
205
- if stats_path is not None and not stats_path.exists():
206
- raise FileNotFoundError(f"Dataset stats file not found: {stats_path}")
207
- experts[comm] = ExpertSpec(comm=comm, checkpoint=checkpoint, stats_path=stats_path)
208
- return experts
209
-
210
-
211
- def discover_stats_path(comm: str, defaults_root: Path) -> Optional[Path]:
212
- candidates = [
213
- defaults_root / f"{comm}_models" / "dataset_stats.json",
214
- defaults_root / f"{comm.lower()}_models" / "dataset_stats.json",
215
- defaults_root / comm / "dataset_stats.json",
216
- defaults_root / comm.lower() / "dataset_stats.json",
217
- ]
218
- for candidate in candidates:
219
- if candidate.exists():
220
- return candidate
221
- return None
222
-
223
-
224
- def load_dataset_stats(stats_path: Optional[Path]) -> Mapping[str, float | str]:
225
- if stats_path is None:
226
- return {"mean": 0.0, "std": 1.0, "normalization": "per_sample"}
227
- with open(stats_path, "r", encoding="utf-8") as fh:
228
- return json.load(fh)
229
-
230
-
231
- def _collect_candidate_files(
232
- *,
233
- data_root: Path,
234
- cities: Sequence[str],
235
- comm: str,
236
- snr_filters: Optional[Sequence[str]],
237
- mobility_filters: Optional[Sequence[str]],
238
- modulation_filters: Optional[Sequence[str]],
239
- fft_filters: Optional[Sequence[str]],
240
- ) -> List[Tuple[Path, SampleMetadata]]:
241
- mobility_set = set(mobility_filters) if mobility_filters else None
242
- snr_set = set(snr_filters) if snr_filters else None
243
- modulation_set = {m.upper() for m in modulation_filters} if modulation_filters else None
244
- fft_set = set(fft_filters) if fft_filters else None
245
-
246
- candidates: List[Tuple[Path, SampleMetadata]] = []
247
- for city in cities:
248
- base = data_root / city / comm
249
- if not base.exists():
250
- continue
251
- pattern = str(base / "**" / "spectrograms" / "*.pkl")
252
- for path_str in glob.iglob(pattern, recursive=True):
253
- path = Path(path_str)
254
- _, modulation = identify_modulation(str(path))
255
- if modulation is None:
256
- continue
257
- if modulation_set is not None and modulation.upper() not in modulation_set:
258
- continue
259
- rate, snr, mobility = _extract_metadata(path.parts)
260
- if mobility_set is not None and mobility not in mobility_set:
261
- continue
262
- if snr_set is not None and snr not in snr_set:
263
- continue
264
- fft_folder = next((part for part in path.parts if part.startswith("win")), None)
265
- if fft_set is not None and fft_folder not in fft_set:
266
- continue
267
- meta = SampleMetadata(
268
- comm=comm,
269
- modulation=modulation,
270
- snr=snr,
271
- mobility=mobility,
272
- rate=rate,
273
- source=str(path),
274
- )
275
- candidates.append((path, meta))
276
- return candidates
277
-
278
-
279
- def _sample_from_file(
280
- array: np.ndarray,
281
- take: int,
282
- rng: np.random.Generator,
283
- ) -> np.ndarray:
284
- if take <= 0 or array.shape[0] == 0:
285
- return np.empty((0, 128, 128), dtype=np.float32)
286
- if take >= array.shape[0]:
287
- return array.astype(np.float32, copy=False)
288
- indices = rng.choice(array.shape[0], size=take, replace=False)
289
- return array[indices].astype(np.float32, copy=False)
290
-
291
-
292
- def collect_spectrograms_for_comm(
293
- *,
294
- data_root: Path,
295
- cities: Sequence[str],
296
- comm: str,
297
- snrs: Optional[Sequence[str]],
298
- mobilities: Optional[Sequence[str]],
299
- modulations: Optional[Sequence[str]],
300
- fft_folders: Optional[Sequence[str]],
301
- max_samples: int,
302
- max_per_combo: Optional[int],
303
- rng: np.random.Generator,
304
- ) -> Tuple[np.ndarray, List[SampleMetadata]]:
305
- candidates = _collect_candidate_files(
306
- data_root=data_root,
307
- cities=cities,
308
- comm=comm,
309
- snr_filters=snrs,
310
- mobility_filters=mobilities,
311
- modulation_filters=modulations,
312
- fft_filters=fft_folders,
313
- )
314
- if not candidates:
315
- raise RuntimeError(f"No spectrogram files matched filters for {comm}")
316
-
317
- rng.shuffle(candidates)
318
- combo_counts: MutableMapping[Tuple[str, str, str], int] = defaultdict(int)
319
- collected: List[np.ndarray] = []
320
- metadata: List[SampleMetadata] = []
321
- remaining: Optional[int] = max_samples if max_samples > 0 else None
322
- per_combo_limit: Optional[int] = max_per_combo if (max_per_combo is not None and max_per_combo > 0) else None
323
-
324
- for path, meta in candidates:
325
- if remaining is not None and remaining <= 0:
326
- break
327
- combo_key = (meta.modulation, meta.snr, meta.mobility)
328
- already = combo_counts[combo_key]
329
- if per_combo_limit is not None and already >= per_combo_limit:
330
- continue
331
-
332
- try:
333
- specs = load_all_samples(str(path))
334
- except Exception as exc: # pragma: no cover - guard against corrupted files
335
- print(f"[WARN] Failed to load {path}: {exc}")
336
- continue
337
-
338
- if specs.size == 0:
339
- continue
340
-
341
- remaining_for_combo = per_combo_limit - already if per_combo_limit is not None else specs.shape[0]
342
- allowed = min(remaining_for_combo, specs.shape[0])
343
- if remaining is not None:
344
- allowed = min(allowed, remaining)
345
- if allowed <= 0:
346
- continue
347
- chosen = _sample_from_file(specs, allowed, rng)
348
- if chosen.size == 0:
349
- continue
350
-
351
- collected.append(chosen)
352
- metadata.extend([meta] * chosen.shape[0])
353
- combo_counts[combo_key] += chosen.shape[0]
354
- if remaining is not None:
355
- remaining -= chosen.shape[0]
356
-
357
- if not collected:
358
- raise RuntimeError(f"Unable to collect samples for {comm} after applying limits")
359
-
360
- stacked = np.concatenate(collected, axis=0)
361
- return stacked.astype(np.float32, copy=False), metadata
362
-
363
-
364
- def stratified_split(
365
- labels: np.ndarray,
366
- *,
367
- train_ratio: float,
368
- val_ratio: float,
369
- seed: int,
370
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
371
- if not (0 < train_ratio < 1) or not (0 < val_ratio < 1):
372
- raise ValueError("train_ratio and val_ratio must be in (0, 1)")
373
- if train_ratio + val_ratio >= 1.0:
374
- raise ValueError("train_ratio + val_ratio must be < 1.0")
375
-
376
- rng = np.random.default_rng(seed)
377
- train_indices: List[int] = []
378
- val_indices: List[int] = []
379
- test_indices: List[int] = []
380
-
381
- for label in np.unique(labels):
382
- idx = np.flatnonzero(labels == label)
383
- if idx.size < 3:
384
- raise ValueError(f"Not enough samples for label {label} to form splits (need >=3, have {idx.size})")
385
- rng.shuffle(idx)
386
- train_end = int(round(train_ratio * idx.size))
387
- val_end = train_end + int(round(val_ratio * idx.size))
388
- train_indices.extend(idx[:train_end])
389
- val_indices.extend(idx[train_end:val_end])
390
- test_indices.extend(idx[val_end:])
391
-
392
- return (
393
- np.array(train_indices, dtype=np.int64),
394
- np.array(val_indices, dtype=np.int64),
395
- np.array(test_indices, dtype=np.int64),
396
- )
397
-
398
-
399
- def build_dataloaders(
400
- dataset: RoutedSpectrogramDataset,
401
- train_idx: np.ndarray,
402
- val_idx: np.ndarray,
403
- test_idx: np.ndarray,
404
- *,
405
- batch_size: int,
406
- num_workers: int,
407
- ) -> Tuple[DataLoader, DataLoader, DataLoader]:
408
- def subset(indices: np.ndarray) -> RoutedSpectrogramDataset:
409
- specs = dataset.specs[indices].numpy()
410
- comm = dataset.comm_labels[indices].numpy()
411
- task = dataset.task_labels[indices].numpy()
412
- meta = [dataset.metadata[int(i)] for i in indices]
413
- return RoutedSpectrogramDataset(specs, comm, task, meta)
414
-
415
- train_ds = subset(train_idx)
416
- val_ds = subset(val_idx)
417
- test_ds = subset(test_idx)
418
-
419
- train_loader = DataLoader(
420
- train_ds,
421
- batch_size=batch_size,
422
- shuffle=True,
423
- drop_last=len(train_ds) > batch_size,
424
- num_workers=num_workers,
425
- pin_memory=torch.cuda.is_available(),
426
- )
427
- val_loader = DataLoader(
428
- val_ds,
429
- batch_size=batch_size,
430
- shuffle=False,
431
- num_workers=num_workers,
432
- pin_memory=torch.cuda.is_available(),
433
- )
434
- test_loader = DataLoader(
435
- test_ds,
436
- batch_size=batch_size,
437
- shuffle=False,
438
- num_workers=num_workers,
439
- pin_memory=torch.cuda.is_available(),
440
- )
441
- return train_loader, val_loader, test_loader
442
-
443
-
444
- def infer_expert_signature(state_dict: Mapping[str, torch.Tensor]) -> Dict[str, object]:
445
- keys = set(state_dict.keys())
446
- # Determine input dimension (128 vs 130 if stats appended).
447
- layer_norm_weight = state_dict.get("classifier.0.weight")
448
- if layer_norm_weight is None:
449
- raise ValueError("Unable to infer classifier input dimension from checkpoint")
450
- input_dim = layer_norm_weight.numel()
451
- append_input_stats = input_dim > 128
452
-
453
- # Determine classifier type.
454
- if any(k.startswith("classifier.1.conv1") for k in keys):
455
- head_type = "res1dcnn"
456
- elif "classifier.1.weight" in keys:
457
- head_type = "mlp"
458
- elif "classifier.weight" in keys:
459
- head_type = "linear"
460
- else:
461
- raise ValueError("Unrecognized classifier architecture in checkpoint")
462
-
463
- # Hidden width for MLP head.
464
- classifier_dim = None
465
- if head_type == "mlp":
466
- classifier_dim = int(state_dict["classifier.1.weight"].shape[0])
467
-
468
- # Projection head dimensionality.
469
- if "projection_head.0.weight" in keys:
470
- projection_dim = int(state_dict["projection_head.0.weight"].shape[0])
471
- else:
472
- projection_dim = 0
473
-
474
- # Number of output classes from final linear weight.
475
- if head_type == "linear":
476
- num_classes = int(state_dict["classifier.weight"].shape[0])
477
- elif head_type == "mlp":
478
- num_classes = int(state_dict["classifier.2.weight"].shape[0])
479
- else: # res1dcnn
480
- num_classes = int(state_dict["classifier.1.fc.weight"].shape[0])
481
-
482
- return {
483
- "append_input_stats": append_input_stats,
484
- "input_dim": input_dim,
485
- "head_type": head_type,
486
- "classifier_dim": classifier_dim if classifier_dim is not None else 128,
487
- "projection_dim": projection_dim,
488
- "num_classes": num_classes,
489
- }
490
-
491
-
492
- def load_expert_model(
493
- spec: ExpertSpec,
494
- stats_root: Path,
495
- device: torch.device,
496
- ) -> Tuple[str, nn.Module, int]:
497
- raw_state = torch.load(spec.checkpoint, map_location="cpu")
498
- if any(k.startswith("module.") for k in raw_state):
499
- raw_state = {k.replace("module.", "", 1): v for k, v in raw_state.items()}
500
-
501
- signature = infer_expert_signature(raw_state)
502
-
503
- stats_path = spec.stats_path
504
- if stats_path is None:
505
- stats_path = discover_stats_path(spec.comm, stats_root)
506
- stats = load_dataset_stats(stats_path)
507
-
508
- model = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
509
- backbone_state = {
510
- k.split("backbone.", 1)[1]: v
511
- for k, v in raw_state.items()
512
- if k.startswith("backbone.")
513
- }
514
- model.load_state_dict(backbone_state, strict=False)
515
-
516
- classifier = LWMClassifierMinimalAdapter(
517
- backbone=model,
518
- num_classes=int(signature["num_classes"]),
519
- classifier_dim=int(signature["classifier_dim"]),
520
- head_type=str(signature["head_type"]),
521
- append_input_stats=bool(signature["append_input_stats"]),
522
- projection_dim=int(signature["projection_dim"]),
523
- normalization_stats=stats,
524
- )
525
- classifier.load_state_dict(raw_state, strict=True)
526
- classifier.eval()
527
- classifier.to(device)
528
- for param in classifier.parameters():
529
- param.requires_grad_(False)
530
- return spec.comm, classifier, int(signature["num_classes"])
531
-
532
-
533
- class LWMClassifierMinimalAdapter(nn.Module):
534
- """Thin wrapper matching task2.mobility_utils.LWMClassifierMinimal."""
535
-
536
- def __init__(
537
- self,
538
- *,
539
- backbone: nn.Module,
540
- num_classes: int,
541
- classifier_dim: int,
542
- head_type: str,
543
- append_input_stats: bool,
544
- projection_dim: int,
545
- normalization_stats: Mapping[str, float | str],
546
- ) -> None:
547
- super().__init__()
548
- from task2.mobility_utils import LWMClassifierMinimal # local import to avoid cycle
549
-
550
- self.inner = LWMClassifierMinimal(
551
- backbone=backbone,
552
- num_classes=num_classes,
553
- classifier_dim=classifier_dim,
554
- dropout=0.0,
555
- trainable_layers=0,
556
- projection_dim=projection_dim,
557
- append_input_stats=append_input_stats,
558
- normalization_stats=normalization_stats,
559
- head_type=head_type,
560
- )
561
-
562
- def forward(self, specs: torch.Tensor) -> torch.Tensor:
563
- return self.inner(specs)
564
-
565
-
566
- @torch.no_grad()
567
- def evaluate_router(
568
- model: nn.Module,
569
- loader: DataLoader,
570
- criterion: nn.Module,
571
- device: torch.device,
572
- ) -> Tuple[float, float, np.ndarray, np.ndarray]:
573
- model.eval()
574
- total_loss = 0.0
575
- correct = 0
576
- seen = 0
577
- y_true: List[int] = []
578
- y_pred: List[int] = []
579
-
580
- for specs, comm_labels, _ in loader:
581
- specs = specs.to(device, non_blocking=True)
582
- comm_labels = torch.as_tensor(comm_labels, device=device)
583
-
584
- logits = model(specs)
585
- loss = criterion(logits, comm_labels)
586
- preds = logits.argmax(dim=1)
587
- total_loss += loss.item() * specs.size(0)
588
- correct += (preds == comm_labels).sum().item()
589
- seen += specs.size(0)
590
- y_true.extend(comm_labels.detach().cpu().tolist())
591
- y_pred.extend(preds.detach().cpu().tolist())
592
-
593
- avg_loss = total_loss / max(seen, 1)
594
- acc = correct / max(seen, 1)
595
- return avg_loss, acc, np.array(y_true, dtype=np.int64), np.array(y_pred, dtype=np.int64)
596
-
597
-
598
- def compute_confusion(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> np.ndarray:
599
- matrix = np.zeros((num_classes, num_classes), dtype=np.int64)
600
- for true, pred in zip(y_true, y_pred):
601
- if 0 <= true < num_classes and 0 <= pred < num_classes:
602
- matrix[true, pred] += 1
603
- return matrix
604
-
605
-
606
- @torch.no_grad()
607
- def evaluate_routing(
608
- router: nn.Module,
609
- experts: Mapping[int, Tuple[str, nn.Module]],
610
- loader: DataLoader,
611
- *,
612
- num_comm: int,
613
- num_task_classes: int,
614
- device: torch.device,
615
- routing_mode: str,
616
- routing_topk: int,
617
- ) -> Dict[str, object]:
618
- router.eval()
619
- for _, model in experts.values():
620
- model.eval()
621
-
622
- criterion = nn.CrossEntropyLoss()
623
- total_loss = 0.0
624
- total = 0
625
- correct_router = 0
626
- correct_task = 0
627
-
628
- confusion_router = np.zeros((num_comm, num_comm), dtype=np.int64)
629
- confusion_task = np.zeros((num_task_classes, num_task_classes), dtype=np.int64)
630
- coverage = Counter() # type: ignore[type-arg]
631
-
632
- for specs, comm_labels, task_labels in loader:
633
- specs = specs.to(device, non_blocking=True)
634
- comm_labels = torch.as_tensor(comm_labels, device=device)
635
- task_labels = torch.as_tensor(task_labels, device=device)
636
-
637
- logits = router(specs)
638
- loss = criterion(logits, comm_labels)
639
- probs = torch.softmax(logits, dim=1)
640
- router_pred = probs.argmax(dim=1)
641
-
642
- batch = specs.size(0)
643
- total_loss += loss.item() * batch
644
- total += batch
645
- correct_router += (router_pred == comm_labels).sum().item()
646
-
647
- confusion_router += compute_confusion(
648
- comm_labels.detach().cpu().numpy(),
649
- router_pred.detach().cpu().numpy(),
650
- num_comm,
651
- )
652
-
653
- if not experts:
654
- continue
655
-
656
- weights = torch.zeros_like(probs)
657
- if routing_mode == "hard":
658
- weights.scatter_(1, router_pred.unsqueeze(1), 1.0)
659
- elif routing_mode == "soft":
660
- weights = probs
661
- elif routing_mode == "topk":
662
- topk = max(1, min(routing_topk, num_comm))
663
- topk_vals, topk_indices = probs.topk(topk, dim=1)
664
- weights.zero_()
665
- weights.scatter_(1, topk_indices, topk_vals)
666
- else:
667
- raise ValueError(f"Unsupported routing mode: {routing_mode}")
668
-
669
- final_logits = torch.zeros(batch, num_task_classes, device=device)
670
- for comm_idx, (name, expert) in experts.items():
671
- weight_column = weights[:, comm_idx]
672
- if not torch.any(weight_column > 0):
673
- continue
674
- outputs = expert(specs)
675
- if outputs.size(1) != num_task_classes:
676
- raise ValueError(
677
- f"Expert '{name}' returned {outputs.size(1)} classes, expected {num_task_classes}"
678
- )
679
- final_logits += weight_column.unsqueeze(1) * outputs
680
- coverage[name] += float(weight_column.sum().item())
681
-
682
- task_pred = final_logits.argmax(dim=1)
683
- correct_task += (task_pred == task_labels).sum().item()
684
- confusion_task += compute_confusion(
685
- task_labels.detach().cpu().numpy(),
686
- task_pred.detach().cpu().numpy(),
687
- num_task_classes,
688
- )
689
-
690
- metrics: Dict[str, object] = {
691
- "router_loss": total_loss / max(total, 1),
692
- "router_acc": correct_router / max(total, 1),
693
- "router_confusion": confusion_router.tolist(),
694
- "coverage": dict(coverage),
695
- }
696
- if experts:
697
- metrics["task_acc"] = correct_task / max(total, 1)
698
- metrics["task_confusion"] = confusion_task.tolist()
699
- return metrics
700
-
701
-
702
- def modulation_labels_from_metadata(metadata: Sequence[SampleMetadata]) -> np.ndarray:
703
- labels: List[int] = []
704
- for meta in metadata:
705
- label = MODULATION_LABELS.get(meta.modulation.upper())
706
- if label is None:
707
- raise ValueError(f"Unknown modulation label in metadata: {meta.modulation}")
708
- labels.append(label)
709
- return np.array(labels, dtype=np.int64)
710
-
711
-
712
- def snr_mobility_labels_from_metadata(
713
- metadata: Sequence[SampleMetadata],
714
- *,
715
- snr_order: Sequence[str],
716
- mobility_order: Sequence[str],
717
- ) -> Tuple[np.ndarray, Dict[int, Tuple[str, str]]]:
718
- combos: List[Tuple[str, str]] = []
719
- for snr in snr_order:
720
- for mobility in mobility_order:
721
- combos.append((snr, mobility))
722
- combo_to_idx = {combo: idx for idx, combo in enumerate(combos)}
723
-
724
- labels: List[int] = []
725
- for meta in metadata:
726
- combo = (meta.snr, meta.mobility)
727
- if combo not in combo_to_idx:
728
- raise ValueError(f"Sample combo {combo} not present in configured (snr, mobility) grid")
729
- labels.append(combo_to_idx[combo])
730
- mapping = {idx: combo for combo, idx in combo_to_idx.items()}
731
- return np.array(labels, dtype=np.int64), mapping
732
-
733
-
734
- def prepare_dataset(
735
- *,
736
- data_root: Path,
737
- cities: Sequence[str],
738
- comm_types: Sequence[str],
739
- snrs: Optional[Sequence[str]],
740
- mobilities: Optional[Sequence[str]],
741
- modulations: Optional[Sequence[str]],
742
- fft_folders: Optional[Sequence[str]],
743
- max_samples_per_comm: int,
744
- max_per_combo: Optional[int],
745
- task: str,
746
- seed: int,
747
- ) -> Tuple[RoutedSpectrogramDataset, Dict[str, int], Optional[Dict[int, Tuple[str, str]]]]:
748
- rng = np.random.default_rng(seed)
749
- specs_list: List[np.ndarray] = []
750
- comm_labels_list: List[int] = []
751
- metadata_list: List[SampleMetadata] = []
752
- comm_to_idx = {comm: idx for idx, comm in enumerate(comm_types)}
753
-
754
- for comm in comm_types:
755
- samples, metadata = collect_spectrograms_for_comm(
756
- data_root=data_root,
757
- cities=cities,
758
- comm=comm,
759
- snrs=snrs,
760
- mobilities=mobilities,
761
- modulations=modulations,
762
- fft_folders=fft_folders,
763
- max_samples=max_samples_per_comm,
764
- max_per_combo=max_per_combo,
765
- rng=rng,
766
- )
767
- specs_list.append(samples)
768
- metadata_list.extend(metadata)
769
- comm_labels_list.extend([comm_to_idx[comm]] * samples.shape[0])
770
-
771
- specs = np.concatenate(specs_list, axis=0)
772
- metadata = metadata_list
773
- comm_labels = np.array(comm_labels_list, dtype=np.int64)
774
-
775
- order = rng.permutation(specs.shape[0])
776
- specs = specs[order]
777
- comm_labels = comm_labels[order]
778
- metadata = [metadata[idx] for idx in order]
779
-
780
- normalized = normalize_per_sample(specs)
781
-
782
- if task == "modulation":
783
- task_labels = modulation_labels_from_metadata(metadata)
784
- mapping = None
785
- else:
786
- if snrs is None:
787
- snr_order = sorted({meta.snr for meta in metadata}, key=snr_sort_key)
788
- else:
789
- snr_order = [snr for snr in snrs if any(meta.snr == snr for meta in metadata)]
790
- if mobilities is None:
791
- mobility_order = sorted({meta.mobility for meta in metadata})
792
- else:
793
- mobility_order = [mob for mob in mobilities if any(meta.mobility == mob for meta in metadata)]
794
- task_labels, mapping = snr_mobility_labels_from_metadata(
795
- metadata,
796
- snr_order=snr_order,
797
- mobility_order=mobility_order,
798
- )
799
-
800
- dataset = RoutedSpectrogramDataset(normalized, comm_labels, task_labels, metadata)
801
- return dataset, comm_to_idx, mapping
802
-
803
-
804
- def parse_args() -> argparse.Namespace:
805
- parser = argparse.ArgumentParser(description=__doc__)
806
- parser.add_argument("--data-root", type=Path, default=Path("spectrograms"), help="Root directory with spectrogram data")
807
- parser.add_argument("--cities", nargs="*", default=["city_1_losangeles"], help="City folders to include")
808
- parser.add_argument("--comm-types", nargs="*", default=["LTE", "WiFi", "5G"], help="Communication standards to model")
809
- parser.add_argument("--snrs", nargs="*", default=None, help="SNR folders to include")
810
- parser.add_argument("--mobilities", nargs="*", default=None, help="Mobility folders to include")
811
- parser.add_argument("--modulations", nargs="*", default=None, help="Modulation classes to include (default: all)")
812
- parser.add_argument("--fft-folders", nargs="*", default=None, help="Specific FFT/window folders to include")
813
- parser.add_argument("--task", choices=("modulation", "snr_mobility"), default="snr_mobility", help="Downstream task label")
814
- parser.add_argument("--max-samples-per-comm", type=int, default=6000, help="Maximum samples per communication profile")
815
- parser.add_argument("--max-per-combo", type=int, default=512, help="Cap per (modulation,SNR,mobility) combo (0=unbounded)")
816
- parser.add_argument("--seed", type=int, default=42, help="Random seed")
817
- parser.add_argument(
818
- "--routing-mode",
819
- choices=("hard", "soft", "topk"),
820
- default="hard",
821
- help="Routing strategy: hard (top-1), soft (probability-weighted), or topk (restricted soft) (default: %(default)s)",
822
- )
823
- parser.add_argument(
824
- "--routing-topk",
825
- type=int,
826
- default=2,
827
- help="When routing-mode=topk, number of experts to evaluate per sample (default: %(default)s)",
828
- )
829
-
830
- parser.add_argument("--train-ratio", type=float, default=0.7, help="Fraction of data for training")
831
- parser.add_argument("--val-ratio", type=float, default=0.15, help="Fraction of data for validation")
832
- parser.add_argument("--batch-size", type=int, default=128, help="Mini-batch size")
833
- parser.add_argument("--epochs", type=int, default=20, help="Training epochs")
834
- parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
835
- parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay for AdamW")
836
- parser.add_argument("--dropout", type=float, default=0.1, help="Router dropout probability")
837
- parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
838
- parser.add_argument("--use-amp", action="store_true", help="Enable mixed precision for router training")
839
- parser.add_argument("--spec-augment", action="store_true", help="Apply SpecAugment to router inputs")
840
- parser.add_argument("--spec-augment-freq", type=int, default=12, help="Frequency mask width for SpecAugment")
841
- parser.add_argument("--spec-augment-time", type=int, default=16, help="Time mask width for SpecAugment")
842
- parser.add_argument("--spec-augment-prob", type=float, default=0.5, help="Probability to apply SpecAugment to a sample")
843
-
844
- parser.add_argument("--expert", action="append", default=[], help="Expert definition COMM=checkpoint[:stats_path]")
845
- parser.add_argument("--expert-stats-root", type=Path, default=Path("models"), help="Root to auto-discover dataset_stats.json")
846
-
847
- parser.add_argument("--output-dir", type=Path, default=Path("mixture/runs/top1_router"), help="Directory for logs and checkpoints")
848
- parser.add_argument("--save-router", action="store_true", help="Save best router state_dict to output directory")
849
-
850
- args = parser.parse_args()
851
-
852
- if args.max_per_combo is not None and args.max_per_combo < 0:
853
- parser.error("--max-per-combo must be >= 0")
854
- if args.spec_augment and not (0.0 <= args.spec_augment_prob <= 1.0):
855
- parser.error("--spec-augment-prob must be between 0 and 1")
856
- if args.max_samples_per_comm <= 0:
857
- parser.error("--max-samples-per-comm must be positive")
858
- if args.train_ratio <= 0 or args.val_ratio <= 0:
859
- parser.error("--train-ratio and --val-ratio must be positive")
860
- if args.train_ratio + args.val_ratio >= 1.0:
861
- parser.error("--train-ratio + --val-ratio must be < 1.0")
862
-
863
- return args
864
-
865
-
866
- def maybe_apply_spec_augment(
867
- specs: torch.Tensor,
868
- *,
869
- enabled: bool,
870
- freq_width: int,
871
- time_width: int,
872
- prob: float,
873
- ) -> torch.Tensor:
874
- if not enabled:
875
- return specs
876
- from task1.train_mcs_models import apply_spec_augment
877
-
878
- return apply_spec_augment(
879
- specs,
880
- freq_mask_width=freq_width,
881
- time_mask_width=time_width,
882
- mask_prob=prob,
883
- )
884
-
885
-
886
- def main() -> None:
887
- args = parse_args()
888
- set_seed(args.seed)
889
-
890
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
891
-
892
- comm_types = [canonical_comm_name(comm) for comm in args.comm_types]
893
- dataset, comm_to_idx, combo_mapping = prepare_dataset(
894
- data_root=args.data_root.expanduser().resolve(),
895
- cities=args.cities,
896
- comm_types=comm_types,
897
- snrs=args.snrs,
898
- mobilities=args.mobilities,
899
- modulations=args.modulations,
900
- fft_folders=args.fft_folders,
901
- max_samples_per_comm=args.max_samples_per_comm,
902
- max_per_combo=args.max_per_combo,
903
- task=args.task,
904
- seed=args.seed,
905
- )
906
- num_comm = len(comm_types)
907
- num_task_classes = int(dataset.task_labels.max()) + 1
908
-
909
- train_idx, val_idx, test_idx = stratified_split(
910
- dataset.comm_labels.numpy(),
911
- train_ratio=args.train_ratio,
912
- val_ratio=args.val_ratio,
913
- seed=args.seed,
914
- )
915
- train_loader, val_loader, test_loader = build_dataloaders(
916
- dataset,
917
- train_idx=train_idx,
918
- val_idx=val_idx,
919
- test_idx=test_idx,
920
- batch_size=args.batch_size,
921
- num_workers=args.num_workers,
922
- )
923
-
924
- router = RouterNet(num_comm=num_comm, dropout=args.dropout).to(device)
925
- criterion = nn.CrossEntropyLoss()
926
- optimizer = torch.optim.AdamW(router.parameters(), lr=args.lr, weight_decay=args.weight_decay)
927
- scaler = GradScaler(enabled=args.use_amp and device.type == "cuda")
928
-
929
- best_state: Optional[Dict[str, torch.Tensor]] = None
930
- best_val_acc = 0.0
931
-
932
- for epoch in range(1, args.epochs + 1):
933
- router.train()
934
- running_loss = 0.0
935
- running_correct = 0
936
- total = 0
937
-
938
- for specs, comm_labels, _ in train_loader:
939
- specs = specs.to(device, non_blocking=True)
940
- comm_labels = torch.as_tensor(comm_labels, device=device)
941
- specs_aug = maybe_apply_spec_augment(
942
- specs,
943
- enabled=args.spec_augment,
944
- freq_width=args.spec_augment_freq,
945
- time_width=args.spec_augment_time,
946
- prob=args.spec_augment_prob,
947
- )
948
-
949
- optimizer.zero_grad(set_to_none=True)
950
- context = autocast(device_type=device.type, enabled=scaler.is_enabled())
951
- with context:
952
- logits = router(specs_aug)
953
- loss = criterion(logits, comm_labels)
954
- if scaler.is_enabled():
955
- scaler.scale(loss).backward()
956
- scaler.step(optimizer)
957
- scaler.update()
958
- else:
959
- loss.backward()
960
- optimizer.step()
961
-
962
- preds = logits.argmax(dim=1)
963
- running_loss += loss.item() * specs.size(0)
964
- running_correct += (preds == comm_labels).sum().item()
965
- total += specs.size(0)
966
-
967
- train_loss = running_loss / max(total, 1)
968
- train_acc = running_correct / max(total, 1)
969
-
970
- val_loss, val_acc, y_true_val, y_pred_val = evaluate_router(router, val_loader, criterion, device)
971
- val_confusion = compute_confusion(y_true_val, y_pred_val, num_comm)
972
-
973
- print(
974
- f"[Epoch {epoch:02d}] train_loss={train_loss:.4f} "
975
- f"train_acc={train_acc:.3f} val_loss={val_loss:.4f} val_acc={val_acc:.3f}"
976
- )
977
-
978
- if val_acc >= best_val_acc:
979
- best_val_acc = val_acc
980
- best_state = {k: v.detach().cpu() for k, v in router.state_dict().items()}
981
- print(f"[Epoch {epoch:02d}] Val confusion matrix:\n{val_confusion}")
982
-
983
- if best_state is None:
984
- best_state = {k: v.detach().cpu() for k, v in router.state_dict().items()}
985
- router.load_state_dict(best_state)
986
-
987
- output_dir = args.output_dir.expanduser().resolve()
988
- output_dir.mkdir(parents=True, exist_ok=True)
989
-
990
- experts: Dict[int, Tuple[str, nn.Module]] = {}
991
- expert_specs = parse_expert_definitions(args.expert)
992
- for comm, spec in expert_specs.items():
993
- comm_idx = comm_to_idx.get(comm)
994
- if comm_idx is None:
995
- print(f"[WARN] Expert for {comm} provided but communication not in dataset; skipping")
996
- continue
997
- name, model, out_classes = load_expert_model(
998
- spec,
999
- stats_root=args.expert_stats_root.expanduser().resolve(),
1000
- device=device,
1001
- )
1002
- if out_classes != num_task_classes:
1003
- print(
1004
- f"[WARN] Expert '{name}' outputs {out_classes} classes, "
1005
- f"but dataset task expects {num_task_classes}. Skipping expert."
1006
- )
1007
- continue
1008
- experts[comm_idx] = (name, model)
1009
-
1010
- test_metrics = evaluate_routing(
1011
- router,
1012
- experts,
1013
- test_loader,
1014
- num_comm=num_comm,
1015
- num_task_classes=num_task_classes,
1016
- device=device,
1017
- routing_mode=args.routing_mode,
1018
- routing_topk=args.routing_topk,
1019
- )
1020
- print("[RESULT] Test metrics:")
1021
- print(json.dumps(test_metrics, indent=2))
1022
-
1023
- metrics_path = output_dir / "metrics.json"
1024
- with open(metrics_path, "w", encoding="utf-8") as fh:
1025
- json.dump(test_metrics, fh, indent=2)
1026
-
1027
- if combo_mapping is not None:
1028
- mapping_path = output_dir / "snr_mobility_mapping.json"
1029
- with open(mapping_path, "w", encoding="utf-8") as fh:
1030
- json.dump({int(k): v for k, v in combo_mapping.items()}, fh, indent=2)
1031
-
1032
- if args.save_router:
1033
- ckpt_path = output_dir / "router_top1_state_dict.pth"
1034
- torch.save(best_state, ckpt_path)
1035
- print(f"[INFO] Saved router checkpoint to {ckpt_path}")
1036
-
1037
-
1038
- if __name__ == "__main__":
1039
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretraining/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (159 Bytes)
 
pretraining/__pycache__/pretrained_model.cpython-311.pyc DELETED
Binary file (14.6 kB)
 
pretraining/pretrained_model.py CHANGED
@@ -178,10 +178,3 @@ def lwm(*args, **kwargs) -> LWM:
178
  """Factory to preserve backward compatibility with older imports."""
179
 
180
  return LWM(*args, **kwargs)
181
-
182
-
183
- class PretrainedLWM(LWM):
184
- """Alias retained for compatibility with existing inference scripts."""
185
-
186
- def __init__(self, *args, **kwargs) -> None:
187
- super().__init__(*args, **kwargs)
 
178
  """Factory to preserve backward compatibility with older imports."""
179
 
180
  return LWM(*args, **kwargs)
 
 
 
 
 
 
 
task1/plot_tsne.py DELETED
@@ -1,802 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Visualise how strongly metadata drives the learned embedding space.
3
-
4
- This script mirrors the functionality of ``task1/plot_mod_tsne.py`` but groups
5
- spectrograms by their SNR folder name (e.g. ``SNR0dB``) instead of modulation.
6
- It is useful for checking whether the self-supervised LWM backbone mostly
7
- captures channel/SNR differences rather than modulation characteristics.
8
-
9
- Pass ``--label-field modulation`` to reuse the same sampled spectrograms while
10
- colouring and scoring them by their modulation folder instead of SNR. Use
11
- ``--label-field mobility`` to highlight link-level mobility categories when
12
- present in the dataset tree. Saved figures automatically include the detected
13
- communication profile (e.g. LTE/WiFi/5G) and label mode in the filename when
14
- those suffixes are not already present.
15
-
16
- Usage example:
17
-
18
- ```bash
19
- python task1/plot_snr_tsne.py \
20
- --data-root spectrograms/city_1_losangeles/LTE \
21
- --snrs SNR-5dB,SNR0dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB \
22
- --save-path task1/snr_separation_plot_latest.png
23
- ```
24
- Shortcut presets:
25
-
26
- ```bash
27
- python task1/plot_snr_tsne.py --WiFi --report-metrics
28
- ```
29
- """
30
-
31
- from __future__ import annotations
32
-
33
- import argparse
34
- import glob
35
- import pickle
36
- import random
37
- import re
38
- from pathlib import Path
39
- from collections import Counter, defaultdict
40
- from typing import Dict, Iterable, List, Tuple
41
-
42
- import matplotlib.pyplot as plt
43
- import numpy as np
44
- import torch
45
- from sklearn.manifold import TSNE
46
- from sklearn.metrics import silhouette_score
47
- from sklearn.model_selection import StratifiedKFold
48
- from sklearn.neighbors import KNeighborsClassifier
49
- from sklearn.preprocessing import StandardScaler
50
-
51
- from pretraining.pretrained_model import lwm as lwm_model
52
- from utils import load_spectrogram_data # support .mat and .pkl uniformly
53
-
54
-
55
- DEFAULT_DATA_ROOT = "spectrograms/city_1_losangeles/LTE"
56
- DEFAULT_MODELS_ROOT = "models/LTE_models"
57
-
58
- PROFILE_PRESETS: Dict[str, Dict[str, str]] = {
59
- "LTE": {
60
- "data_root": DEFAULT_DATA_ROOT,
61
- "models_root": DEFAULT_MODELS_ROOT,
62
- },
63
- "WiFi": {
64
- "data_root": "spectrograms/city_1_losangeles/WiFi",
65
- "models_root": "models/WiFi_models",
66
- },
67
- "5G": {
68
- "data_root": "spectrograms/city_1_losangeles/5G",
69
- "models_root": "models/5G_models",
70
- },
71
- }
72
-
73
-
74
- def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
75
- means = specs.mean(axis=(1, 2), keepdims=True)
76
- stds = specs.std(axis=(1, 2), keepdims=True)
77
- stds = np.maximum(stds, eps)
78
- return ((specs - means) / stds).astype(np.float32, copy=False)
79
-
80
-
81
- def normalize_dataset(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
82
- mean = float(specs.mean())
83
- std = float(specs.std())
84
- std = max(std, eps)
85
- return ((specs - mean) / std).astype(np.float32, copy=False)
86
-
87
-
88
- # ---------------------------------------------------------------------------
89
- # Utility helpers
90
- # ---------------------------------------------------------------------------
91
-
92
- def parse_args() -> argparse.Namespace:
93
- parser = argparse.ArgumentParser(description=__doc__)
94
- parser.add_argument(
95
- "--data-root",
96
- default=DEFAULT_DATA_ROOT,
97
- help="Root directory containing modulation folders (default: %(default)s)",
98
- )
99
- parser.add_argument(
100
- "--modulation",
101
- default="all",
102
- help="Modulation folder to load (default: %(default)s)",
103
- )
104
- parser.add_argument(
105
- "--snrs",
106
- default="SNR-5dB,SNR0dB,SNR5dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB",
107
- help=(
108
- "Comma-separated list of SNR folder names to include. Pass 'all' "
109
- "to include every SNR discovered under the modulation (default: %(default)s)"
110
- ),
111
- )
112
- parser.add_argument(
113
- "--mobility",
114
- nargs="+",
115
- default=["all"],
116
- help=(
117
- "Mobility folder(s) to filter on. Pass 'all' to include every mobility "
118
- "(default: %(default)s). Multiple values can be provided either as a "
119
- "space-separated list (e.g. '--mobility vehicular pedestrian') or a "
120
- "comma-separated string."
121
- ),
122
- )
123
- parser.add_argument(
124
- "--fft-folder",
125
- default="all",
126
- help=(
127
- "FFT size folder name to use. Pass 'all' to include every FFT variant "
128
- "(default: %(default)s)"
129
- ),
130
- )
131
- parser.add_argument(
132
- "--samples-per-snr",
133
- type=int,
134
- default=500,
135
- help="Maximum number of samples to draw for each SNR label",
136
- )
137
- parser.add_argument(
138
- "--seed",
139
- type=int,
140
- default=42,
141
- help="Random seed for sampling and t-SNE",
142
- )
143
- parser.add_argument(
144
- "--pooling",
145
- choices=("mean", "cls"),
146
- default="mean",
147
- help="How to collapse token embeddings into a single vector",
148
- )
149
- parser.add_argument(
150
- "--save-path",
151
- default="task1/snr_separation_plot_latest.png",
152
- help="Location to save the generated figure (default: %(default)s)",
153
- )
154
- parser.add_argument(
155
- "--checkpoint",
156
- default=None,
157
- help="Optional explicit checkpoint path; overrides automatic latest selection",
158
- )
159
- parser.add_argument(
160
- "--models-root",
161
- default=DEFAULT_MODELS_ROOT,
162
- help=(
163
- "Directory containing checkpoints. When --checkpoint is not given, "
164
- "the latest/best checkpoint inside this directory will be used "
165
- "(default: %(default)s)"
166
- ),
167
- )
168
- preset_group = parser.add_mutually_exclusive_group()
169
- preset_group.add_argument(
170
- "--profile",
171
- dest="profile",
172
- choices=tuple(PROFILE_PRESETS.keys()),
173
- help=(
174
- "Convenience preset that sets --data-root and --models-root when they "
175
- "are left at their defaults"
176
- ),
177
- )
178
- preset_group.add_argument(
179
- "--LTE",
180
- dest="profile",
181
- action="store_const",
182
- const="LTE",
183
- help="Shortcut for --profile LTE",
184
- )
185
- preset_group.add_argument(
186
- "--WiFi",
187
- dest="profile",
188
- action="store_const",
189
- const="WiFi",
190
- help="Shortcut for --profile WiFi",
191
- )
192
- preset_group.add_argument(
193
- "--5G",
194
- dest="profile",
195
- action="store_const",
196
- const="5G",
197
- help="Shortcut for --profile 5G",
198
- )
199
- parser.add_argument(
200
- "--report-metrics",
201
- action="store_true",
202
- help="Print clustering metrics (silhouette, 5-fold kNN accuracy)",
203
- )
204
- parser.add_argument(
205
- "--metrics-only",
206
- action="store_true",
207
- help="Exit after reporting metrics without running t-SNE or saving figures",
208
- )
209
- parser.add_argument(
210
- "--sampling-mode",
211
- choices=("first", "reservoir"),
212
- default="first",
213
- help="How to down-sample each class (default: first)",
214
- )
215
- parser.add_argument(
216
- "--complex-mode",
217
- choices=("auto", "magnitude", "interleaved"),
218
- default="auto",
219
- help=(
220
- "How to handle complex spectrograms: 'magnitude' (abs), 'interleaved' (real/imag interleaved along width), "
221
- "or 'auto' (prefer interleaved when complex). Real-valued inputs are unaffected."
222
- ),
223
- )
224
- parser.add_argument(
225
- "--label-field",
226
- choices=("snr", "modulation", "mobility"),
227
- default="snr",
228
- help="Choose which label to visualise and score (default: %(default)s)",
229
- )
230
- parser.add_argument(
231
- "--normalization",
232
- choices=("per-sample", "dataset"),
233
- default="per-sample",
234
- help="Normalisation strategy applied before embedding extraction",
235
- )
236
- return parser.parse_args()
237
-
238
-
239
- def find_latest_checkpoint(models_root: Path) -> Path:
240
- """Return a checkpoint path under ``models_root``.
241
-
242
- Works with either a parent directory that contains multiple run folders,
243
- or directly with a single run directory containing ``*.pth`` files.
244
- Chooses the checkpoint with the lowest parsed validation value when
245
- available, else falls back to most-recent modification time.
246
- """
247
-
248
- if not models_root.exists():
249
- raise FileNotFoundError(f"Models root not found: {models_root}")
250
-
251
- if models_root.is_file():
252
- raise FileNotFoundError(f"Expected a directory, got file: {models_root}")
253
-
254
- # If the provided directory itself contains checkpoints, use it directly.
255
- checkpoints = list(models_root.glob("*.pth"))
256
- if not checkpoints:
257
- # Otherwise, look for subdirectories that contain checkpoints and ignore others (e.g., tensorboard)
258
- run_dirs = [p for p in models_root.iterdir() if p.is_dir()]
259
- candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))]
260
- if not candidate_runs:
261
- raise FileNotFoundError(
262
- f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)"
263
- )
264
- latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime)
265
- checkpoints = list(latest_run.glob("*.pth"))
266
-
267
- def parse_val_metric(path: Path) -> float | None:
268
- match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name)
269
- if match:
270
- try:
271
- return float(match.group(1))
272
- except ValueError:
273
- return None
274
- return None
275
-
276
- parsed = [(parse_val_metric(p), p) for p in checkpoints]
277
- valid = [item for item in parsed if item[0] is not None]
278
- if valid:
279
- valid.sort(key=lambda item: item[0])
280
- return valid[0][1]
281
-
282
- # Fallback to most recent modification time
283
- return max(checkpoints, key=lambda p: p.stat().st_mtime)
284
-
285
-
286
- def parse_snr_list(snr_argument: str | None) -> set[str] | None:
287
- if snr_argument is None or snr_argument.lower() == "all":
288
- return None
289
- values = [item.strip() for item in snr_argument.split(",") if item.strip()]
290
- return set(values)
291
-
292
-
293
- def list_snr_samples(
294
- data_root: Path,
295
- modulation: str,
296
- allowed_snrs: set[str] | None,
297
- mobility_filter: set[str] | None,
298
- fft_folder: str,
299
- max_per_class: int,
300
- rng: random.Random,
301
- mode: str,
302
- complex_mode: str,
303
- ) -> Dict[str, List[Tuple[np.ndarray, str, str]]]:
304
- """Collect spectrogram samples grouped by SNR label.
305
-
306
- Supports both legacy PKL layout with a trailing 'spectrograms/' folder and
307
- MATLAB .mat bundles saved directly under the mobility folder.
308
-
309
- Returns: mapping from SNR label to list of tuples: (spec, modulation, mobility)
310
- """
311
-
312
- class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]] = defaultdict(list)
313
- seen_counts: Dict[str, int] = defaultdict(int)
314
-
315
- # Search patterns:
316
- # - PKL under .../spectrograms/*.pkl
317
- # - MAT under .../spectrogram_*.mat
318
- patterns = [
319
- str(data_root / "**" / "spectrograms" / "*.pkl"),
320
- str(data_root / "**" / "spectrogram_*.mat"),
321
- ]
322
-
323
- mobility_set = {"static", "pedestrian", "vehicular"}
324
-
325
- def extract_tokens(rel_parts: Tuple[str, ...]) -> Tuple[str, str, str, str] | None:
326
- # Heuristic extraction to support both layouts
327
- # modulation: first path segment below data_root
328
- if not rel_parts:
329
- return None
330
- modulation_folder = rel_parts[0]
331
-
332
- # snr: first segment like SNR(-?)NdB
333
- snr_folder = next((p for p in rel_parts if re.match(r"^SNR-?\d+dB$", p)), None)
334
- if snr_folder is None:
335
- return None
336
-
337
- # mobility: one of known labels
338
- mobility_folder = next((p for p in rel_parts if p.lower() in mobility_set), None)
339
- if mobility_folder is None:
340
- return None
341
-
342
- # fft/window folder if present (PKL layout), else fallback for MAT
343
- fft_folder_name = next((p for p in rel_parts if p.startswith("win") or p.startswith("fft")), "fft_unknown")
344
-
345
- return modulation_folder, snr_folder, mobility_folder, fft_folder_name
346
-
347
- for pattern in patterns:
348
- for path_str in glob.iglob(pattern, recursive=True):
349
- path = Path(path_str)
350
- try:
351
- rel_parts = path.relative_to(data_root).parts
352
- except ValueError:
353
- continue
354
-
355
- tokens = extract_tokens(rel_parts)
356
- if tokens is None:
357
- continue
358
- modulation_folder, snr_folder, mobility_folder, fft_folder_name = tokens
359
-
360
- # Apply filters
361
- if modulation.lower() != "all" and modulation_folder != modulation:
362
- continue
363
- if allowed_snrs is not None and snr_folder not in allowed_snrs:
364
- continue
365
- if mobility_filter is not None and mobility_folder.lower() not in mobility_filter:
366
- continue
367
- if fft_folder != "all" and fft_folder_name != fft_folder:
368
- continue
369
-
370
- class_label = snr_folder
371
- if mode == "first" and len(class_samples[class_label]) >= max_per_class:
372
- continue
373
-
374
- # Load spectrogram data (supports .pkl and .mat)
375
- try:
376
- arr = load_spectrogram_data(str(path))
377
- except Exception as exc: # pragma: no cover - I/O heavy
378
- print(f"[WARN] Failed to load {path}: {exc}")
379
- continue
380
-
381
- if not isinstance(arr, np.ndarray) or arr.size == 0:
382
- continue
383
-
384
- # If loaded spectrograms are complex, convert according to mode
385
- if np.iscomplexobj(arr):
386
- if complex_mode == "magnitude":
387
- arr = np.abs(arr)
388
- else:
389
- # Interleave real/imag parts along the width dimension
390
- if arr.ndim == 4 and arr.shape[1] == 1:
391
- arr = arr[:, 0]
392
- if arr.ndim == 3:
393
- real = arr.real.astype(np.float32, copy=False)
394
- imag = arr.imag.astype(np.float32, copy=False)
395
- n, h, w = real.shape
396
- inter = np.empty((n, h, w * 2), dtype=np.float32)
397
- inter[:, :, 0::2] = real
398
- inter[:, :, 1::2] = imag
399
- arr = inter
400
- else:
401
- # Fallback to magnitude for unsupported shapes
402
- arr = np.abs(arr)
403
-
404
- # Normalize shapes:
405
- # - (N, H, W)
406
- # - (N, C, H, W) -> collapse channels via mean
407
- if arr.ndim == 4:
408
- # (N, C, H, W) -> (N, H, W)
409
- if arr.shape[1] > 1:
410
- specs = arr.mean(axis=1)
411
- else:
412
- specs = arr[:, 0]
413
- elif arr.ndim == 3:
414
- specs = arr
415
- elif arr.ndim == 2:
416
- specs = arr[None, ...]
417
- else:
418
- print(f"[WARN] Unexpected spectrogram shape in {path}: {arr.shape}")
419
- continue
420
-
421
- for spec in specs:
422
- sample = np.asarray(spec, dtype=np.float32)
423
- bucket = class_samples[class_label]
424
-
425
- if len(bucket) < max_per_class:
426
- bucket.append((sample, modulation_folder, mobility_folder))
427
- seen_counts[class_label] += 1
428
- elif mode == "reservoir":
429
- seen_counts[class_label] += 1
430
- j = rng.randint(0, seen_counts[class_label] - 1)
431
- if j < max_per_class:
432
- bucket[j] = (sample, modulation_folder, mobility_folder)
433
- else: # mode == "first" and already full
434
- break
435
-
436
- return class_samples
437
-
438
-
439
- def sample_balanced_dataset(
440
- class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]],
441
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[str]]:
442
- """Stack the sampled spectrograms alongside SNR, modulation, and mobility labels."""
443
-
444
- features: List[np.ndarray] = []
445
- snr_labels: List[str] = []
446
- modulation_labels: List[str] = []
447
- mobility_labels: List[str] = []
448
- class_names = sorted(class_samples.keys())
449
-
450
- for class_name in class_names:
451
- samples = class_samples[class_name]
452
- if not samples:
453
- continue
454
- for sample, modulation_label, mobility_label in samples:
455
- features.append(sample)
456
- snr_labels.append(class_name)
457
- modulation_labels.append(modulation_label)
458
- mobility_labels.append(mobility_label)
459
-
460
- if not features:
461
- raise RuntimeError("No spectrogram samples collected for the specified filters")
462
-
463
- stacked = np.stack(features) # [N, 128, 128]
464
- return (
465
- stacked,
466
- np.array(snr_labels),
467
- np.array(modulation_labels),
468
- np.array(mobility_labels),
469
- class_names,
470
- )
471
-
472
-
473
- def unfold_patches_square(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor:
474
- # Input shape: [B, H, W]; extracts (patch_size x patch_size) patches
475
- patches_h = x.unfold(1, patch_size, patch_size)
476
- patches = patches_h.unfold(2, patch_size, patch_size)
477
- return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size)
478
-
479
-
480
- def unfold_patches_rect(x: torch.Tensor, patch_rows: int = 4, patch_cols: int = 8) -> torch.Tensor:
481
- # Input shape: [B, H, W]; extracts (patch_rows x patch_cols) patches (for interleaved complex)
482
- patches_h = x.unfold(1, patch_rows, patch_rows)
483
- patches = patches_h.unfold(2, patch_cols, patch_cols)
484
- return patches.contiguous().view(x.shape[0], -1, patch_rows * patch_cols)
485
-
486
-
487
- def extract_tokens(spec: np.ndarray, device: torch.device, interleaved: bool) -> torch.Tensor:
488
- tensor = torch.from_numpy(spec).unsqueeze(0).to(device)
489
- if interleaved:
490
- # Rectangular patches 4x8 to cover 4x4 complex bins (real+imag)
491
- return unfold_patches_rect(tensor, 4, 8) # [1, 1024, 32]
492
- else:
493
- return unfold_patches_square(tensor, 4) # [1, 1024, 16]
494
-
495
-
496
- def pool_embeddings(
497
- tokens: torch.Tensor,
498
- model: torch.nn.Module,
499
- pooling: str,
500
- ) -> np.ndarray:
501
- # Append CLS token (value 0.2) before passing through the transformer.
502
- cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device)
503
- inputs = torch.cat([cls_token, tokens], dim=1) # [B, 1025, 16]
504
-
505
- with torch.no_grad():
506
- outputs = model(inputs) # [B, 1025, 128]
507
-
508
- if pooling == "cls":
509
- pooled = outputs[:, 0]
510
- else: # mean pooling across patch tokens (exclude CLS)
511
- pooled = outputs[:, 1:].mean(dim=1)
512
-
513
- return pooled.detach().cpu().numpy()
514
-
515
-
516
- def sort_snr_labels(labels: List[str]) -> List[str]:
517
- """Sort SNR labels by numeric value instead of lexicographic order."""
518
- def extract_snr_value(label: str) -> float:
519
- """Extract numeric SNR value from label like 'SNR-5dB' -> -5.0"""
520
- import re
521
- match = re.search(r'SNR(-?\d+)dB', label)
522
- if match:
523
- return float(match.group(1))
524
- else:
525
- return float('inf') # Put non-SNR labels at the end
526
-
527
- return sorted(labels, key=extract_snr_value)
528
-
529
-
530
- def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None:
531
- scaler = StandardScaler()
532
- x_scaled = scaler.fit_transform(x)
533
- # Guard against NaN/Inf from upstream (normalisation or model outputs)
534
- x_scaled = np.nan_to_num(x_scaled, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
535
- # Use a safe perplexity relative to sample count (sklearn requirement: < n_samples).
536
- max_perplexity = max(5, min(30, len(x_scaled) // 10))
537
- perplexity = min(max_perplexity, len(x_scaled) - 1)
538
- perplexity = max(perplexity, 5)
539
-
540
- tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
541
- embedding = tsne.fit_transform(x_scaled)
542
-
543
- class_names = sort_snr_labels(list(np.unique(labels)))
544
- colors = plt.cm.Set3(np.linspace(0, 1, len(class_names)))
545
- for color, class_name in zip(colors, class_names):
546
- mask = labels == class_name
547
- ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name)
548
-
549
- # ax.set_title(title, fontsize=14, fontweight="bold") # Title removed for paper
550
- ax.set_xlabel("t-SNE Component 1", fontsize=16)
551
- ax.set_ylabel("t-SNE Component 2", fontsize=16)
552
- ax.tick_params(labelsize=14) # Increase tick label size
553
- ax.grid(True, alpha=0.3)
554
- ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=12)
555
-
556
-
557
- def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None:
558
- if len(np.unique(labels)) < 2:
559
- print(f"[METRIC] {name}: skipped (only one class present)")
560
- return
561
-
562
- scaler = StandardScaler()
563
- features_scaled = scaler.fit_transform(features)
564
-
565
- silhouette = silhouette_score(features_scaled, labels)
566
-
567
- skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
568
- scores: List[float] = []
569
- for train_idx, test_idx in skf.split(features_scaled, labels):
570
- clf = KNeighborsClassifier(n_neighbors=5)
571
- clf.fit(features_scaled[train_idx], labels[train_idx])
572
- scores.append(clf.score(features_scaled[test_idx], labels[test_idx]))
573
-
574
- mean_acc = float(np.mean(scores))
575
- std_acc = float(np.std(scores))
576
- print(
577
- f"[METRIC] {name}: silhouette={silhouette:.3f}, "
578
- f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}"
579
- )
580
-
581
-
582
- # ---------------------------------------------------------------------------
583
- # Main execution
584
- # ---------------------------------------------------------------------------
585
-
586
-
587
- def main() -> None:
588
- args = parse_args()
589
-
590
- if args.profile:
591
- preset = PROFILE_PRESETS.get(args.profile)
592
- if not preset:
593
- raise ValueError(f"Unknown profile requested: {args.profile}")
594
- if args.data_root == DEFAULT_DATA_ROOT:
595
- args.data_root = preset["data_root"]
596
- if args.models_root == DEFAULT_MODELS_ROOT:
597
- args.models_root = preset["models_root"]
598
-
599
- if args.profile:
600
- print(f"[INFO] Profile preset active: {args.profile}")
601
-
602
- random.seed(args.seed)
603
- np.random.seed(args.seed)
604
- torch.manual_seed(args.seed)
605
-
606
- data_root = Path(args.data_root)
607
- if not data_root.exists():
608
- raise FileNotFoundError(f"Data root not found: {data_root}")
609
-
610
- allowed_snrs = parse_snr_list(args.snrs)
611
-
612
- mobility_filter: set[str] | None = None
613
- if args.mobility:
614
- mobility_values: List[str] = []
615
- for value in args.mobility:
616
- mobility_values.extend([item.strip() for item in value.split(",") if item.strip()])
617
- mobility_values = [value for value in mobility_values if value]
618
- if mobility_values and not (len(mobility_values) == 1 and mobility_values[0].lower() == "all"):
619
- mobility_filter = {value.lower() for value in mobility_values}
620
- print(
621
- "[INFO] Mobility filter active: "
622
- + ", ".join(sorted(mobility_filter))
623
- )
624
-
625
- class_samples = list_snr_samples(
626
- data_root,
627
- args.modulation,
628
- allowed_snrs,
629
- mobility_filter,
630
- args.fft_folder,
631
- args.samples_per_snr,
632
- random,
633
- args.sampling_mode,
634
- args.complex_mode,
635
- )
636
- samples, snr_labels, modulation_labels, mobility_labels, _ = sample_balanced_dataset(class_samples)
637
-
638
- if args.label_field == "snr":
639
- labels = snr_labels
640
- label_name = "SNR"
641
- label_display = "SNR"
642
- elif args.label_field == "modulation":
643
- labels = modulation_labels
644
- label_name = "modulation"
645
- label_display = "Modulation"
646
- else: # mobility
647
- labels = mobility_labels
648
- label_name = "mobility"
649
- label_display = "Mobility"
650
-
651
- unique_labels = np.unique(labels)
652
- print(
653
- f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(unique_labels)} {label_name} buckets"
654
- )
655
- class_counts = Counter(labels)
656
- print(f"[INFO] Samples per {label_name}:")
657
- for name, count in sorted(class_counts.items()):
658
- print(f" {name}: {count}")
659
-
660
- if args.label_field != "snr":
661
- snr_counts = Counter(snr_labels)
662
- print("[INFO] SNR distribution (sampling classes):")
663
- for name, count in sorted(snr_counts.items()):
664
- print(f" {name}: {count}")
665
- if args.label_field == "mobility":
666
- modulation_counts = Counter(modulation_labels)
667
- print("[INFO] Modulation distribution:")
668
- for name, count in sorted(modulation_counts.items()):
669
- print(f" {name}: {count}")
670
-
671
- normalization_mode = args.normalization
672
- if normalization_mode == "per-sample":
673
- normalized_samples = normalize_per_sample(samples)
674
- else:
675
- normalized_samples = normalize_dataset(samples)
676
- print(f"[INFO] Normalisation mode: {normalization_mode}")
677
-
678
- # Flatten spectrograms (after optional normalization) for the raw t-SNE view.
679
- raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1)
680
-
681
- # Prepare LWM model and embeddings for the right subplot.
682
- if args.checkpoint:
683
- checkpoint_path = Path(args.checkpoint)
684
- if not checkpoint_path.exists():
685
- raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
686
- else:
687
- checkpoint_path = find_latest_checkpoint(Path(args.models_root))
688
- print(f"[INFO] Using checkpoint: {checkpoint_path}")
689
-
690
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
691
- print(f"[INFO] Using device: {device}")
692
- print(f"[INFO] Pooling strategy: {args.pooling}")
693
- # Determine complex handling strategy for model/patching
694
- use_interleaved = False
695
- if args.complex_mode == "interleaved":
696
- use_interleaved = True
697
- elif args.complex_mode == "auto":
698
- # Heuristic: if any sample contains width > 128, assume interleaved (e.g., 128x256)
699
- sample_shape = tuple(normalized_samples.shape[1:])
700
- if len(sample_shape) == 2 and sample_shape[1] > 128:
701
- use_interleaved = True
702
-
703
- element_length = 32 if use_interleaved else 16
704
-
705
- model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
706
- state_dict = torch.load(checkpoint_path, map_location=device)
707
- if any(k.startswith("module.") for k in state_dict):
708
- state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
709
- try:
710
- model.load_state_dict(state_dict, strict=False)
711
- except RuntimeError as e:
712
- msg = str(e)
713
- # Fallback: checkpoint expects element_length=16 (magnitude), but we constructed 32 (interleaved)
714
- mismatch16 = "[128, 16]" in msg or "[16]" in msg
715
- mismatch32 = "[128, 32]" in msg or "[32]" in msg
716
- if mismatch16 and not mismatch32:
717
- print("[WARN] Checkpoint expects token dimension 16. Falling back to magnitude embedding.")
718
- use_interleaved = False
719
- element_length = 16
720
- # Recreate model and reload
721
- model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
722
- model.load_state_dict(state_dict, strict=False)
723
- else:
724
- raise
725
- model = model.to(device).eval()
726
-
727
- def collapse_interleaved_to_magnitude(spec: np.ndarray) -> np.ndarray:
728
- # spec: [H, 2W] with interleaved real/imag along width -> [H, W] magnitude
729
- h, w2 = spec.shape
730
- if w2 % 2 != 0:
731
- return spec # cannot collapse; return as-is
732
- real = spec[:, 0::2]
733
- imag = spec[:, 1::2]
734
- return np.sqrt(np.maximum(real * real + imag * imag, 0.0, dtype=np.float32))
735
-
736
- # If we fell back to magnitude (use_interleaved False) but inputs are interleaved, collapse for embeddings only
737
- embed_inputs = normalized_samples
738
- if not use_interleaved and normalized_samples.shape[2] > 128:
739
- collapsed = []
740
- for spec in normalized_samples:
741
- collapsed.append(collapse_interleaved_to_magnitude(spec))
742
- embed_inputs = np.stack(collapsed).astype(np.float32, copy=False)
743
-
744
- embeddings: List[np.ndarray] = []
745
- for spec in embed_inputs:
746
- tokens = extract_tokens(spec, device, interleaved=use_interleaved)
747
- embedding = pool_embeddings(tokens, model, args.pooling)
748
- embeddings.append(embedding.squeeze(0))
749
-
750
- embeddings_np = np.vstack(embeddings)
751
- print(f"[INFO] Generated embeddings with shape {embeddings_np.shape}")
752
-
753
- if args.report_metrics:
754
- compute_metrics("Raw spectrogram", raw_vectors, labels)
755
- pool_label = "LWM mean" if args.pooling == "mean" else "LWM CLS"
756
- compute_metrics(pool_label, embeddings_np, labels)
757
- if args.metrics_only:
758
- return
759
-
760
- # Plot results (two subplots matching the original figure format).
761
- fig, axes = plt.subplots(1, 2, figsize=(18, 7))
762
- raw_title = f"Raw Spectrogram t-SNE (by {label_display})"
763
- pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token"
764
- embedding_title = f"LWM Embedding t-SNE ({pooling_label}, by {label_display})"
765
- run_tsne(raw_vectors, labels, raw_title, axes[0])
766
- run_tsne(embeddings_np, labels, embedding_title, axes[1])
767
-
768
- fig.tight_layout()
769
- save_path = Path(args.save_path)
770
-
771
- communication_tag: str | None = None
772
- if args.profile:
773
- communication_tag = args.profile
774
- else:
775
- root_name = Path(args.data_root).name
776
- if root_name:
777
- communication_tag = root_name
778
-
779
- def ensure_suffix(stem: str, suffix: str) -> str:
780
- return stem if stem.endswith(suffix) else f"{stem}_{suffix}"
781
-
782
- updated_stem = save_path.stem
783
- if communication_tag:
784
- updated_stem = ensure_suffix(updated_stem, communication_tag)
785
- if args.label_field != "snr":
786
- label_suffix = f"by_{args.label_field}"
787
- updated_stem = ensure_suffix(updated_stem, label_suffix)
788
-
789
- if updated_stem != save_path.stem:
790
- save_path = save_path.with_name(f"{updated_stem}{save_path.suffix}")
791
- save_path.parent.mkdir(parents=True, exist_ok=True)
792
- plt.savefig(save_path, dpi=600, bbox_inches="tight")
793
- print(f"[INFO] Figure saved to {save_path}")
794
-
795
- # Also save PDF version for paper (vector format, no resolution limit)
796
- pdf_path = save_path.with_suffix('.pdf')
797
- plt.savefig(pdf_path, format='pdf', bbox_inches="tight")
798
- print(f"[INFO] PDF version saved to {pdf_path}")
799
-
800
-
801
- if __name__ == "__main__":
802
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
task1/train_mcs_models.py DELETED
The diff for this file is too large to render. See raw diff
 
task2/mobility_utils.py DELETED
@@ -1,414 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Shared mobility-classification utilities used across Task 2 helpers.
3
-
4
- This module provides the lightweight LWM classifier head plus supporting
5
- sampling and normalization helpers that were previously bundled inside the
6
- stand-alone mobility fine-tuning scripts. They remain available so that
7
- benchmarking, router training, and visualisation pipelines can reuse the same
8
- logic without depending on a separate CLI.
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import glob
14
- import json
15
- from collections import defaultdict
16
- from pathlib import Path
17
- from typing import Any, Dict, Iterable, List, Sequence, Tuple
18
-
19
- import numpy as np
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
-
24
- from pretraining.pretrained_model import lwm as lwm_model
25
- from task1.train_mcs_models import (
26
- _extract_metadata,
27
- identify_modulation,
28
- load_all_samples,
29
- )
30
-
31
- MOBILITY_LABELS = ["static", "pedestrian", "vehicular"]
32
- BINARY_MOBILITY_LABELS = ["vehicular", "pedestrian"]
33
-
34
-
35
- def load_dataset_stats(models_root: Path) -> Dict[str, float | str]:
36
- """Load dataset statistics (mean/std/normalization mode) from a models directory."""
37
- stats_path = models_root / "dataset_stats.json"
38
- if not stats_path.exists():
39
- print(
40
- f"[WARN] dataset_stats.json not found under {models_root}; "
41
- "falling back to per-sample normalization with mean=0/std=1.",
42
- flush=True,
43
- )
44
- return {"mean": 0.0, "std": 1.0, "normalization": "per_sample"}
45
- with open(stats_path, "r", encoding="utf-8") as f:
46
- stats = json.load(f)
47
- mean = float(stats.get("mean", 0.0))
48
- std = float(stats.get("std", 1.0))
49
- if std == 0.0:
50
- std = 1.0
51
- normalization = str(stats.get("normalization", stats.get("mode", "dataset")))
52
- return {
53
- "mean": mean,
54
- "std": std,
55
- "normalization": normalization,
56
- }
57
-
58
-
59
- def gather_controlled_groups(
60
- data_root: Path,
61
- cities: Sequence[str],
62
- comm: str,
63
- mobilities: Sequence[str],
64
- snrs: Sequence[str] | None,
65
- fft_whitelist: Sequence[str] | None,
66
- ) -> Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]]:
67
- """Group spectrogram paths by (city, modulation, rate, SNR, FFT) while balancing mobilities."""
68
- groups: Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list))
69
- mobility_set = set(mobilities)
70
- snr_set = set(snrs) if snrs else None
71
- fft_set = set(fft_whitelist) if fft_whitelist else None
72
-
73
- for city in cities:
74
- base = data_root / city / comm
75
- if not base.exists():
76
- continue
77
- pattern = str(base / "**" / "spectrograms" / "*.pkl")
78
- for path_str in glob.iglob(pattern, recursive=True):
79
- path = Path(path_str)
80
- rate, snr, mobility = _extract_metadata(path.parts)
81
- if mobility not in mobility_set:
82
- continue
83
- if snr_set is not None and snr not in snr_set:
84
- continue
85
- fft = next((part for part in path.parts if part.startswith("win")), "fft_unknown")
86
- if fft_set is not None and fft not in fft_set:
87
- continue
88
- _, modulation = identify_modulation(path_str)
89
- if modulation is None:
90
- continue
91
- key = (city, modulation, rate, snr, fft)
92
- groups[key][mobility].append(str(path))
93
- return {key: dict(mob_map) for key, mob_map in groups.items()}
94
-
95
-
96
- def _collect_balanced_arrays(
97
- groups: Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]],
98
- mobilities: Sequence[str],
99
- max_per_config: int,
100
- rng: np.random.Generator,
101
- ) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
102
- """Load spectrogram arrays with per-configuration balance across mobilities."""
103
- features: List[np.ndarray] = []
104
- labels: List[np.ndarray] = []
105
- mobility_to_idx = {mob: idx for idx, mob in enumerate(mobilities)}
106
- per_mobility_totals = {mob: 0 for mob in mobilities}
107
- matched_configs = 0
108
- preview_configs: List[Tuple[str, str, str, str, str]] = []
109
-
110
- for key, mobility_map in groups.items():
111
- if not all(mob in mobility_map for mob in mobilities):
112
- continue
113
-
114
- cached_arrays: Dict[str, np.ndarray] = {}
115
- per_mobility_counts: List[int] = []
116
- for mobility in mobilities:
117
- paths = mobility_map[mobility]
118
- collected: List[np.ndarray] = []
119
- for path in paths:
120
- arr = load_all_samples(path)
121
- if arr.size == 0:
122
- continue
123
- collected.append(arr)
124
- if not collected:
125
- cached_arrays = {}
126
- break
127
- stacked = np.concatenate(collected, axis=0)
128
- cached_arrays[mobility] = stacked
129
- per_mobility_counts.append(stacked.shape[0])
130
-
131
- if len(cached_arrays) != len(mobilities):
132
- continue
133
-
134
- limit = min(per_mobility_counts)
135
- if max_per_config > 0:
136
- limit = min(limit, max_per_config)
137
- if limit == 0:
138
- continue
139
-
140
- for mobility in mobilities:
141
- arr = cached_arrays[mobility]
142
- if arr.shape[0] > limit:
143
- indices = rng.permutation(arr.shape[0])[:limit]
144
- arr = arr[indices]
145
- features.append(arr)
146
- labels.append(np.full(arr.shape[0], mobility_to_idx[mob], dtype=np.int64))
147
- per_mobility_totals[mobility] += arr.shape[0]
148
-
149
- if matched_configs < 5:
150
- preview_configs.append(key)
151
- matched_configs += 1
152
-
153
- if not features:
154
- return (
155
- np.empty((0, 128, 128), dtype=np.float32),
156
- np.empty((0,), dtype=np.int64),
157
- {"per_mobility": per_mobility_totals, "matched_configs": matched_configs, "preview_configs": preview_configs},
158
- )
159
-
160
- stacked_features = np.concatenate(features, axis=0).astype(np.float32, copy=False)
161
- stacked_labels = np.concatenate(labels, axis=0).astype(np.int64, copy=False)
162
- return stacked_features, stacked_labels, {
163
- "per_mobility": per_mobility_totals,
164
- "matched_configs": matched_configs,
165
- "preview_configs": preview_configs,
166
- }
167
-
168
-
169
- class ResidualBlock1D(nn.Module):
170
- """1D Residual block used by the Res1DCNN classification head."""
171
-
172
- def __init__(self, in_channels: int, out_channels: int) -> None:
173
- super().__init__()
174
- self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
175
- self.bn1 = nn.BatchNorm1d(out_channels)
176
- self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
177
- self.bn2 = nn.BatchNorm1d(out_channels)
178
- self.shortcut = nn.Sequential()
179
- if in_channels != out_channels:
180
- self.shortcut = nn.Sequential(
181
- nn.Conv1d(in_channels, out_channels, kernel_size=1),
182
- nn.BatchNorm1d(out_channels),
183
- )
184
-
185
- def forward(self, x: torch.Tensor) -> torch.Tensor:
186
- residual = x
187
- x = F.relu(self.bn1(self.conv1(x)))
188
- x = self.bn2(self.conv2(x))
189
- x += self.shortcut(residual)
190
- return F.relu(x)
191
-
192
-
193
- class Res1DCNNHead(nn.Module):
194
- """Compact ResNet-style 1D head for classifying 128-d embeddings."""
195
-
196
- def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.5) -> None:
197
- super().__init__()
198
- hidden_dim = 64
199
- self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1)
200
- self.bn1 = nn.BatchNorm1d(hidden_dim)
201
- self.res_block = ResidualBlock1D(hidden_dim, hidden_dim)
202
- self.fc = nn.Linear(hidden_dim, num_classes)
203
- self.dropout = nn.Dropout(dropout)
204
-
205
- def forward(self, x: torch.Tensor) -> torch.Tensor:
206
- x = x.unsqueeze(1)
207
- x = F.relu(self.bn1(self.conv1(x)))
208
- x = self.res_block(x)
209
- x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
210
- x = self.dropout(x)
211
- return self.fc(x)
212
-
213
-
214
- class LWMClassifierMinimal(nn.Module):
215
- """LWM backbone wrapper with configurable classifier and optional projection head."""
216
-
217
- def __init__(
218
- self,
219
- backbone: nn.Module,
220
- num_classes: int,
221
- classifier_dim: int,
222
- dropout: float,
223
- trainable_layers: int,
224
- projection_dim: int,
225
- append_input_stats: bool,
226
- normalization_stats: Dict[str, object] | None,
227
- head_type: str = "mlp",
228
- ) -> None:
229
- super().__init__()
230
- self.backbone = backbone
231
- self.patch_size = 4
232
- self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size)
233
- self.head_type = head_type
234
-
235
- self.append_input_stats = bool(append_input_stats)
236
- stats_info = normalization_stats or {}
237
- self.normalization_mode = str(stats_info.get("normalization", "dataset")).lower()
238
- self.dataset_mean = float(stats_info.get("mean", 0.0))
239
- self.dataset_std = float(stats_info.get("std", 1.0))
240
- if abs(self.dataset_std) < 1e-6:
241
- self.dataset_std = 1e-6
242
- base_dim = 128
243
- stats_dim = 2 if self.append_input_stats else 0
244
- input_dim = base_dim + stats_dim
245
-
246
- classifier_dim = max(32, int(classifier_dim))
247
- dropout = max(0.0, float(dropout))
248
-
249
- if head_type == "linear":
250
- self.classifier = nn.Sequential(
251
- nn.LayerNorm(input_dim),
252
- nn.Linear(input_dim, num_classes),
253
- )
254
- elif head_type == "res1dcnn":
255
- self.classifier = nn.Sequential(
256
- nn.LayerNorm(input_dim),
257
- Res1DCNNHead(input_dim, num_classes, dropout=dropout),
258
- )
259
- else:
260
- head_layers: List[nn.Module] = [
261
- nn.LayerNorm(input_dim),
262
- nn.Linear(input_dim, classifier_dim),
263
- nn.GELU(),
264
- ]
265
- if dropout > 0:
266
- head_layers.append(nn.Dropout(dropout))
267
- head_layers.append(nn.Linear(classifier_dim, num_classes))
268
- self.classifier = nn.Sequential(*head_layers)
269
-
270
- proj_dim = int(projection_dim)
271
- if proj_dim > 0:
272
- self.projection_head = nn.Sequential(
273
- nn.Linear(128, proj_dim),
274
- nn.ReLU(inplace=True),
275
- nn.Linear(proj_dim, proj_dim),
276
- )
277
- else:
278
- self.projection_head = None
279
-
280
- for param in self.backbone.parameters():
281
- param.requires_grad = False
282
-
283
- if trainable_layers > 0:
284
- layers = getattr(self.backbone, "layers", None)
285
- if layers is not None:
286
- trainable_layers = min(trainable_layers, len(layers))
287
- for layer in layers[-trainable_layers:]:
288
- for param in layer.parameters():
289
- param.requires_grad = True
290
-
291
- def spectrogram_to_tokens(self, x: torch.Tensor) -> torch.Tensor:
292
- x = x.unsqueeze(1)
293
- patches = self.unfold(x).transpose(1, 2)
294
- cls_token = torch.full(
295
- (patches.size(0), 1, patches.size(-1)),
296
- 0.2,
297
- dtype=patches.dtype,
298
- device=patches.device,
299
- )
300
- return torch.cat([cls_token, patches], dim=1)
301
-
302
- def forward_features(self, x: torch.Tensor) -> torch.Tensor:
303
- tokens = self.spectrogram_to_tokens(x)
304
- outputs = self.backbone(tokens)
305
- if outputs.size(1) <= 1:
306
- return outputs[:, 0, :]
307
- return outputs[:, 1:, :].mean(dim=1)
308
-
309
- def _collect_input_stats(self, x: torch.Tensor) -> torch.Tensor:
310
- mean = x.mean(dim=(1, 2))
311
- std = x.std(dim=(1, 2), unbiased=False)
312
- if self.normalization_mode == "dataset":
313
- mean = mean * self.dataset_std + self.dataset_mean
314
- std = std * self.dataset_std
315
- return torch.stack([mean, std], dim=1)
316
-
317
- def forward(
318
- self,
319
- x: torch.Tensor,
320
- *,
321
- input_stats: torch.Tensor | None = None,
322
- return_projection: bool = False,
323
- ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
324
- features = self.forward_features(x)
325
- classifier_input = features
326
- if self.append_input_stats:
327
- stats = input_stats if input_stats is not None else self._collect_input_stats(x)
328
- if stats.dtype != classifier_input.dtype:
329
- stats = stats.to(classifier_input.dtype)
330
- stats = stats.to(classifier_input.device)
331
- classifier_input = torch.cat([classifier_input, stats], dim=1)
332
- logits = self.classifier(classifier_input)
333
- if return_projection:
334
- projection = self.projection_head(features) if self.projection_head is not None else None
335
- return logits, projection
336
- return logits
337
-
338
-
339
- def prepare_model(
340
- checkpoint: Path,
341
- num_classes: int,
342
- classifier_dim: int,
343
- dropout: float,
344
- trainable_layers: int,
345
- projection_dim: int,
346
- *,
347
- append_input_stats: bool = False,
348
- normalization_stats: Dict[str, object] | None = None,
349
- head_type: str = "mlp",
350
- ) -> nn.Module:
351
- """Instantiate an LWM backbone with the minimal classifier head."""
352
- backbone = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
353
- state = torch.load(checkpoint, map_location="cpu")
354
- if any(k.startswith("module.") for k in state):
355
- state = {k.replace("module.", ""): v for k, v in state.items()}
356
- backbone.load_state_dict(state, strict=False)
357
- return LWMClassifierMinimal(
358
- backbone,
359
- num_classes=num_classes,
360
- classifier_dim=classifier_dim,
361
- dropout=dropout,
362
- trainable_layers=trainable_layers,
363
- projection_dim=projection_dim,
364
- append_input_stats=append_input_stats,
365
- normalization_stats=normalization_stats,
366
- head_type=head_type,
367
- )
368
-
369
-
370
- def supervised_contrastive_loss(
371
- features: torch.Tensor,
372
- labels: torch.Tensor,
373
- temperature: float,
374
- ) -> torch.Tensor:
375
- """Supervised contrastive loss over a batch of feature embeddings."""
376
- batch_size = features.size(0)
377
- if batch_size < 2:
378
- return features.new_tensor(0.0)
379
-
380
- features = F.normalize(features, dim=1)
381
- similarity = torch.div(torch.matmul(features, features.T), max(temperature, 1e-6))
382
- logits_max, _ = similarity.max(dim=1, keepdim=True)
383
- similarity = similarity - logits_max.detach()
384
-
385
- device = features.device
386
- labels = labels.contiguous().view(-1, 1)
387
- mask = torch.eq(labels, labels.T).float().to(device)
388
- logits_mask = torch.ones_like(mask) - torch.eye(batch_size, device=device)
389
- mask = mask * logits_mask
390
-
391
- exp_logits = torch.exp(similarity) * logits_mask
392
- log_prob = similarity - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
393
-
394
- mask_sum = mask.sum(dim=1)
395
- valid = mask_sum > 0
396
- if not torch.any(valid):
397
- return features.new_tensor(0.0)
398
-
399
- mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask_sum.clamp_min(1e-12)
400
- loss = -mean_log_prob_pos[valid].mean()
401
- return loss
402
-
403
-
404
- __all__ = [
405
- "BINARY_MOBILITY_LABELS",
406
- "LWMClassifierMinimal",
407
- "MOBILITY_LABELS",
408
- "Res1DCNNHead",
409
- "_collect_balanced_arrays",
410
- "gather_controlled_groups",
411
- "load_dataset_stats",
412
- "prepare_model",
413
- "supervised_contrastive_loss",
414
- ]