Spaces:
Running
Running
“Namhyun-Kim”
commited on
Commit
·
aebafe2
1
Parent(s):
24c4d80
Update demo with MoE centroid evaluation
Browse files- .gitattributes +0 -22
- README.md +0 -10
- app.py +443 -374
- mixture/train_embedding_router.py +0 -0
- mixture/train_top1_router.py +0 -1039
- pretraining/__pycache__/__init__.cpython-311.pyc +0 -0
- pretraining/__pycache__/pretrained_model.cpython-311.pyc +0 -0
- pretraining/pretrained_model.py +0 -7
- task1/plot_tsne.py +0 -802
- task1/train_mcs_models.py +0 -0
- task2/mobility_utils.py +0 -414
.gitattributes
CHANGED
|
@@ -1,24 +1,2 @@
|
|
| 1 |
-
# Git LFS configuration for large model files
|
| 2 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 13 |
|
| 14 |
-
# Large data files
|
| 15 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
|
| 21 |
-
# Large image files (if needed)
|
| 22 |
-
*.png filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: LWM Spectro Demo
|
| 3 |
-
emoji: 🔬
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.5.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,412 +1,481 @@
|
|
| 1 |
|
| 2 |
-
import
|
| 3 |
-
import random
|
| 4 |
import sys
|
| 5 |
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
import huggingface_hub as hf_hub
|
| 8 |
-
|
| 9 |
-
# Gradio imports HfFolder; add shim before importing gradio.
|
| 10 |
-
if not hasattr(hf_hub, "HfFolder"):
|
| 11 |
-
class _HfFolderShim:
|
| 12 |
-
@staticmethod
|
| 13 |
-
def get_token():
|
| 14 |
-
return None
|
| 15 |
-
|
| 16 |
-
@staticmethod
|
| 17 |
-
def save_token(token):
|
| 18 |
-
return None
|
| 19 |
-
|
| 20 |
-
hf_hub.HfFolder = _HfFolderShim # type: ignore[attr-defined]
|
| 21 |
|
| 22 |
import gradio as gr
|
| 23 |
-
import torch
|
| 24 |
import numpy as np
|
| 25 |
import pandas as pd
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
from sklearn.decomposition import PCA
|
| 28 |
-
from sklearn.
|
| 29 |
-
from sklearn.metrics import
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
# Repo root for local imports
|
| 33 |
-
REPO_ROOT = Path(__file__).resolve().parent
|
| 34 |
if str(REPO_ROOT) not in sys.path:
|
| 35 |
sys.path.append(str(REPO_ROOT))
|
| 36 |
|
| 37 |
-
from mixture.train_embedding_router import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
# ------------------------------------------------------------------------------
|
| 40 |
-
# Data loading (t-SNE + evaluation)
|
| 41 |
-
# ------------------------------------------------------------------------------
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
records = []
|
| 48 |
-
for i,
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
# Get unique values for filters
|
| 65 |
-
tech_choices = sorted(list(df['tech'].unique()))
|
| 66 |
-
snr_choices = sorted(list(df['snr'].unique()))
|
| 67 |
-
mod_choices = sorted(list(df['mod'].unique()))
|
| 68 |
-
mob_choices = sorted(list(df['mob'].unique()))
|
| 69 |
|
| 70 |
def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter):
|
| 71 |
-
|
| 72 |
-
filtered_df = df.copy()
|
| 73 |
-
if not tech_filter:
|
| 74 |
-
return None, "Select at least one technology."
|
| 75 |
-
|
| 76 |
-
if tech_filter and len(tech_filter) > 0:
|
| 77 |
-
filtered_df = filtered_df[filtered_df['tech'].isin(tech_filter)]
|
| 78 |
-
|
| 79 |
-
if snr_filter and len(snr_filter) > 0:
|
| 80 |
-
filtered_df = filtered_df[filtered_df['snr'].isin(snr_filter)]
|
| 81 |
-
|
| 82 |
-
if mod_filter and len(mod_filter) > 0:
|
| 83 |
-
filtered_df = filtered_df[filtered_df['mod'].isin(mod_filter)]
|
| 84 |
-
|
| 85 |
-
if mob_filter and len(mob_filter) > 0:
|
| 86 |
-
filtered_df = filtered_df[filtered_df['mob'].isin(mob_filter)]
|
| 87 |
-
|
| 88 |
if len(filtered_df) < 5:
|
| 89 |
return None, f"Not enough data points ({len(filtered_df)}). Need at least 5."
|
| 90 |
-
|
| 91 |
-
# Select features
|
| 92 |
if representation == "LWM Embedding":
|
| 93 |
-
features = np.stack(filtered_df[
|
| 94 |
else:
|
| 95 |
-
features = np.stack(filtered_df[
|
| 96 |
-
# PCA for raw spectrograms to speed up t-SNE
|
| 97 |
if features.shape[1] > 50:
|
| 98 |
pca = PCA(n_components=50, random_state=42)
|
| 99 |
features = pca.fit_transform(features)
|
| 100 |
-
|
| 101 |
-
# Clean up NaNs/Infs that can blank out t-SNE plots
|
| 102 |
-
features = np.nan_to_num(features, copy=False)
|
| 103 |
-
# Match task1/plot_tsne.py preprocessing: standardize, clamp, float32
|
| 104 |
-
scaler = StandardScaler()
|
| 105 |
-
features = scaler.fit_transform(features)
|
| 106 |
-
features = np.nan_to_num(features, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
|
| 107 |
-
features = np.clip(features, -1e6, 1e6).astype(np.float32, copy=False)
|
| 108 |
-
|
| 109 |
-
# Run t-SNE
|
| 110 |
-
# Adjust perplexity if N is small; cap similarly to task1/plot_tsne.py
|
| 111 |
-
max_perplexity = max(5, min(30, len(filtered_df) // 10 if len(filtered_df) > 10 else len(filtered_df) - 1))
|
| 112 |
-
eff_perplexity = min(perplexity, len(filtered_df) - 1, max_perplexity)
|
| 113 |
-
eff_perplexity = max(eff_perplexity, 5)
|
| 114 |
-
|
| 115 |
-
tsne_kwargs = {"n_components": 2, "perplexity": eff_perplexity, "random_state": 42}
|
| 116 |
-
sig = inspect.signature(TSNE.__init__)
|
| 117 |
-
if "init" in sig.parameters:
|
| 118 |
-
tsne_kwargs["init"] = "random"
|
| 119 |
-
if "learning_rate" in sig.parameters:
|
| 120 |
-
tsne_kwargs["learning_rate"] = "auto"
|
| 121 |
-
if "n_iter" in sig.parameters:
|
| 122 |
-
tsne_kwargs["n_iter"] = n_iter
|
| 123 |
-
elif "max_iter" in sig.parameters:
|
| 124 |
-
tsne_kwargs["max_iter"] = n_iter
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
filtered_df[
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
fig.tight_layout()
|
| 174 |
-
|
| 175 |
-
coord_info = f"x[{x_min:.3f},{x_max:.3f}] y[{y_min:.3f},{y_max:.3f}]"
|
| 176 |
-
trace_info = f"traces: {len(filtered_df[color_by].unique())}"
|
| 177 |
-
return fig, f"{status_msg} | filtered samples: {len(filtered_df)} | {coord_info} | {trace_info}"
|
| 178 |
-
|
| 179 |
-
# ------------------------------------------------------------------------------
|
| 180 |
-
# Evaluation utilities (confusion matrix, F1) using the MoE checkpoint
|
| 181 |
-
# ------------------------------------------------------------------------------
|
| 182 |
-
|
| 183 |
-
_predictor: MoEPredictor | None = None
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def load_predictor() -> MoEPredictor:
|
| 187 |
-
global _predictor
|
| 188 |
-
if _predictor is not None:
|
| 189 |
-
return _predictor
|
| 190 |
-
|
| 191 |
-
# Prefer local checkpoint if present; otherwise pull from Hub
|
| 192 |
-
candidates = [
|
| 193 |
-
REPO_ROOT / "mixture" / "runs" / "embedding_router" / "moe_checkpoint.pth",
|
| 194 |
-
REPO_ROOT / "moe_checkpoint.pth",
|
| 195 |
-
]
|
| 196 |
-
ckpt_path = None
|
| 197 |
-
for cand in candidates:
|
| 198 |
-
if cand.exists():
|
| 199 |
-
ckpt_path = cand
|
| 200 |
-
break
|
| 201 |
-
if ckpt_path is None:
|
| 202 |
-
ckpt_path = Path(
|
| 203 |
-
hf_hub.hf_hub_download(repo_id="wi-lab/lwm-spectro", filename="moe_checkpoint.pth")
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
# Ensure expert checkpoints are resolvable in the Space (paths inside ckpt are absolute)
|
| 207 |
-
def ensure_expert(name: str, comm: str) -> Path:
|
| 208 |
-
"""Return a local path to the expert checkpoint, downloading if needed."""
|
| 209 |
-
fname = Path(name).name
|
| 210 |
-
comm_tag = comm.replace("/", "_")
|
| 211 |
-
local_candidates = [
|
| 212 |
-
REPO_ROOT / "experts" / fname,
|
| 213 |
-
REPO_ROOT / fname,
|
| 214 |
-
REPO_ROOT / "experts" / f"{comm_tag}_expert.pth",
|
| 215 |
-
REPO_ROOT / f"{comm_tag}_expert.pth",
|
| 216 |
-
]
|
| 217 |
-
for cand in local_candidates:
|
| 218 |
-
if cand.exists():
|
| 219 |
-
return cand
|
| 220 |
-
# Download from model repo with multiple filename guesses
|
| 221 |
-
download_candidates = [
|
| 222 |
-
f"experts/{fname}",
|
| 223 |
-
f"experts/{comm_tag}_expert.pth",
|
| 224 |
-
fname,
|
| 225 |
-
]
|
| 226 |
-
last_err = None
|
| 227 |
-
for rel in download_candidates:
|
| 228 |
-
try:
|
| 229 |
-
downloaded = hf_hub.hf_hub_download(
|
| 230 |
-
repo_id="wi-lab/lwm-spectro",
|
| 231 |
-
filename=rel,
|
| 232 |
-
)
|
| 233 |
-
return Path(downloaded)
|
| 234 |
-
except Exception as exc: # pragma: no cover - network/permissions issues
|
| 235 |
-
last_err = exc
|
| 236 |
-
continue
|
| 237 |
-
raise RuntimeError(f"Could not resolve expert checkpoint for {comm} ({fname}): {last_err}")
|
| 238 |
-
|
| 239 |
-
# Rewrite expert paths into a temp checkpoint so MoEPredictor loads cleanly
|
| 240 |
-
import torch # local import to keep top import list compact
|
| 241 |
-
|
| 242 |
-
raw_ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 243 |
-
experts = raw_ckpt.get("experts", [])
|
| 244 |
-
if experts:
|
| 245 |
-
patched = False
|
| 246 |
-
for expert in experts:
|
| 247 |
-
ckpt_field = expert.get("checkpoint")
|
| 248 |
-
if not ckpt_field:
|
| 249 |
-
continue
|
| 250 |
-
fname = Path(ckpt_field).name
|
| 251 |
-
comm = expert.get("comm", "unknown")
|
| 252 |
-
local_path = ensure_expert(fname, comm)
|
| 253 |
-
if str(local_path) != ckpt_field:
|
| 254 |
-
expert["checkpoint"] = str(local_path)
|
| 255 |
-
patched = True
|
| 256 |
-
if patched:
|
| 257 |
-
tmp_path = Path("/tmp/moe_checkpoint_patched.pth")
|
| 258 |
-
torch.save(raw_ckpt, tmp_path)
|
| 259 |
-
ckpt_path = tmp_path
|
| 260 |
-
|
| 261 |
-
_predictor = MoEPredictor.from_checkpoint(ckpt_path)
|
| 262 |
-
return _predictor
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
def _to_tensor(spec) -> torch.Tensor:
|
| 266 |
-
t = spec
|
| 267 |
-
if not isinstance(t, torch.Tensor):
|
| 268 |
-
t = torch.as_tensor(t)
|
| 269 |
-
if t.dim() == 2:
|
| 270 |
-
t = t.unsqueeze(0)
|
| 271 |
-
return t
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
def _normalize_label(val):
|
| 275 |
-
"""Convert labels to a simple string for metrics."""
|
| 276 |
-
if isinstance(val, (list, tuple)):
|
| 277 |
-
return " | ".join(str(v) for v in val)
|
| 278 |
-
return str(val)
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def compute_eval(task: str):
|
| 282 |
-
"""Compute confusion matrix + macro F1 with balanced sampling per class."""
|
| 283 |
-
predictor = load_predictor()
|
| 284 |
-
y_true, y_pred = [], []
|
| 285 |
-
|
| 286 |
-
# Balanced sampling per class
|
| 287 |
-
rng = random.Random(42)
|
| 288 |
-
per_class_target = 100
|
| 289 |
-
|
| 290 |
-
def class_key(sample):
|
| 291 |
-
if task == "comm":
|
| 292 |
-
return _normalize_label(sample["tech"])
|
| 293 |
-
return _normalize_label((sample["snr"], sample["mob"]))
|
| 294 |
-
|
| 295 |
-
buckets = {}
|
| 296 |
-
for s in raw_samples:
|
| 297 |
-
key = class_key(s)
|
| 298 |
-
buckets.setdefault(key, []).append(s)
|
| 299 |
-
|
| 300 |
-
selected = []
|
| 301 |
-
for key, items in buckets.items():
|
| 302 |
-
rng.shuffle(items)
|
| 303 |
-
take = min(per_class_target, len(items))
|
| 304 |
-
selected.extend(items[:take])
|
| 305 |
-
|
| 306 |
-
rng.shuffle(selected)
|
| 307 |
-
|
| 308 |
-
for sample in selected:
|
| 309 |
-
spec = _to_tensor(sample["data"])
|
| 310 |
-
try:
|
| 311 |
-
res = predictor.predict(spec, return_routing=True)
|
| 312 |
-
except Exception as exc:
|
| 313 |
-
print(f"[WARN] predict failed: {exc}")
|
| 314 |
-
continue
|
| 315 |
-
|
| 316 |
-
if task == "comm":
|
| 317 |
-
routing = res.get("routing") or []
|
| 318 |
-
pred = _normalize_label(routing[0]["comm"]) if routing else "Unknown"
|
| 319 |
-
true = _normalize_label(sample["tech"])
|
| 320 |
-
else: # snr_mobility
|
| 321 |
-
pred_raw = res.get("label", res["predicted_class"])
|
| 322 |
-
pred = _normalize_label(pred_raw)
|
| 323 |
-
true = _normalize_label((sample["snr"], sample["mob"]))
|
| 324 |
-
y_true.append(true)
|
| 325 |
-
y_pred.append(pred)
|
| 326 |
-
|
| 327 |
-
if not y_true or not y_pred:
|
| 328 |
-
raise RuntimeError("No samples were evaluated; check data or predictions.")
|
| 329 |
-
|
| 330 |
-
labels = sorted(list({*y_true, *y_pred}))
|
| 331 |
-
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 332 |
-
f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
|
| 333 |
-
acc = (np.array(y_true) == np.array(y_pred)).mean()
|
| 334 |
-
return cm, labels, f1, acc, len(y_true)
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
def plot_confusion(cm: np.ndarray, labels):
|
| 338 |
-
fig, ax = plt.subplots(figsize=(6, 5))
|
| 339 |
-
im = ax.imshow(cm, cmap="Blues")
|
| 340 |
-
ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 341 |
-
ax.set_xticks(np.arange(len(labels)), labels=labels, rotation=45, ha="right")
|
| 342 |
-
ax.set_yticks(np.arange(len(labels)), labels=labels)
|
| 343 |
-
ax.set_xlabel("Predicted")
|
| 344 |
-
ax.set_ylabel("True")
|
| 345 |
-
for i in range(cm.shape[0]):
|
| 346 |
-
for j in range(cm.shape[1]):
|
| 347 |
-
ax.text(j, i, int(cm[i, j]), ha="center", va="center", color="black")
|
| 348 |
-
fig.tight_layout()
|
| 349 |
return fig
|
| 350 |
|
| 351 |
|
| 352 |
-
def
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
-
# ------------------------------------------------------------------------------
|
| 360 |
-
# UI
|
| 361 |
-
# ------------------------------------------------------------------------------
|
| 362 |
with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
| 363 |
gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
|
| 364 |
-
gr.Markdown(
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
if __name__ == "__main__":
|
| 412 |
demo.launch()
|
|
|
|
| 1 |
|
| 2 |
+
import json
|
|
|
|
| 3 |
import sys
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import gradio as gr
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
+
import torch
|
| 13 |
from sklearn.decomposition import PCA
|
| 14 |
+
from sklearn.manifold import TSNE
|
| 15 |
+
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
|
| 16 |
+
|
| 17 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 18 |
+
APP_DIR = Path(__file__).resolve().parent
|
| 19 |
+
DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
|
| 20 |
+
MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
|
| 21 |
+
MOE_CHECKPOINT = REPO_ROOT / "mixture" / "runs" / "embedding_router" / "moe_checkpoint.pth"
|
| 22 |
+
SNR_MOB_MAPPING_PATH = REPO_ROOT / "mixture" / "runs" / "embedding_router" / "snr_mobility_mapping.json"
|
| 23 |
|
|
|
|
|
|
|
| 24 |
if str(REPO_ROOT) not in sys.path:
|
| 25 |
sys.path.append(str(REPO_ROOT))
|
| 26 |
|
| 27 |
+
from mixture.train_embedding_router import ( # type: ignore
|
| 28 |
+
MoEPredictor,
|
| 29 |
+
compute_selected_expert_embeddings,
|
| 30 |
+
normalize_per_sample_tensor,
|
| 31 |
+
stack_expert_embeddings,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_joint_mapping() -> Optional[Dict[str, object]]:
|
| 36 |
+
if not SNR_MOB_MAPPING_PATH.exists():
|
| 37 |
+
print(f"[WARN] Mapping file not found at {SNR_MOB_MAPPING_PATH}")
|
| 38 |
+
return None
|
| 39 |
+
raw = json.loads(SNR_MOB_MAPPING_PATH.read_text())
|
| 40 |
+
ordered_pairs: List[Tuple[str, str]] = []
|
| 41 |
+
for key in sorted(raw.keys(), key=lambda k: int(k)):
|
| 42 |
+
snr, mob = raw[key]
|
| 43 |
+
ordered_pairs.append((snr, mob))
|
| 44 |
+
label_names = [f"{snr} | {mob}" for snr, mob in ordered_pairs]
|
| 45 |
+
pair_to_name = {pair: name for pair, name in zip(ordered_pairs, label_names)}
|
| 46 |
+
name_to_id = {name: idx for idx, name in enumerate(label_names)}
|
| 47 |
+
pair_to_id = {pair: idx for idx, pair in enumerate(ordered_pairs)}
|
| 48 |
+
return {
|
| 49 |
+
"pairs": ordered_pairs,
|
| 50 |
+
"label_names": label_names,
|
| 51 |
+
"pair_to_name": pair_to_name,
|
| 52 |
+
"name_to_id": name_to_id,
|
| 53 |
+
"pair_to_id": pair_to_id,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def compute_moe_embeddings(
|
| 58 |
+
samples: Sequence[Dict[str, object]],
|
| 59 |
+
predictor: MoEPredictor,
|
| 60 |
+
batch_size: int = 64,
|
| 61 |
+
) -> torch.Tensor:
|
| 62 |
+
router = predictor.router
|
| 63 |
+
experts = predictor.experts
|
| 64 |
+
device = predictor.device
|
| 65 |
+
embeddings: List[torch.Tensor] = []
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
for start in range(0, len(samples), batch_size):
|
| 69 |
+
batch = samples[start : start + batch_size]
|
| 70 |
+
specs = torch.cat([sample["data"] for sample in batch], dim=0).to(device)
|
| 71 |
+
specs_norm = normalize_per_sample_tensor(specs)
|
| 72 |
+
|
| 73 |
+
if router is not None:
|
| 74 |
+
router_logits = router(specs_norm)
|
| 75 |
+
probs = torch.softmax(router_logits, dim=1)
|
| 76 |
+
topk_vals, topk_idx = probs.topk(k=predictor.topk, dim=1)
|
| 77 |
+
weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6)
|
| 78 |
+
selected_embeddings = compute_selected_expert_embeddings(
|
| 79 |
+
experts,
|
| 80 |
+
specs_norm,
|
| 81 |
+
topk_idx,
|
| 82 |
+
allow_grad=False,
|
| 83 |
+
)
|
| 84 |
+
weighted = (weights.unsqueeze(-1) * selected_embeddings).sum(dim=1)
|
| 85 |
+
else:
|
| 86 |
+
stacked = stack_expert_embeddings(experts, specs_norm)
|
| 87 |
+
weighted = stacked.mean(dim=1)
|
| 88 |
+
|
| 89 |
+
embeddings.append(weighted.cpu())
|
| 90 |
+
|
| 91 |
+
return torch.cat(embeddings, dim=0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def ensure_moe_embeddings(samples: List[Dict[str, object]]) -> Tuple[List[Dict[str, object]], bool]:
|
| 95 |
+
if MOE_DATA_PATH.exists():
|
| 96 |
+
cached = torch.load(MOE_DATA_PATH)
|
| 97 |
+
if len(cached) == len(samples):
|
| 98 |
+
print(f"[INFO] Loaded cached MoE embeddings from {MOE_DATA_PATH}")
|
| 99 |
+
return cached, True
|
| 100 |
+
print("[WARN] Cached MoE embeddings length mismatch. Recomputing...")
|
| 101 |
+
|
| 102 |
+
if not MOE_CHECKPOINT.exists():
|
| 103 |
+
print(f"[WARN] MoE checkpoint not found at {MOE_CHECKPOINT}. Skipping MoE embeddings.")
|
| 104 |
+
return samples, False
|
| 105 |
+
|
| 106 |
+
print("[INFO] Computing MoE embeddings using router checkpoint...")
|
| 107 |
+
predictor = MoEPredictor.from_checkpoint(MOE_CHECKPOINT)
|
| 108 |
+
moe_embeddings = compute_moe_embeddings(samples, predictor)
|
| 109 |
+
for sample, emb in zip(samples, moe_embeddings):
|
| 110 |
+
sample["moe_embedding"] = emb.detach().cpu()
|
| 111 |
+
|
| 112 |
+
torch.save(samples, MOE_DATA_PATH)
|
| 113 |
+
print(f"[INFO] Saved MoE-augmented dataset to {MOE_DATA_PATH}")
|
| 114 |
+
return samples, True
|
| 115 |
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
def load_data(mapping: Optional[Dict[str, object]]):
|
| 118 |
+
if not DEMO_DATA_PATH.exists():
|
| 119 |
+
raise FileNotFoundError(f"Dataset not found at {DEMO_DATA_PATH}")
|
| 120 |
+
|
| 121 |
+
print(f"[INFO] Loading base dataset from {DEMO_DATA_PATH}")
|
| 122 |
+
data: List[Dict[str, object]] = torch.load(DEMO_DATA_PATH)
|
| 123 |
+
data, has_moe = ensure_moe_embeddings(data)
|
| 124 |
+
|
| 125 |
+
pair_to_name = mapping["pair_to_name"] if mapping else {}
|
| 126 |
+
pair_to_id = mapping["pair_to_id"] if mapping else {}
|
| 127 |
+
|
| 128 |
records = []
|
| 129 |
+
for i, sample in enumerate(data):
|
| 130 |
+
embedding = sample["embedding"]
|
| 131 |
+
if isinstance(embedding, torch.Tensor):
|
| 132 |
+
base_embedding = embedding.detach().cpu().numpy()
|
| 133 |
+
else:
|
| 134 |
+
base_embedding = np.asarray(embedding)
|
| 135 |
+
|
| 136 |
+
spectrogram = sample["data"]
|
| 137 |
+
if isinstance(spectrogram, torch.Tensor):
|
| 138 |
+
flat_spec = spectrogram.numpy().flatten()
|
| 139 |
+
else:
|
| 140 |
+
flat_spec = np.asarray(spectrogram).flatten()
|
| 141 |
+
|
| 142 |
+
moe_embedding = sample.get("moe_embedding")
|
| 143 |
+
if isinstance(moe_embedding, torch.Tensor):
|
| 144 |
+
moe_embedding = moe_embedding.numpy()
|
| 145 |
+
elif moe_embedding is not None:
|
| 146 |
+
moe_embedding = np.asarray(moe_embedding)
|
| 147 |
+
|
| 148 |
+
pair = (sample["snr"], sample["mob"])
|
| 149 |
+
joint_label = pair_to_name.get(pair)
|
| 150 |
+
joint_label_id = pair_to_id.get(pair)
|
| 151 |
+
|
| 152 |
+
records.append(
|
| 153 |
+
{
|
| 154 |
+
"index": i,
|
| 155 |
+
"tech": sample["tech"],
|
| 156 |
+
"snr": sample["snr"],
|
| 157 |
+
"mod": sample["mod"],
|
| 158 |
+
"mob": sample["mob"],
|
| 159 |
+
"embedding": base_embedding,
|
| 160 |
+
"moe_embedding": moe_embedding,
|
| 161 |
+
"spectrogram": flat_spec,
|
| 162 |
+
"joint_label": joint_label,
|
| 163 |
+
"joint_label_id": joint_label_id,
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
|
| 167 |
+
df = pd.DataFrame(records)
|
| 168 |
+
print(f"[INFO] Loaded {len(df)} samples.")
|
| 169 |
+
return df, has_moe
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def apply_filters(
|
| 173 |
+
dataframe: pd.DataFrame,
|
| 174 |
+
tech_filter,
|
| 175 |
+
snr_filter,
|
| 176 |
+
mod_filter,
|
| 177 |
+
mob_filter,
|
| 178 |
+
) -> pd.DataFrame:
|
| 179 |
+
filtered = dataframe.copy()
|
| 180 |
+
if tech_filter:
|
| 181 |
+
filtered = filtered[filtered["tech"].isin(tech_filter)]
|
| 182 |
+
if snr_filter:
|
| 183 |
+
filtered = filtered[filtered["snr"].isin(snr_filter)]
|
| 184 |
+
if mod_filter:
|
| 185 |
+
filtered = filtered[filtered["mod"].isin(mod_filter)]
|
| 186 |
+
if mob_filter:
|
| 187 |
+
filtered = filtered[filtered["mob"].isin(mob_filter)]
|
| 188 |
+
return filtered
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter):
|
| 192 |
+
filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
if len(filtered_df) < 5:
|
| 194 |
return None, f"Not enough data points ({len(filtered_df)}). Need at least 5."
|
| 195 |
+
|
|
|
|
| 196 |
if representation == "LWM Embedding":
|
| 197 |
+
features = np.stack(filtered_df["embedding"].values)
|
| 198 |
else:
|
| 199 |
+
features = np.stack(filtered_df["spectrogram"].values)
|
|
|
|
| 200 |
if features.shape[1] > 50:
|
| 201 |
pca = PCA(n_components=50, random_state=42)
|
| 202 |
features = pca.fit_transform(features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
eff_perplexity = min(perplexity, len(filtered_df) - 1)
|
| 205 |
+
tsne = TSNE(
|
| 206 |
+
n_components=2,
|
| 207 |
+
perplexity=eff_perplexity,
|
| 208 |
+
n_iter=n_iter,
|
| 209 |
+
random_state=42,
|
| 210 |
+
init="pca",
|
| 211 |
+
learning_rate="auto",
|
| 212 |
+
)
|
| 213 |
+
projections = tsne.fit_transform(features)
|
| 214 |
+
filtered_df = filtered_df.copy()
|
| 215 |
+
filtered_df["x"] = projections[:, 0]
|
| 216 |
+
filtered_df["y"] = projections[:, 1]
|
| 217 |
+
|
| 218 |
+
fig = px.scatter(
|
| 219 |
+
filtered_df,
|
| 220 |
+
x="x",
|
| 221 |
+
y="y",
|
| 222 |
+
color=color_by,
|
| 223 |
+
hover_data=["tech", "snr", "mod", "mob"],
|
| 224 |
+
title=f"t-SNE of {representation} ({len(filtered_df)} samples)",
|
| 225 |
+
template="plotly_white",
|
| 226 |
+
)
|
| 227 |
+
fig.update_layout(legend_title_text=color_by.capitalize())
|
| 228 |
+
return fig, f"Displayed {len(filtered_df)} samples."
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
|
| 232 |
+
rng = np.random.default_rng(int(seed))
|
| 233 |
+
train_indices: List[int] = []
|
| 234 |
+
test_indices: List[int] = []
|
| 235 |
+
|
| 236 |
+
for label_id, group in filtered_df.groupby("joint_label_id"):
|
| 237 |
+
indices = group.index.to_numpy()
|
| 238 |
+
if indices.size < 2:
|
| 239 |
+
raise ValueError(f"Class '{CLASS_LABELS[int(label_id)]}' needs at least 2 samples for evaluation.")
|
| 240 |
+
|
| 241 |
+
rng.shuffle(indices)
|
| 242 |
+
split = int(round(indices.size * train_ratio))
|
| 243 |
+
split = max(1, min(indices.size - 1, split))
|
| 244 |
+
train_indices.extend(indices[:split])
|
| 245 |
+
test_indices.extend(indices[split:])
|
| 246 |
+
|
| 247 |
+
return np.array(train_indices), np.array(test_indices)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def compute_centroid_metrics(filtered_df: pd.DataFrame, train_idx: np.ndarray, test_idx: np.ndarray) -> Dict[str, object]:
|
| 251 |
+
train_subset = filtered_df.loc[train_idx]
|
| 252 |
+
test_subset = filtered_df.loc[test_idx]
|
| 253 |
+
|
| 254 |
+
train_embeddings = np.stack(train_subset["moe_embedding"].values)
|
| 255 |
+
test_embeddings = np.stack(test_subset["moe_embedding"].values)
|
| 256 |
+
train_labels = train_subset["joint_label_id"].to_numpy(dtype=int)
|
| 257 |
+
test_labels = test_subset["joint_label_id"].to_numpy(dtype=int)
|
| 258 |
+
|
| 259 |
+
unique_labels = np.unique(train_labels)
|
| 260 |
+
centroids = []
|
| 261 |
+
centroid_ids: List[int] = []
|
| 262 |
+
for label_id in unique_labels:
|
| 263 |
+
mask = train_labels == label_id
|
| 264 |
+
centroids.append(train_embeddings[mask].mean(axis=0))
|
| 265 |
+
centroid_ids.append(int(label_id))
|
| 266 |
+
|
| 267 |
+
centroids = np.stack(centroids)
|
| 268 |
+
centroid_ids = np.array(centroid_ids, dtype=int)
|
| 269 |
+
|
| 270 |
+
dists = ((test_embeddings[:, None, :] - centroids[None, :, :]) ** 2).sum(axis=-1)
|
| 271 |
+
preds = centroid_ids[np.argmin(dists, axis=1)]
|
| 272 |
+
|
| 273 |
+
accuracy = accuracy_score(test_labels, preds)
|
| 274 |
+
macro_f1 = f1_score(test_labels, preds, average="macro", labels=centroid_ids, zero_division=0)
|
| 275 |
+
|
| 276 |
+
active_ids = sorted(np.unique(np.concatenate([test_labels, preds])))
|
| 277 |
+
label_names = [CLASS_LABELS[i] for i in active_ids]
|
| 278 |
+
cm = confusion_matrix(test_labels, preds, labels=active_ids)
|
| 279 |
+
|
| 280 |
+
return {
|
| 281 |
+
"accuracy": accuracy,
|
| 282 |
+
"macro_f1": macro_f1,
|
| 283 |
+
"confusion": cm,
|
| 284 |
+
"label_names": label_names,
|
| 285 |
+
"train_size": len(train_idx),
|
| 286 |
+
"test_size": len(test_idx),
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def plot_confusion_heatmap(confusion: np.ndarray, label_names: List[str]) -> go.Figure:
|
| 291 |
+
fig = go.Figure(
|
| 292 |
+
data=go.Heatmap(
|
| 293 |
+
z=confusion,
|
| 294 |
+
x=label_names,
|
| 295 |
+
y=label_names,
|
| 296 |
+
colorscale="Viridis",
|
| 297 |
+
hovertemplate="Predicted %{x}<br>True %{y}<br>Count %{z}<extra></extra>",
|
| 298 |
)
|
| 299 |
+
)
|
| 300 |
+
fig.update_layout(
|
| 301 |
+
title="Prototype Classifier Confusion Matrix",
|
| 302 |
+
xaxis_title="Predicted",
|
| 303 |
+
yaxis_title="True",
|
| 304 |
+
xaxis=dict(tickangle=45),
|
| 305 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
return fig
|
| 307 |
|
| 308 |
|
| 309 |
+
def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter):
|
| 310 |
+
if joint_eval_df.empty:
|
| 311 |
+
fig = go.Figure()
|
| 312 |
+
fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 313 |
+
return fig, "MoE embeddings are not available for evaluation."
|
| 314 |
+
|
| 315 |
+
filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter)
|
| 316 |
+
if filtered.empty:
|
| 317 |
+
fig = go.Figure()
|
| 318 |
+
fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 319 |
+
return fig, "No samples match the selected filters."
|
| 320 |
|
| 321 |
+
if filtered["joint_label_id"].nunique() < 2:
|
| 322 |
+
fig = go.Figure()
|
| 323 |
+
fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 324 |
+
return fig, "Need at least two joint SNR/Doppler classes to evaluate."
|
| 325 |
+
|
| 326 |
+
try:
|
| 327 |
+
train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed)
|
| 328 |
+
except ValueError as exc:
|
| 329 |
+
fig = go.Figure()
|
| 330 |
+
fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 331 |
+
return fig, str(exc)
|
| 332 |
+
|
| 333 |
+
metrics = compute_centroid_metrics(filtered, train_idx, test_idx)
|
| 334 |
+
fig = plot_confusion_heatmap(metrics["confusion"], metrics["label_names"])
|
| 335 |
+
status = (
|
| 336 |
+
f"Train samples: {metrics['train_size']}\n"
|
| 337 |
+
f"Test samples: {metrics['test_size']}\n"
|
| 338 |
+
f"Accuracy: {metrics['accuracy'] * 100:.2f}%\n"
|
| 339 |
+
f"Macro F1: {metrics['macro_f1']:.3f}"
|
| 340 |
+
)
|
| 341 |
+
return fig, status
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
mapping_info = load_joint_mapping()
|
| 345 |
+
df, has_moe_embeddings = load_data(mapping_info)
|
| 346 |
+
CLASS_LABELS: List[str] = mapping_info["label_names"] if mapping_info else []
|
| 347 |
+
|
| 348 |
+
joint_eval_df = df.copy()
|
| 349 |
+
joint_eval_df = joint_eval_df[joint_eval_df["joint_label_id"].notna()]
|
| 350 |
+
joint_eval_df = joint_eval_df[joint_eval_df["moe_embedding"].notna()]
|
| 351 |
+
|
| 352 |
+
tech_choices = sorted(df["tech"].unique())
|
| 353 |
+
snr_choices = sorted(df["snr"].unique())
|
| 354 |
+
mod_choices = sorted(df["mod"].unique())
|
| 355 |
+
mob_choices = sorted(df["mob"].unique())
|
| 356 |
+
|
| 357 |
+
evaluation_disabled = joint_eval_df.empty
|
| 358 |
|
|
|
|
|
|
|
|
|
|
| 359 |
with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
| 360 |
gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
|
| 361 |
+
gr.Markdown(
|
| 362 |
+
"""
|
| 363 |
+
Compare **LWM embeddings** vs **Raw Spectrograms** for visualization, then evaluate **MoE embeddings**
|
| 364 |
+
with a lightweight prototype classifier for joint SNR/Doppler recognition.
|
| 365 |
+
"""
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
with gr.Tabs():
|
| 369 |
+
with gr.Tab("Visualization"):
|
| 370 |
+
with gr.Row():
|
| 371 |
+
with gr.Column(scale=1, min_width=300):
|
| 372 |
+
gr.Markdown("### Filters")
|
| 373 |
+
tech_filter = gr.CheckboxGroup(choices=tech_choices, value=tech_choices, label="Technology")
|
| 374 |
+
snr_filter = gr.Dropdown(
|
| 375 |
+
choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)"
|
| 376 |
+
)
|
| 377 |
+
mod_filter = gr.Dropdown(
|
| 378 |
+
choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)"
|
| 379 |
+
)
|
| 380 |
+
mob_filter = gr.Dropdown(
|
| 381 |
+
choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
gr.Markdown("### Visualization Settings")
|
| 385 |
+
representation = gr.Radio(
|
| 386 |
+
choices=["LWM Embedding", "Raw Spectrogram"],
|
| 387 |
+
value="LWM Embedding",
|
| 388 |
+
label="Representation",
|
| 389 |
+
)
|
| 390 |
+
color_by = gr.Dropdown(choices=["tech", "snr", "mod", "mob"], value="tech", label="Color By")
|
| 391 |
+
|
| 392 |
+
with gr.Accordion("Advanced t-SNE Settings", open=False):
|
| 393 |
+
perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity")
|
| 394 |
+
n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
|
| 395 |
+
|
| 396 |
+
btn = gr.Button("Update Plot", variant="primary")
|
| 397 |
+
status = gr.Textbox(label="Status", interactive=False)
|
| 398 |
+
|
| 399 |
+
with gr.Column(scale=3):
|
| 400 |
+
plot = gr.Plot(label="t-SNE Visualization")
|
| 401 |
+
|
| 402 |
+
btn.click(
|
| 403 |
+
plot_tsne,
|
| 404 |
+
inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
|
| 405 |
+
outputs=[plot, status],
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
demo.load(
|
| 409 |
+
plot_tsne,
|
| 410 |
+
inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
|
| 411 |
+
outputs=[plot, status],
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
with gr.Tab("Evaludation (Joint SNR/Doppler)"):
|
| 415 |
+
if evaluation_disabled:
|
| 416 |
+
gr.Markdown(
|
| 417 |
+
"⚠️ MoE embeddings are unavailable. Ensure `demo_data_moe.pt` exists or the checkpoint is present."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
with gr.Row():
|
| 421 |
+
with gr.Column(scale=1, min_width=320):
|
| 422 |
+
gr.Markdown("### Evaluation Filters")
|
| 423 |
+
eval_tech_filter = gr.CheckboxGroup(
|
| 424 |
+
choices=tech_choices,
|
| 425 |
+
value=tech_choices,
|
| 426 |
+
label="Technology",
|
| 427 |
+
interactive=not evaluation_disabled,
|
| 428 |
+
)
|
| 429 |
+
eval_snr_filter = gr.Dropdown(
|
| 430 |
+
choices=snr_choices,
|
| 431 |
+
value=None,
|
| 432 |
+
multiselect=True,
|
| 433 |
+
label="SNR (Empty = All)",
|
| 434 |
+
interactive=not evaluation_disabled,
|
| 435 |
+
)
|
| 436 |
+
eval_mod_filter = gr.Dropdown(
|
| 437 |
+
choices=mod_choices,
|
| 438 |
+
value=None,
|
| 439 |
+
multiselect=True,
|
| 440 |
+
label="Modulation (Empty = All)",
|
| 441 |
+
interactive=not evaluation_disabled,
|
| 442 |
+
)
|
| 443 |
+
eval_mob_filter = gr.Dropdown(
|
| 444 |
+
choices=mob_choices,
|
| 445 |
+
value=None,
|
| 446 |
+
multiselect=True,
|
| 447 |
+
label="Mobility (Empty = All)",
|
| 448 |
+
interactive=not evaluation_disabled,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
gr.Markdown("### Prototype Settings")
|
| 452 |
+
train_pct = gr.Slider(
|
| 453 |
+
minimum=10,
|
| 454 |
+
maximum=80,
|
| 455 |
+
step=5,
|
| 456 |
+
value=60,
|
| 457 |
+
label="Training Percentage (%)",
|
| 458 |
+
interactive=not evaluation_disabled,
|
| 459 |
+
)
|
| 460 |
+
seed = gr.Slider(
|
| 461 |
+
minimum=0,
|
| 462 |
+
maximum=9999,
|
| 463 |
+
step=1,
|
| 464 |
+
value=42,
|
| 465 |
+
label="Random Seed",
|
| 466 |
+
interactive=not evaluation_disabled,
|
| 467 |
+
)
|
| 468 |
+
eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled)
|
| 469 |
+
|
| 470 |
+
with gr.Column(scale=3):
|
| 471 |
+
eval_plot = gr.Plot(label="Prototype Confusion Matrix")
|
| 472 |
+
eval_status = gr.Textbox(label="Metrics", interactive=False)
|
| 473 |
+
|
| 474 |
+
eval_btn.click(
|
| 475 |
+
run_joint_evaluation,
|
| 476 |
+
inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter],
|
| 477 |
+
outputs=[eval_plot, eval_status],
|
| 478 |
+
)
|
| 479 |
|
| 480 |
if __name__ == "__main__":
|
| 481 |
demo.launch()
|
mixture/train_embedding_router.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mixture/train_top1_router.py
DELETED
|
@@ -1,1039 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Train a communication-router with top-1 hard expert selection.
|
| 3 |
-
|
| 4 |
-
The script builds a supervised mixture-of-experts pipeline:
|
| 5 |
-
|
| 6 |
-
1. Gather spectrogram samples for each communication profile (LTE/WiFi/5G).
|
| 7 |
-
2. Train a lightweight CNN router that predicts the communication label.
|
| 8 |
-
3. (Optional) Attach pre-trained experts and evaluate top-1 hard routing by
|
| 9 |
-
running only the expert picked by the router's argmax for each sample.
|
| 10 |
-
|
| 11 |
-
Expert checkpoints are expected to be LWM-based classifiers (for example those
|
| 12 |
-
produced by `task2/train_joint_snr_mobility.py` or earlier mobility fine-tuning
|
| 13 |
-
pipelines).
|
| 14 |
-
Their architecture is inferred from the checkpoint to avoid manual plumbing.
|
| 15 |
-
|
| 16 |
-
Example:
|
| 17 |
-
|
| 18 |
-
```bash
|
| 19 |
-
python mixture/train_top1_router.py \
|
| 20 |
-
--data-root spectrograms \
|
| 21 |
-
--cities city_1_losangeles \
|
| 22 |
-
--comm-types LTE WiFi 5G \
|
| 23 |
-
--task snr_mobility \
|
| 24 |
-
--mobilities vehicular pedestrian \
|
| 25 |
-
--snrs SNR-5dB SNR0dB SNR5dB SNR10dB SNR15dB \
|
| 26 |
-
--max-samples-per-comm 6000 \
|
| 27 |
-
--max-per-combo 400 \
|
| 28 |
-
--epochs 25 \
|
| 29 |
-
--batch-size 128 \
|
| 30 |
-
--lr 3e-4 \
|
| 31 |
-
--output-dir mixture/runs/top1_router \
|
| 32 |
-
--expert LTE=models/doppler_finetuned_binary/lte/lwm_lte_doppler_val90.67.pth \
|
| 33 |
-
--expert WiFi=models/doppler_finetuned_binary/wifi/lwm_wifi_doppler_val95.01.pth \
|
| 34 |
-
--expert 5G=models/doppler_finetuned_binary/5g/lwm_5g_doppler_val96.05.pth
|
| 35 |
-
```
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
from __future__ import annotations
|
| 39 |
-
|
| 40 |
-
import argparse
|
| 41 |
-
import json
|
| 42 |
-
import random
|
| 43 |
-
from collections import Counter, defaultdict
|
| 44 |
-
from dataclasses import dataclass
|
| 45 |
-
from pathlib import Path
|
| 46 |
-
from typing import Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
| 47 |
-
|
| 48 |
-
import glob
|
| 49 |
-
import numpy as np
|
| 50 |
-
import torch
|
| 51 |
-
import torch.nn as nn
|
| 52 |
-
import torch.nn.functional as F
|
| 53 |
-
from torch.amp import GradScaler, autocast
|
| 54 |
-
from torch.utils.data import DataLoader, Dataset
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
try:
|
| 58 |
-
from task1.train_mcs_models import (
|
| 59 |
-
MODULATION_LABELS,
|
| 60 |
-
identify_modulation,
|
| 61 |
-
load_all_samples,
|
| 62 |
-
normalize_per_sample,
|
| 63 |
-
_extract_metadata,
|
| 64 |
-
)
|
| 65 |
-
except ImportError as exc: # pragma: no cover - safety net
|
| 66 |
-
raise ImportError(
|
| 67 |
-
"Failed to import helpers from task1.train_mcs_models. "
|
| 68 |
-
"Ensure the repository root is on PYTHONPATH."
|
| 69 |
-
) from exc
|
| 70 |
-
|
| 71 |
-
try:
|
| 72 |
-
from task2.train_joint_snr_mobility import snr_sort_key
|
| 73 |
-
except ImportError: # pragma: no cover - fallback if task2 module is unavailable
|
| 74 |
-
|
| 75 |
-
def snr_sort_key(snr: str) -> Tuple[int, str]:
|
| 76 |
-
import re
|
| 77 |
-
|
| 78 |
-
match = re.search(r"SNR(-?\d+)dB", snr)
|
| 79 |
-
if match:
|
| 80 |
-
return int(match.group(1)), snr
|
| 81 |
-
return 0, snr
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
from pretraining.pretrained_model import lwm as lwm_model
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
COMM_CANONICAL = {"lte": "LTE", "wifi": "WiFi", "5g": "5G"}
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def canonical_comm_name(name: str) -> str:
|
| 91 |
-
lower = name.strip().lower()
|
| 92 |
-
if lower in COMM_CANONICAL:
|
| 93 |
-
return COMM_CANONICAL[lower]
|
| 94 |
-
for canonical in COMM_CANONICAL.values():
|
| 95 |
-
if canonical.lower() == lower:
|
| 96 |
-
return canonical
|
| 97 |
-
raise ValueError(f"Unknown communication type: {name}")
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
@dataclass(slots=True)
|
| 101 |
-
class SampleMetadata:
|
| 102 |
-
comm: str
|
| 103 |
-
modulation: str
|
| 104 |
-
snr: str
|
| 105 |
-
mobility: str
|
| 106 |
-
rate: str
|
| 107 |
-
source: str
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
@dataclass(slots=True)
|
| 111 |
-
class ExpertSpec:
|
| 112 |
-
comm: str
|
| 113 |
-
checkpoint: Path
|
| 114 |
-
stats_path: Optional[Path]
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
class RoutedSpectrogramDataset(Dataset):
|
| 118 |
-
"""Spectrogram dataset that tracks both router and downstream labels."""
|
| 119 |
-
|
| 120 |
-
def __init__(
|
| 121 |
-
self,
|
| 122 |
-
specs: np.ndarray,
|
| 123 |
-
comm_labels: np.ndarray,
|
| 124 |
-
task_labels: np.ndarray,
|
| 125 |
-
metadata: List[SampleMetadata],
|
| 126 |
-
) -> None:
|
| 127 |
-
if not (len(specs) == len(comm_labels) == len(task_labels) == len(metadata)):
|
| 128 |
-
raise ValueError("All dataset inputs must have the same length")
|
| 129 |
-
self.specs = torch.from_numpy(specs.astype(np.float32, copy=False))
|
| 130 |
-
self.comm_labels = torch.from_numpy(comm_labels.astype(np.int64, copy=False))
|
| 131 |
-
self.task_labels = torch.from_numpy(task_labels.astype(np.int64, copy=False))
|
| 132 |
-
self.metadata = metadata
|
| 133 |
-
|
| 134 |
-
def __len__(self) -> int:
|
| 135 |
-
return self.specs.shape[0]
|
| 136 |
-
|
| 137 |
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]:
|
| 138 |
-
return self.specs[idx], int(self.comm_labels[idx]), int(self.task_labels[idx])
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
class RouterNet(nn.Module):
|
| 142 |
-
"""Lightweight CNN router for 128×128 spectrogram inputs."""
|
| 143 |
-
|
| 144 |
-
def __init__(self, num_comm: int, dropout: float = 0.1) -> None:
|
| 145 |
-
super().__init__()
|
| 146 |
-
self.features = nn.Sequential(
|
| 147 |
-
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2),
|
| 148 |
-
nn.BatchNorm2d(32),
|
| 149 |
-
nn.SiLU(inplace=True),
|
| 150 |
-
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
| 151 |
-
nn.BatchNorm2d(64),
|
| 152 |
-
nn.SiLU(inplace=True),
|
| 153 |
-
nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),
|
| 154 |
-
nn.BatchNorm2d(96),
|
| 155 |
-
nn.SiLU(inplace=True),
|
| 156 |
-
nn.Conv2d(96, 128, kernel_size=3, stride=2, padding=1),
|
| 157 |
-
nn.BatchNorm2d(128),
|
| 158 |
-
nn.SiLU(inplace=True),
|
| 159 |
-
nn.AdaptiveAvgPool2d((1, 1)),
|
| 160 |
-
)
|
| 161 |
-
head_layers: List[nn.Module] = [nn.Flatten()]
|
| 162 |
-
if dropout > 0:
|
| 163 |
-
head_layers.append(nn.Dropout(dropout))
|
| 164 |
-
head_layers.append(nn.Linear(128, num_comm))
|
| 165 |
-
self.classifier = nn.Sequential(*head_layers)
|
| 166 |
-
|
| 167 |
-
def forward(self, specs: torch.Tensor) -> torch.Tensor:
|
| 168 |
-
x = specs
|
| 169 |
-
if x.dim() == 3:
|
| 170 |
-
x = x.unsqueeze(1)
|
| 171 |
-
elif x.dim() != 4:
|
| 172 |
-
raise ValueError(f"Expected specs rank 3 or 4, got shape {tuple(specs.shape)}")
|
| 173 |
-
features = self.features(x)
|
| 174 |
-
logits = self.classifier(features)
|
| 175 |
-
return logits
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def set_seed(seed: int) -> None:
|
| 179 |
-
random.seed(seed)
|
| 180 |
-
np.random.seed(seed)
|
| 181 |
-
torch.manual_seed(seed)
|
| 182 |
-
if torch.cuda.is_available():
|
| 183 |
-
torch.cuda.manual_seed(seed)
|
| 184 |
-
torch.cuda.manual_seed_all(seed)
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def parse_expert_definitions(entries: Sequence[str]) -> Dict[str, ExpertSpec]:
|
| 188 |
-
experts: Dict[str, ExpertSpec] = {}
|
| 189 |
-
for entry in entries:
|
| 190 |
-
if "=" not in entry:
|
| 191 |
-
raise ValueError(f"Expert definition must use COMM=path syntax (got: {entry})")
|
| 192 |
-
comm_part, _, path_part = entry.partition("=")
|
| 193 |
-
comm = canonical_comm_name(comm_part)
|
| 194 |
-
if not path_part:
|
| 195 |
-
raise ValueError(f"Missing checkpoint path for expert '{comm}'")
|
| 196 |
-
if ":" in path_part:
|
| 197 |
-
ckpt_str, stats_str = path_part.split(":", 1)
|
| 198 |
-
stats_path = Path(stats_str).expanduser().resolve()
|
| 199 |
-
else:
|
| 200 |
-
ckpt_str = path_part
|
| 201 |
-
stats_path = None
|
| 202 |
-
checkpoint = Path(ckpt_str).expanduser().resolve()
|
| 203 |
-
if not checkpoint.exists():
|
| 204 |
-
raise FileNotFoundError(f"Expert checkpoint not found: {checkpoint}")
|
| 205 |
-
if stats_path is not None and not stats_path.exists():
|
| 206 |
-
raise FileNotFoundError(f"Dataset stats file not found: {stats_path}")
|
| 207 |
-
experts[comm] = ExpertSpec(comm=comm, checkpoint=checkpoint, stats_path=stats_path)
|
| 208 |
-
return experts
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
def discover_stats_path(comm: str, defaults_root: Path) -> Optional[Path]:
|
| 212 |
-
candidates = [
|
| 213 |
-
defaults_root / f"{comm}_models" / "dataset_stats.json",
|
| 214 |
-
defaults_root / f"{comm.lower()}_models" / "dataset_stats.json",
|
| 215 |
-
defaults_root / comm / "dataset_stats.json",
|
| 216 |
-
defaults_root / comm.lower() / "dataset_stats.json",
|
| 217 |
-
]
|
| 218 |
-
for candidate in candidates:
|
| 219 |
-
if candidate.exists():
|
| 220 |
-
return candidate
|
| 221 |
-
return None
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
def load_dataset_stats(stats_path: Optional[Path]) -> Mapping[str, float | str]:
|
| 225 |
-
if stats_path is None:
|
| 226 |
-
return {"mean": 0.0, "std": 1.0, "normalization": "per_sample"}
|
| 227 |
-
with open(stats_path, "r", encoding="utf-8") as fh:
|
| 228 |
-
return json.load(fh)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
def _collect_candidate_files(
|
| 232 |
-
*,
|
| 233 |
-
data_root: Path,
|
| 234 |
-
cities: Sequence[str],
|
| 235 |
-
comm: str,
|
| 236 |
-
snr_filters: Optional[Sequence[str]],
|
| 237 |
-
mobility_filters: Optional[Sequence[str]],
|
| 238 |
-
modulation_filters: Optional[Sequence[str]],
|
| 239 |
-
fft_filters: Optional[Sequence[str]],
|
| 240 |
-
) -> List[Tuple[Path, SampleMetadata]]:
|
| 241 |
-
mobility_set = set(mobility_filters) if mobility_filters else None
|
| 242 |
-
snr_set = set(snr_filters) if snr_filters else None
|
| 243 |
-
modulation_set = {m.upper() for m in modulation_filters} if modulation_filters else None
|
| 244 |
-
fft_set = set(fft_filters) if fft_filters else None
|
| 245 |
-
|
| 246 |
-
candidates: List[Tuple[Path, SampleMetadata]] = []
|
| 247 |
-
for city in cities:
|
| 248 |
-
base = data_root / city / comm
|
| 249 |
-
if not base.exists():
|
| 250 |
-
continue
|
| 251 |
-
pattern = str(base / "**" / "spectrograms" / "*.pkl")
|
| 252 |
-
for path_str in glob.iglob(pattern, recursive=True):
|
| 253 |
-
path = Path(path_str)
|
| 254 |
-
_, modulation = identify_modulation(str(path))
|
| 255 |
-
if modulation is None:
|
| 256 |
-
continue
|
| 257 |
-
if modulation_set is not None and modulation.upper() not in modulation_set:
|
| 258 |
-
continue
|
| 259 |
-
rate, snr, mobility = _extract_metadata(path.parts)
|
| 260 |
-
if mobility_set is not None and mobility not in mobility_set:
|
| 261 |
-
continue
|
| 262 |
-
if snr_set is not None and snr not in snr_set:
|
| 263 |
-
continue
|
| 264 |
-
fft_folder = next((part for part in path.parts if part.startswith("win")), None)
|
| 265 |
-
if fft_set is not None and fft_folder not in fft_set:
|
| 266 |
-
continue
|
| 267 |
-
meta = SampleMetadata(
|
| 268 |
-
comm=comm,
|
| 269 |
-
modulation=modulation,
|
| 270 |
-
snr=snr,
|
| 271 |
-
mobility=mobility,
|
| 272 |
-
rate=rate,
|
| 273 |
-
source=str(path),
|
| 274 |
-
)
|
| 275 |
-
candidates.append((path, meta))
|
| 276 |
-
return candidates
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
def _sample_from_file(
|
| 280 |
-
array: np.ndarray,
|
| 281 |
-
take: int,
|
| 282 |
-
rng: np.random.Generator,
|
| 283 |
-
) -> np.ndarray:
|
| 284 |
-
if take <= 0 or array.shape[0] == 0:
|
| 285 |
-
return np.empty((0, 128, 128), dtype=np.float32)
|
| 286 |
-
if take >= array.shape[0]:
|
| 287 |
-
return array.astype(np.float32, copy=False)
|
| 288 |
-
indices = rng.choice(array.shape[0], size=take, replace=False)
|
| 289 |
-
return array[indices].astype(np.float32, copy=False)
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
def collect_spectrograms_for_comm(
|
| 293 |
-
*,
|
| 294 |
-
data_root: Path,
|
| 295 |
-
cities: Sequence[str],
|
| 296 |
-
comm: str,
|
| 297 |
-
snrs: Optional[Sequence[str]],
|
| 298 |
-
mobilities: Optional[Sequence[str]],
|
| 299 |
-
modulations: Optional[Sequence[str]],
|
| 300 |
-
fft_folders: Optional[Sequence[str]],
|
| 301 |
-
max_samples: int,
|
| 302 |
-
max_per_combo: Optional[int],
|
| 303 |
-
rng: np.random.Generator,
|
| 304 |
-
) -> Tuple[np.ndarray, List[SampleMetadata]]:
|
| 305 |
-
candidates = _collect_candidate_files(
|
| 306 |
-
data_root=data_root,
|
| 307 |
-
cities=cities,
|
| 308 |
-
comm=comm,
|
| 309 |
-
snr_filters=snrs,
|
| 310 |
-
mobility_filters=mobilities,
|
| 311 |
-
modulation_filters=modulations,
|
| 312 |
-
fft_filters=fft_folders,
|
| 313 |
-
)
|
| 314 |
-
if not candidates:
|
| 315 |
-
raise RuntimeError(f"No spectrogram files matched filters for {comm}")
|
| 316 |
-
|
| 317 |
-
rng.shuffle(candidates)
|
| 318 |
-
combo_counts: MutableMapping[Tuple[str, str, str], int] = defaultdict(int)
|
| 319 |
-
collected: List[np.ndarray] = []
|
| 320 |
-
metadata: List[SampleMetadata] = []
|
| 321 |
-
remaining: Optional[int] = max_samples if max_samples > 0 else None
|
| 322 |
-
per_combo_limit: Optional[int] = max_per_combo if (max_per_combo is not None and max_per_combo > 0) else None
|
| 323 |
-
|
| 324 |
-
for path, meta in candidates:
|
| 325 |
-
if remaining is not None and remaining <= 0:
|
| 326 |
-
break
|
| 327 |
-
combo_key = (meta.modulation, meta.snr, meta.mobility)
|
| 328 |
-
already = combo_counts[combo_key]
|
| 329 |
-
if per_combo_limit is not None and already >= per_combo_limit:
|
| 330 |
-
continue
|
| 331 |
-
|
| 332 |
-
try:
|
| 333 |
-
specs = load_all_samples(str(path))
|
| 334 |
-
except Exception as exc: # pragma: no cover - guard against corrupted files
|
| 335 |
-
print(f"[WARN] Failed to load {path}: {exc}")
|
| 336 |
-
continue
|
| 337 |
-
|
| 338 |
-
if specs.size == 0:
|
| 339 |
-
continue
|
| 340 |
-
|
| 341 |
-
remaining_for_combo = per_combo_limit - already if per_combo_limit is not None else specs.shape[0]
|
| 342 |
-
allowed = min(remaining_for_combo, specs.shape[0])
|
| 343 |
-
if remaining is not None:
|
| 344 |
-
allowed = min(allowed, remaining)
|
| 345 |
-
if allowed <= 0:
|
| 346 |
-
continue
|
| 347 |
-
chosen = _sample_from_file(specs, allowed, rng)
|
| 348 |
-
if chosen.size == 0:
|
| 349 |
-
continue
|
| 350 |
-
|
| 351 |
-
collected.append(chosen)
|
| 352 |
-
metadata.extend([meta] * chosen.shape[0])
|
| 353 |
-
combo_counts[combo_key] += chosen.shape[0]
|
| 354 |
-
if remaining is not None:
|
| 355 |
-
remaining -= chosen.shape[0]
|
| 356 |
-
|
| 357 |
-
if not collected:
|
| 358 |
-
raise RuntimeError(f"Unable to collect samples for {comm} after applying limits")
|
| 359 |
-
|
| 360 |
-
stacked = np.concatenate(collected, axis=0)
|
| 361 |
-
return stacked.astype(np.float32, copy=False), metadata
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
def stratified_split(
|
| 365 |
-
labels: np.ndarray,
|
| 366 |
-
*,
|
| 367 |
-
train_ratio: float,
|
| 368 |
-
val_ratio: float,
|
| 369 |
-
seed: int,
|
| 370 |
-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 371 |
-
if not (0 < train_ratio < 1) or not (0 < val_ratio < 1):
|
| 372 |
-
raise ValueError("train_ratio and val_ratio must be in (0, 1)")
|
| 373 |
-
if train_ratio + val_ratio >= 1.0:
|
| 374 |
-
raise ValueError("train_ratio + val_ratio must be < 1.0")
|
| 375 |
-
|
| 376 |
-
rng = np.random.default_rng(seed)
|
| 377 |
-
train_indices: List[int] = []
|
| 378 |
-
val_indices: List[int] = []
|
| 379 |
-
test_indices: List[int] = []
|
| 380 |
-
|
| 381 |
-
for label in np.unique(labels):
|
| 382 |
-
idx = np.flatnonzero(labels == label)
|
| 383 |
-
if idx.size < 3:
|
| 384 |
-
raise ValueError(f"Not enough samples for label {label} to form splits (need >=3, have {idx.size})")
|
| 385 |
-
rng.shuffle(idx)
|
| 386 |
-
train_end = int(round(train_ratio * idx.size))
|
| 387 |
-
val_end = train_end + int(round(val_ratio * idx.size))
|
| 388 |
-
train_indices.extend(idx[:train_end])
|
| 389 |
-
val_indices.extend(idx[train_end:val_end])
|
| 390 |
-
test_indices.extend(idx[val_end:])
|
| 391 |
-
|
| 392 |
-
return (
|
| 393 |
-
np.array(train_indices, dtype=np.int64),
|
| 394 |
-
np.array(val_indices, dtype=np.int64),
|
| 395 |
-
np.array(test_indices, dtype=np.int64),
|
| 396 |
-
)
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
def build_dataloaders(
|
| 400 |
-
dataset: RoutedSpectrogramDataset,
|
| 401 |
-
train_idx: np.ndarray,
|
| 402 |
-
val_idx: np.ndarray,
|
| 403 |
-
test_idx: np.ndarray,
|
| 404 |
-
*,
|
| 405 |
-
batch_size: int,
|
| 406 |
-
num_workers: int,
|
| 407 |
-
) -> Tuple[DataLoader, DataLoader, DataLoader]:
|
| 408 |
-
def subset(indices: np.ndarray) -> RoutedSpectrogramDataset:
|
| 409 |
-
specs = dataset.specs[indices].numpy()
|
| 410 |
-
comm = dataset.comm_labels[indices].numpy()
|
| 411 |
-
task = dataset.task_labels[indices].numpy()
|
| 412 |
-
meta = [dataset.metadata[int(i)] for i in indices]
|
| 413 |
-
return RoutedSpectrogramDataset(specs, comm, task, meta)
|
| 414 |
-
|
| 415 |
-
train_ds = subset(train_idx)
|
| 416 |
-
val_ds = subset(val_idx)
|
| 417 |
-
test_ds = subset(test_idx)
|
| 418 |
-
|
| 419 |
-
train_loader = DataLoader(
|
| 420 |
-
train_ds,
|
| 421 |
-
batch_size=batch_size,
|
| 422 |
-
shuffle=True,
|
| 423 |
-
drop_last=len(train_ds) > batch_size,
|
| 424 |
-
num_workers=num_workers,
|
| 425 |
-
pin_memory=torch.cuda.is_available(),
|
| 426 |
-
)
|
| 427 |
-
val_loader = DataLoader(
|
| 428 |
-
val_ds,
|
| 429 |
-
batch_size=batch_size,
|
| 430 |
-
shuffle=False,
|
| 431 |
-
num_workers=num_workers,
|
| 432 |
-
pin_memory=torch.cuda.is_available(),
|
| 433 |
-
)
|
| 434 |
-
test_loader = DataLoader(
|
| 435 |
-
test_ds,
|
| 436 |
-
batch_size=batch_size,
|
| 437 |
-
shuffle=False,
|
| 438 |
-
num_workers=num_workers,
|
| 439 |
-
pin_memory=torch.cuda.is_available(),
|
| 440 |
-
)
|
| 441 |
-
return train_loader, val_loader, test_loader
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
def infer_expert_signature(state_dict: Mapping[str, torch.Tensor]) -> Dict[str, object]:
|
| 445 |
-
keys = set(state_dict.keys())
|
| 446 |
-
# Determine input dimension (128 vs 130 if stats appended).
|
| 447 |
-
layer_norm_weight = state_dict.get("classifier.0.weight")
|
| 448 |
-
if layer_norm_weight is None:
|
| 449 |
-
raise ValueError("Unable to infer classifier input dimension from checkpoint")
|
| 450 |
-
input_dim = layer_norm_weight.numel()
|
| 451 |
-
append_input_stats = input_dim > 128
|
| 452 |
-
|
| 453 |
-
# Determine classifier type.
|
| 454 |
-
if any(k.startswith("classifier.1.conv1") for k in keys):
|
| 455 |
-
head_type = "res1dcnn"
|
| 456 |
-
elif "classifier.1.weight" in keys:
|
| 457 |
-
head_type = "mlp"
|
| 458 |
-
elif "classifier.weight" in keys:
|
| 459 |
-
head_type = "linear"
|
| 460 |
-
else:
|
| 461 |
-
raise ValueError("Unrecognized classifier architecture in checkpoint")
|
| 462 |
-
|
| 463 |
-
# Hidden width for MLP head.
|
| 464 |
-
classifier_dim = None
|
| 465 |
-
if head_type == "mlp":
|
| 466 |
-
classifier_dim = int(state_dict["classifier.1.weight"].shape[0])
|
| 467 |
-
|
| 468 |
-
# Projection head dimensionality.
|
| 469 |
-
if "projection_head.0.weight" in keys:
|
| 470 |
-
projection_dim = int(state_dict["projection_head.0.weight"].shape[0])
|
| 471 |
-
else:
|
| 472 |
-
projection_dim = 0
|
| 473 |
-
|
| 474 |
-
# Number of output classes from final linear weight.
|
| 475 |
-
if head_type == "linear":
|
| 476 |
-
num_classes = int(state_dict["classifier.weight"].shape[0])
|
| 477 |
-
elif head_type == "mlp":
|
| 478 |
-
num_classes = int(state_dict["classifier.2.weight"].shape[0])
|
| 479 |
-
else: # res1dcnn
|
| 480 |
-
num_classes = int(state_dict["classifier.1.fc.weight"].shape[0])
|
| 481 |
-
|
| 482 |
-
return {
|
| 483 |
-
"append_input_stats": append_input_stats,
|
| 484 |
-
"input_dim": input_dim,
|
| 485 |
-
"head_type": head_type,
|
| 486 |
-
"classifier_dim": classifier_dim if classifier_dim is not None else 128,
|
| 487 |
-
"projection_dim": projection_dim,
|
| 488 |
-
"num_classes": num_classes,
|
| 489 |
-
}
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
def load_expert_model(
|
| 493 |
-
spec: ExpertSpec,
|
| 494 |
-
stats_root: Path,
|
| 495 |
-
device: torch.device,
|
| 496 |
-
) -> Tuple[str, nn.Module, int]:
|
| 497 |
-
raw_state = torch.load(spec.checkpoint, map_location="cpu")
|
| 498 |
-
if any(k.startswith("module.") for k in raw_state):
|
| 499 |
-
raw_state = {k.replace("module.", "", 1): v for k, v in raw_state.items()}
|
| 500 |
-
|
| 501 |
-
signature = infer_expert_signature(raw_state)
|
| 502 |
-
|
| 503 |
-
stats_path = spec.stats_path
|
| 504 |
-
if stats_path is None:
|
| 505 |
-
stats_path = discover_stats_path(spec.comm, stats_root)
|
| 506 |
-
stats = load_dataset_stats(stats_path)
|
| 507 |
-
|
| 508 |
-
model = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
|
| 509 |
-
backbone_state = {
|
| 510 |
-
k.split("backbone.", 1)[1]: v
|
| 511 |
-
for k, v in raw_state.items()
|
| 512 |
-
if k.startswith("backbone.")
|
| 513 |
-
}
|
| 514 |
-
model.load_state_dict(backbone_state, strict=False)
|
| 515 |
-
|
| 516 |
-
classifier = LWMClassifierMinimalAdapter(
|
| 517 |
-
backbone=model,
|
| 518 |
-
num_classes=int(signature["num_classes"]),
|
| 519 |
-
classifier_dim=int(signature["classifier_dim"]),
|
| 520 |
-
head_type=str(signature["head_type"]),
|
| 521 |
-
append_input_stats=bool(signature["append_input_stats"]),
|
| 522 |
-
projection_dim=int(signature["projection_dim"]),
|
| 523 |
-
normalization_stats=stats,
|
| 524 |
-
)
|
| 525 |
-
classifier.load_state_dict(raw_state, strict=True)
|
| 526 |
-
classifier.eval()
|
| 527 |
-
classifier.to(device)
|
| 528 |
-
for param in classifier.parameters():
|
| 529 |
-
param.requires_grad_(False)
|
| 530 |
-
return spec.comm, classifier, int(signature["num_classes"])
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
class LWMClassifierMinimalAdapter(nn.Module):
|
| 534 |
-
"""Thin wrapper matching task2.mobility_utils.LWMClassifierMinimal."""
|
| 535 |
-
|
| 536 |
-
def __init__(
|
| 537 |
-
self,
|
| 538 |
-
*,
|
| 539 |
-
backbone: nn.Module,
|
| 540 |
-
num_classes: int,
|
| 541 |
-
classifier_dim: int,
|
| 542 |
-
head_type: str,
|
| 543 |
-
append_input_stats: bool,
|
| 544 |
-
projection_dim: int,
|
| 545 |
-
normalization_stats: Mapping[str, float | str],
|
| 546 |
-
) -> None:
|
| 547 |
-
super().__init__()
|
| 548 |
-
from task2.mobility_utils import LWMClassifierMinimal # local import to avoid cycle
|
| 549 |
-
|
| 550 |
-
self.inner = LWMClassifierMinimal(
|
| 551 |
-
backbone=backbone,
|
| 552 |
-
num_classes=num_classes,
|
| 553 |
-
classifier_dim=classifier_dim,
|
| 554 |
-
dropout=0.0,
|
| 555 |
-
trainable_layers=0,
|
| 556 |
-
projection_dim=projection_dim,
|
| 557 |
-
append_input_stats=append_input_stats,
|
| 558 |
-
normalization_stats=normalization_stats,
|
| 559 |
-
head_type=head_type,
|
| 560 |
-
)
|
| 561 |
-
|
| 562 |
-
def forward(self, specs: torch.Tensor) -> torch.Tensor:
|
| 563 |
-
return self.inner(specs)
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
@torch.no_grad()
|
| 567 |
-
def evaluate_router(
|
| 568 |
-
model: nn.Module,
|
| 569 |
-
loader: DataLoader,
|
| 570 |
-
criterion: nn.Module,
|
| 571 |
-
device: torch.device,
|
| 572 |
-
) -> Tuple[float, float, np.ndarray, np.ndarray]:
|
| 573 |
-
model.eval()
|
| 574 |
-
total_loss = 0.0
|
| 575 |
-
correct = 0
|
| 576 |
-
seen = 0
|
| 577 |
-
y_true: List[int] = []
|
| 578 |
-
y_pred: List[int] = []
|
| 579 |
-
|
| 580 |
-
for specs, comm_labels, _ in loader:
|
| 581 |
-
specs = specs.to(device, non_blocking=True)
|
| 582 |
-
comm_labels = torch.as_tensor(comm_labels, device=device)
|
| 583 |
-
|
| 584 |
-
logits = model(specs)
|
| 585 |
-
loss = criterion(logits, comm_labels)
|
| 586 |
-
preds = logits.argmax(dim=1)
|
| 587 |
-
total_loss += loss.item() * specs.size(0)
|
| 588 |
-
correct += (preds == comm_labels).sum().item()
|
| 589 |
-
seen += specs.size(0)
|
| 590 |
-
y_true.extend(comm_labels.detach().cpu().tolist())
|
| 591 |
-
y_pred.extend(preds.detach().cpu().tolist())
|
| 592 |
-
|
| 593 |
-
avg_loss = total_loss / max(seen, 1)
|
| 594 |
-
acc = correct / max(seen, 1)
|
| 595 |
-
return avg_loss, acc, np.array(y_true, dtype=np.int64), np.array(y_pred, dtype=np.int64)
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
def compute_confusion(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> np.ndarray:
|
| 599 |
-
matrix = np.zeros((num_classes, num_classes), dtype=np.int64)
|
| 600 |
-
for true, pred in zip(y_true, y_pred):
|
| 601 |
-
if 0 <= true < num_classes and 0 <= pred < num_classes:
|
| 602 |
-
matrix[true, pred] += 1
|
| 603 |
-
return matrix
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
@torch.no_grad()
|
| 607 |
-
def evaluate_routing(
|
| 608 |
-
router: nn.Module,
|
| 609 |
-
experts: Mapping[int, Tuple[str, nn.Module]],
|
| 610 |
-
loader: DataLoader,
|
| 611 |
-
*,
|
| 612 |
-
num_comm: int,
|
| 613 |
-
num_task_classes: int,
|
| 614 |
-
device: torch.device,
|
| 615 |
-
routing_mode: str,
|
| 616 |
-
routing_topk: int,
|
| 617 |
-
) -> Dict[str, object]:
|
| 618 |
-
router.eval()
|
| 619 |
-
for _, model in experts.values():
|
| 620 |
-
model.eval()
|
| 621 |
-
|
| 622 |
-
criterion = nn.CrossEntropyLoss()
|
| 623 |
-
total_loss = 0.0
|
| 624 |
-
total = 0
|
| 625 |
-
correct_router = 0
|
| 626 |
-
correct_task = 0
|
| 627 |
-
|
| 628 |
-
confusion_router = np.zeros((num_comm, num_comm), dtype=np.int64)
|
| 629 |
-
confusion_task = np.zeros((num_task_classes, num_task_classes), dtype=np.int64)
|
| 630 |
-
coverage = Counter() # type: ignore[type-arg]
|
| 631 |
-
|
| 632 |
-
for specs, comm_labels, task_labels in loader:
|
| 633 |
-
specs = specs.to(device, non_blocking=True)
|
| 634 |
-
comm_labels = torch.as_tensor(comm_labels, device=device)
|
| 635 |
-
task_labels = torch.as_tensor(task_labels, device=device)
|
| 636 |
-
|
| 637 |
-
logits = router(specs)
|
| 638 |
-
loss = criterion(logits, comm_labels)
|
| 639 |
-
probs = torch.softmax(logits, dim=1)
|
| 640 |
-
router_pred = probs.argmax(dim=1)
|
| 641 |
-
|
| 642 |
-
batch = specs.size(0)
|
| 643 |
-
total_loss += loss.item() * batch
|
| 644 |
-
total += batch
|
| 645 |
-
correct_router += (router_pred == comm_labels).sum().item()
|
| 646 |
-
|
| 647 |
-
confusion_router += compute_confusion(
|
| 648 |
-
comm_labels.detach().cpu().numpy(),
|
| 649 |
-
router_pred.detach().cpu().numpy(),
|
| 650 |
-
num_comm,
|
| 651 |
-
)
|
| 652 |
-
|
| 653 |
-
if not experts:
|
| 654 |
-
continue
|
| 655 |
-
|
| 656 |
-
weights = torch.zeros_like(probs)
|
| 657 |
-
if routing_mode == "hard":
|
| 658 |
-
weights.scatter_(1, router_pred.unsqueeze(1), 1.0)
|
| 659 |
-
elif routing_mode == "soft":
|
| 660 |
-
weights = probs
|
| 661 |
-
elif routing_mode == "topk":
|
| 662 |
-
topk = max(1, min(routing_topk, num_comm))
|
| 663 |
-
topk_vals, topk_indices = probs.topk(topk, dim=1)
|
| 664 |
-
weights.zero_()
|
| 665 |
-
weights.scatter_(1, topk_indices, topk_vals)
|
| 666 |
-
else:
|
| 667 |
-
raise ValueError(f"Unsupported routing mode: {routing_mode}")
|
| 668 |
-
|
| 669 |
-
final_logits = torch.zeros(batch, num_task_classes, device=device)
|
| 670 |
-
for comm_idx, (name, expert) in experts.items():
|
| 671 |
-
weight_column = weights[:, comm_idx]
|
| 672 |
-
if not torch.any(weight_column > 0):
|
| 673 |
-
continue
|
| 674 |
-
outputs = expert(specs)
|
| 675 |
-
if outputs.size(1) != num_task_classes:
|
| 676 |
-
raise ValueError(
|
| 677 |
-
f"Expert '{name}' returned {outputs.size(1)} classes, expected {num_task_classes}"
|
| 678 |
-
)
|
| 679 |
-
final_logits += weight_column.unsqueeze(1) * outputs
|
| 680 |
-
coverage[name] += float(weight_column.sum().item())
|
| 681 |
-
|
| 682 |
-
task_pred = final_logits.argmax(dim=1)
|
| 683 |
-
correct_task += (task_pred == task_labels).sum().item()
|
| 684 |
-
confusion_task += compute_confusion(
|
| 685 |
-
task_labels.detach().cpu().numpy(),
|
| 686 |
-
task_pred.detach().cpu().numpy(),
|
| 687 |
-
num_task_classes,
|
| 688 |
-
)
|
| 689 |
-
|
| 690 |
-
metrics: Dict[str, object] = {
|
| 691 |
-
"router_loss": total_loss / max(total, 1),
|
| 692 |
-
"router_acc": correct_router / max(total, 1),
|
| 693 |
-
"router_confusion": confusion_router.tolist(),
|
| 694 |
-
"coverage": dict(coverage),
|
| 695 |
-
}
|
| 696 |
-
if experts:
|
| 697 |
-
metrics["task_acc"] = correct_task / max(total, 1)
|
| 698 |
-
metrics["task_confusion"] = confusion_task.tolist()
|
| 699 |
-
return metrics
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
def modulation_labels_from_metadata(metadata: Sequence[SampleMetadata]) -> np.ndarray:
|
| 703 |
-
labels: List[int] = []
|
| 704 |
-
for meta in metadata:
|
| 705 |
-
label = MODULATION_LABELS.get(meta.modulation.upper())
|
| 706 |
-
if label is None:
|
| 707 |
-
raise ValueError(f"Unknown modulation label in metadata: {meta.modulation}")
|
| 708 |
-
labels.append(label)
|
| 709 |
-
return np.array(labels, dtype=np.int64)
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
def snr_mobility_labels_from_metadata(
|
| 713 |
-
metadata: Sequence[SampleMetadata],
|
| 714 |
-
*,
|
| 715 |
-
snr_order: Sequence[str],
|
| 716 |
-
mobility_order: Sequence[str],
|
| 717 |
-
) -> Tuple[np.ndarray, Dict[int, Tuple[str, str]]]:
|
| 718 |
-
combos: List[Tuple[str, str]] = []
|
| 719 |
-
for snr in snr_order:
|
| 720 |
-
for mobility in mobility_order:
|
| 721 |
-
combos.append((snr, mobility))
|
| 722 |
-
combo_to_idx = {combo: idx for idx, combo in enumerate(combos)}
|
| 723 |
-
|
| 724 |
-
labels: List[int] = []
|
| 725 |
-
for meta in metadata:
|
| 726 |
-
combo = (meta.snr, meta.mobility)
|
| 727 |
-
if combo not in combo_to_idx:
|
| 728 |
-
raise ValueError(f"Sample combo {combo} not present in configured (snr, mobility) grid")
|
| 729 |
-
labels.append(combo_to_idx[combo])
|
| 730 |
-
mapping = {idx: combo for combo, idx in combo_to_idx.items()}
|
| 731 |
-
return np.array(labels, dtype=np.int64), mapping
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
def prepare_dataset(
|
| 735 |
-
*,
|
| 736 |
-
data_root: Path,
|
| 737 |
-
cities: Sequence[str],
|
| 738 |
-
comm_types: Sequence[str],
|
| 739 |
-
snrs: Optional[Sequence[str]],
|
| 740 |
-
mobilities: Optional[Sequence[str]],
|
| 741 |
-
modulations: Optional[Sequence[str]],
|
| 742 |
-
fft_folders: Optional[Sequence[str]],
|
| 743 |
-
max_samples_per_comm: int,
|
| 744 |
-
max_per_combo: Optional[int],
|
| 745 |
-
task: str,
|
| 746 |
-
seed: int,
|
| 747 |
-
) -> Tuple[RoutedSpectrogramDataset, Dict[str, int], Optional[Dict[int, Tuple[str, str]]]]:
|
| 748 |
-
rng = np.random.default_rng(seed)
|
| 749 |
-
specs_list: List[np.ndarray] = []
|
| 750 |
-
comm_labels_list: List[int] = []
|
| 751 |
-
metadata_list: List[SampleMetadata] = []
|
| 752 |
-
comm_to_idx = {comm: idx for idx, comm in enumerate(comm_types)}
|
| 753 |
-
|
| 754 |
-
for comm in comm_types:
|
| 755 |
-
samples, metadata = collect_spectrograms_for_comm(
|
| 756 |
-
data_root=data_root,
|
| 757 |
-
cities=cities,
|
| 758 |
-
comm=comm,
|
| 759 |
-
snrs=snrs,
|
| 760 |
-
mobilities=mobilities,
|
| 761 |
-
modulations=modulations,
|
| 762 |
-
fft_folders=fft_folders,
|
| 763 |
-
max_samples=max_samples_per_comm,
|
| 764 |
-
max_per_combo=max_per_combo,
|
| 765 |
-
rng=rng,
|
| 766 |
-
)
|
| 767 |
-
specs_list.append(samples)
|
| 768 |
-
metadata_list.extend(metadata)
|
| 769 |
-
comm_labels_list.extend([comm_to_idx[comm]] * samples.shape[0])
|
| 770 |
-
|
| 771 |
-
specs = np.concatenate(specs_list, axis=0)
|
| 772 |
-
metadata = metadata_list
|
| 773 |
-
comm_labels = np.array(comm_labels_list, dtype=np.int64)
|
| 774 |
-
|
| 775 |
-
order = rng.permutation(specs.shape[0])
|
| 776 |
-
specs = specs[order]
|
| 777 |
-
comm_labels = comm_labels[order]
|
| 778 |
-
metadata = [metadata[idx] for idx in order]
|
| 779 |
-
|
| 780 |
-
normalized = normalize_per_sample(specs)
|
| 781 |
-
|
| 782 |
-
if task == "modulation":
|
| 783 |
-
task_labels = modulation_labels_from_metadata(metadata)
|
| 784 |
-
mapping = None
|
| 785 |
-
else:
|
| 786 |
-
if snrs is None:
|
| 787 |
-
snr_order = sorted({meta.snr for meta in metadata}, key=snr_sort_key)
|
| 788 |
-
else:
|
| 789 |
-
snr_order = [snr for snr in snrs if any(meta.snr == snr for meta in metadata)]
|
| 790 |
-
if mobilities is None:
|
| 791 |
-
mobility_order = sorted({meta.mobility for meta in metadata})
|
| 792 |
-
else:
|
| 793 |
-
mobility_order = [mob for mob in mobilities if any(meta.mobility == mob for meta in metadata)]
|
| 794 |
-
task_labels, mapping = snr_mobility_labels_from_metadata(
|
| 795 |
-
metadata,
|
| 796 |
-
snr_order=snr_order,
|
| 797 |
-
mobility_order=mobility_order,
|
| 798 |
-
)
|
| 799 |
-
|
| 800 |
-
dataset = RoutedSpectrogramDataset(normalized, comm_labels, task_labels, metadata)
|
| 801 |
-
return dataset, comm_to_idx, mapping
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
def parse_args() -> argparse.Namespace:
|
| 805 |
-
parser = argparse.ArgumentParser(description=__doc__)
|
| 806 |
-
parser.add_argument("--data-root", type=Path, default=Path("spectrograms"), help="Root directory with spectrogram data")
|
| 807 |
-
parser.add_argument("--cities", nargs="*", default=["city_1_losangeles"], help="City folders to include")
|
| 808 |
-
parser.add_argument("--comm-types", nargs="*", default=["LTE", "WiFi", "5G"], help="Communication standards to model")
|
| 809 |
-
parser.add_argument("--snrs", nargs="*", default=None, help="SNR folders to include")
|
| 810 |
-
parser.add_argument("--mobilities", nargs="*", default=None, help="Mobility folders to include")
|
| 811 |
-
parser.add_argument("--modulations", nargs="*", default=None, help="Modulation classes to include (default: all)")
|
| 812 |
-
parser.add_argument("--fft-folders", nargs="*", default=None, help="Specific FFT/window folders to include")
|
| 813 |
-
parser.add_argument("--task", choices=("modulation", "snr_mobility"), default="snr_mobility", help="Downstream task label")
|
| 814 |
-
parser.add_argument("--max-samples-per-comm", type=int, default=6000, help="Maximum samples per communication profile")
|
| 815 |
-
parser.add_argument("--max-per-combo", type=int, default=512, help="Cap per (modulation,SNR,mobility) combo (0=unbounded)")
|
| 816 |
-
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 817 |
-
parser.add_argument(
|
| 818 |
-
"--routing-mode",
|
| 819 |
-
choices=("hard", "soft", "topk"),
|
| 820 |
-
default="hard",
|
| 821 |
-
help="Routing strategy: hard (top-1), soft (probability-weighted), or topk (restricted soft) (default: %(default)s)",
|
| 822 |
-
)
|
| 823 |
-
parser.add_argument(
|
| 824 |
-
"--routing-topk",
|
| 825 |
-
type=int,
|
| 826 |
-
default=2,
|
| 827 |
-
help="When routing-mode=topk, number of experts to evaluate per sample (default: %(default)s)",
|
| 828 |
-
)
|
| 829 |
-
|
| 830 |
-
parser.add_argument("--train-ratio", type=float, default=0.7, help="Fraction of data for training")
|
| 831 |
-
parser.add_argument("--val-ratio", type=float, default=0.15, help="Fraction of data for validation")
|
| 832 |
-
parser.add_argument("--batch-size", type=int, default=128, help="Mini-batch size")
|
| 833 |
-
parser.add_argument("--epochs", type=int, default=20, help="Training epochs")
|
| 834 |
-
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
| 835 |
-
parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay for AdamW")
|
| 836 |
-
parser.add_argument("--dropout", type=float, default=0.1, help="Router dropout probability")
|
| 837 |
-
parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
|
| 838 |
-
parser.add_argument("--use-amp", action="store_true", help="Enable mixed precision for router training")
|
| 839 |
-
parser.add_argument("--spec-augment", action="store_true", help="Apply SpecAugment to router inputs")
|
| 840 |
-
parser.add_argument("--spec-augment-freq", type=int, default=12, help="Frequency mask width for SpecAugment")
|
| 841 |
-
parser.add_argument("--spec-augment-time", type=int, default=16, help="Time mask width for SpecAugment")
|
| 842 |
-
parser.add_argument("--spec-augment-prob", type=float, default=0.5, help="Probability to apply SpecAugment to a sample")
|
| 843 |
-
|
| 844 |
-
parser.add_argument("--expert", action="append", default=[], help="Expert definition COMM=checkpoint[:stats_path]")
|
| 845 |
-
parser.add_argument("--expert-stats-root", type=Path, default=Path("models"), help="Root to auto-discover dataset_stats.json")
|
| 846 |
-
|
| 847 |
-
parser.add_argument("--output-dir", type=Path, default=Path("mixture/runs/top1_router"), help="Directory for logs and checkpoints")
|
| 848 |
-
parser.add_argument("--save-router", action="store_true", help="Save best router state_dict to output directory")
|
| 849 |
-
|
| 850 |
-
args = parser.parse_args()
|
| 851 |
-
|
| 852 |
-
if args.max_per_combo is not None and args.max_per_combo < 0:
|
| 853 |
-
parser.error("--max-per-combo must be >= 0")
|
| 854 |
-
if args.spec_augment and not (0.0 <= args.spec_augment_prob <= 1.0):
|
| 855 |
-
parser.error("--spec-augment-prob must be between 0 and 1")
|
| 856 |
-
if args.max_samples_per_comm <= 0:
|
| 857 |
-
parser.error("--max-samples-per-comm must be positive")
|
| 858 |
-
if args.train_ratio <= 0 or args.val_ratio <= 0:
|
| 859 |
-
parser.error("--train-ratio and --val-ratio must be positive")
|
| 860 |
-
if args.train_ratio + args.val_ratio >= 1.0:
|
| 861 |
-
parser.error("--train-ratio + --val-ratio must be < 1.0")
|
| 862 |
-
|
| 863 |
-
return args
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
def maybe_apply_spec_augment(
|
| 867 |
-
specs: torch.Tensor,
|
| 868 |
-
*,
|
| 869 |
-
enabled: bool,
|
| 870 |
-
freq_width: int,
|
| 871 |
-
time_width: int,
|
| 872 |
-
prob: float,
|
| 873 |
-
) -> torch.Tensor:
|
| 874 |
-
if not enabled:
|
| 875 |
-
return specs
|
| 876 |
-
from task1.train_mcs_models import apply_spec_augment
|
| 877 |
-
|
| 878 |
-
return apply_spec_augment(
|
| 879 |
-
specs,
|
| 880 |
-
freq_mask_width=freq_width,
|
| 881 |
-
time_mask_width=time_width,
|
| 882 |
-
mask_prob=prob,
|
| 883 |
-
)
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
def main() -> None:
|
| 887 |
-
args = parse_args()
|
| 888 |
-
set_seed(args.seed)
|
| 889 |
-
|
| 890 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 891 |
-
|
| 892 |
-
comm_types = [canonical_comm_name(comm) for comm in args.comm_types]
|
| 893 |
-
dataset, comm_to_idx, combo_mapping = prepare_dataset(
|
| 894 |
-
data_root=args.data_root.expanduser().resolve(),
|
| 895 |
-
cities=args.cities,
|
| 896 |
-
comm_types=comm_types,
|
| 897 |
-
snrs=args.snrs,
|
| 898 |
-
mobilities=args.mobilities,
|
| 899 |
-
modulations=args.modulations,
|
| 900 |
-
fft_folders=args.fft_folders,
|
| 901 |
-
max_samples_per_comm=args.max_samples_per_comm,
|
| 902 |
-
max_per_combo=args.max_per_combo,
|
| 903 |
-
task=args.task,
|
| 904 |
-
seed=args.seed,
|
| 905 |
-
)
|
| 906 |
-
num_comm = len(comm_types)
|
| 907 |
-
num_task_classes = int(dataset.task_labels.max()) + 1
|
| 908 |
-
|
| 909 |
-
train_idx, val_idx, test_idx = stratified_split(
|
| 910 |
-
dataset.comm_labels.numpy(),
|
| 911 |
-
train_ratio=args.train_ratio,
|
| 912 |
-
val_ratio=args.val_ratio,
|
| 913 |
-
seed=args.seed,
|
| 914 |
-
)
|
| 915 |
-
train_loader, val_loader, test_loader = build_dataloaders(
|
| 916 |
-
dataset,
|
| 917 |
-
train_idx=train_idx,
|
| 918 |
-
val_idx=val_idx,
|
| 919 |
-
test_idx=test_idx,
|
| 920 |
-
batch_size=args.batch_size,
|
| 921 |
-
num_workers=args.num_workers,
|
| 922 |
-
)
|
| 923 |
-
|
| 924 |
-
router = RouterNet(num_comm=num_comm, dropout=args.dropout).to(device)
|
| 925 |
-
criterion = nn.CrossEntropyLoss()
|
| 926 |
-
optimizer = torch.optim.AdamW(router.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 927 |
-
scaler = GradScaler(enabled=args.use_amp and device.type == "cuda")
|
| 928 |
-
|
| 929 |
-
best_state: Optional[Dict[str, torch.Tensor]] = None
|
| 930 |
-
best_val_acc = 0.0
|
| 931 |
-
|
| 932 |
-
for epoch in range(1, args.epochs + 1):
|
| 933 |
-
router.train()
|
| 934 |
-
running_loss = 0.0
|
| 935 |
-
running_correct = 0
|
| 936 |
-
total = 0
|
| 937 |
-
|
| 938 |
-
for specs, comm_labels, _ in train_loader:
|
| 939 |
-
specs = specs.to(device, non_blocking=True)
|
| 940 |
-
comm_labels = torch.as_tensor(comm_labels, device=device)
|
| 941 |
-
specs_aug = maybe_apply_spec_augment(
|
| 942 |
-
specs,
|
| 943 |
-
enabled=args.spec_augment,
|
| 944 |
-
freq_width=args.spec_augment_freq,
|
| 945 |
-
time_width=args.spec_augment_time,
|
| 946 |
-
prob=args.spec_augment_prob,
|
| 947 |
-
)
|
| 948 |
-
|
| 949 |
-
optimizer.zero_grad(set_to_none=True)
|
| 950 |
-
context = autocast(device_type=device.type, enabled=scaler.is_enabled())
|
| 951 |
-
with context:
|
| 952 |
-
logits = router(specs_aug)
|
| 953 |
-
loss = criterion(logits, comm_labels)
|
| 954 |
-
if scaler.is_enabled():
|
| 955 |
-
scaler.scale(loss).backward()
|
| 956 |
-
scaler.step(optimizer)
|
| 957 |
-
scaler.update()
|
| 958 |
-
else:
|
| 959 |
-
loss.backward()
|
| 960 |
-
optimizer.step()
|
| 961 |
-
|
| 962 |
-
preds = logits.argmax(dim=1)
|
| 963 |
-
running_loss += loss.item() * specs.size(0)
|
| 964 |
-
running_correct += (preds == comm_labels).sum().item()
|
| 965 |
-
total += specs.size(0)
|
| 966 |
-
|
| 967 |
-
train_loss = running_loss / max(total, 1)
|
| 968 |
-
train_acc = running_correct / max(total, 1)
|
| 969 |
-
|
| 970 |
-
val_loss, val_acc, y_true_val, y_pred_val = evaluate_router(router, val_loader, criterion, device)
|
| 971 |
-
val_confusion = compute_confusion(y_true_val, y_pred_val, num_comm)
|
| 972 |
-
|
| 973 |
-
print(
|
| 974 |
-
f"[Epoch {epoch:02d}] train_loss={train_loss:.4f} "
|
| 975 |
-
f"train_acc={train_acc:.3f} val_loss={val_loss:.4f} val_acc={val_acc:.3f}"
|
| 976 |
-
)
|
| 977 |
-
|
| 978 |
-
if val_acc >= best_val_acc:
|
| 979 |
-
best_val_acc = val_acc
|
| 980 |
-
best_state = {k: v.detach().cpu() for k, v in router.state_dict().items()}
|
| 981 |
-
print(f"[Epoch {epoch:02d}] Val confusion matrix:\n{val_confusion}")
|
| 982 |
-
|
| 983 |
-
if best_state is None:
|
| 984 |
-
best_state = {k: v.detach().cpu() for k, v in router.state_dict().items()}
|
| 985 |
-
router.load_state_dict(best_state)
|
| 986 |
-
|
| 987 |
-
output_dir = args.output_dir.expanduser().resolve()
|
| 988 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 989 |
-
|
| 990 |
-
experts: Dict[int, Tuple[str, nn.Module]] = {}
|
| 991 |
-
expert_specs = parse_expert_definitions(args.expert)
|
| 992 |
-
for comm, spec in expert_specs.items():
|
| 993 |
-
comm_idx = comm_to_idx.get(comm)
|
| 994 |
-
if comm_idx is None:
|
| 995 |
-
print(f"[WARN] Expert for {comm} provided but communication not in dataset; skipping")
|
| 996 |
-
continue
|
| 997 |
-
name, model, out_classes = load_expert_model(
|
| 998 |
-
spec,
|
| 999 |
-
stats_root=args.expert_stats_root.expanduser().resolve(),
|
| 1000 |
-
device=device,
|
| 1001 |
-
)
|
| 1002 |
-
if out_classes != num_task_classes:
|
| 1003 |
-
print(
|
| 1004 |
-
f"[WARN] Expert '{name}' outputs {out_classes} classes, "
|
| 1005 |
-
f"but dataset task expects {num_task_classes}. Skipping expert."
|
| 1006 |
-
)
|
| 1007 |
-
continue
|
| 1008 |
-
experts[comm_idx] = (name, model)
|
| 1009 |
-
|
| 1010 |
-
test_metrics = evaluate_routing(
|
| 1011 |
-
router,
|
| 1012 |
-
experts,
|
| 1013 |
-
test_loader,
|
| 1014 |
-
num_comm=num_comm,
|
| 1015 |
-
num_task_classes=num_task_classes,
|
| 1016 |
-
device=device,
|
| 1017 |
-
routing_mode=args.routing_mode,
|
| 1018 |
-
routing_topk=args.routing_topk,
|
| 1019 |
-
)
|
| 1020 |
-
print("[RESULT] Test metrics:")
|
| 1021 |
-
print(json.dumps(test_metrics, indent=2))
|
| 1022 |
-
|
| 1023 |
-
metrics_path = output_dir / "metrics.json"
|
| 1024 |
-
with open(metrics_path, "w", encoding="utf-8") as fh:
|
| 1025 |
-
json.dump(test_metrics, fh, indent=2)
|
| 1026 |
-
|
| 1027 |
-
if combo_mapping is not None:
|
| 1028 |
-
mapping_path = output_dir / "snr_mobility_mapping.json"
|
| 1029 |
-
with open(mapping_path, "w", encoding="utf-8") as fh:
|
| 1030 |
-
json.dump({int(k): v for k, v in combo_mapping.items()}, fh, indent=2)
|
| 1031 |
-
|
| 1032 |
-
if args.save_router:
|
| 1033 |
-
ckpt_path = output_dir / "router_top1_state_dict.pth"
|
| 1034 |
-
torch.save(best_state, ckpt_path)
|
| 1035 |
-
print(f"[INFO] Saved router checkpoint to {ckpt_path}")
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
if __name__ == "__main__":
|
| 1039 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretraining/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (159 Bytes)
|
|
|
pretraining/__pycache__/pretrained_model.cpython-311.pyc
DELETED
|
Binary file (14.6 kB)
|
|
|
pretraining/pretrained_model.py
CHANGED
|
@@ -178,10 +178,3 @@ def lwm(*args, **kwargs) -> LWM:
|
|
| 178 |
"""Factory to preserve backward compatibility with older imports."""
|
| 179 |
|
| 180 |
return LWM(*args, **kwargs)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
class PretrainedLWM(LWM):
|
| 184 |
-
"""Alias retained for compatibility with existing inference scripts."""
|
| 185 |
-
|
| 186 |
-
def __init__(self, *args, **kwargs) -> None:
|
| 187 |
-
super().__init__(*args, **kwargs)
|
|
|
|
| 178 |
"""Factory to preserve backward compatibility with older imports."""
|
| 179 |
|
| 180 |
return LWM(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task1/plot_tsne.py
DELETED
|
@@ -1,802 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Visualise how strongly metadata drives the learned embedding space.
|
| 3 |
-
|
| 4 |
-
This script mirrors the functionality of ``task1/plot_mod_tsne.py`` but groups
|
| 5 |
-
spectrograms by their SNR folder name (e.g. ``SNR0dB``) instead of modulation.
|
| 6 |
-
It is useful for checking whether the self-supervised LWM backbone mostly
|
| 7 |
-
captures channel/SNR differences rather than modulation characteristics.
|
| 8 |
-
|
| 9 |
-
Pass ``--label-field modulation`` to reuse the same sampled spectrograms while
|
| 10 |
-
colouring and scoring them by their modulation folder instead of SNR. Use
|
| 11 |
-
``--label-field mobility`` to highlight link-level mobility categories when
|
| 12 |
-
present in the dataset tree. Saved figures automatically include the detected
|
| 13 |
-
communication profile (e.g. LTE/WiFi/5G) and label mode in the filename when
|
| 14 |
-
those suffixes are not already present.
|
| 15 |
-
|
| 16 |
-
Usage example:
|
| 17 |
-
|
| 18 |
-
```bash
|
| 19 |
-
python task1/plot_snr_tsne.py \
|
| 20 |
-
--data-root spectrograms/city_1_losangeles/LTE \
|
| 21 |
-
--snrs SNR-5dB,SNR0dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB \
|
| 22 |
-
--save-path task1/snr_separation_plot_latest.png
|
| 23 |
-
```
|
| 24 |
-
Shortcut presets:
|
| 25 |
-
|
| 26 |
-
```bash
|
| 27 |
-
python task1/plot_snr_tsne.py --WiFi --report-metrics
|
| 28 |
-
```
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
from __future__ import annotations
|
| 32 |
-
|
| 33 |
-
import argparse
|
| 34 |
-
import glob
|
| 35 |
-
import pickle
|
| 36 |
-
import random
|
| 37 |
-
import re
|
| 38 |
-
from pathlib import Path
|
| 39 |
-
from collections import Counter, defaultdict
|
| 40 |
-
from typing import Dict, Iterable, List, Tuple
|
| 41 |
-
|
| 42 |
-
import matplotlib.pyplot as plt
|
| 43 |
-
import numpy as np
|
| 44 |
-
import torch
|
| 45 |
-
from sklearn.manifold import TSNE
|
| 46 |
-
from sklearn.metrics import silhouette_score
|
| 47 |
-
from sklearn.model_selection import StratifiedKFold
|
| 48 |
-
from sklearn.neighbors import KNeighborsClassifier
|
| 49 |
-
from sklearn.preprocessing import StandardScaler
|
| 50 |
-
|
| 51 |
-
from pretraining.pretrained_model import lwm as lwm_model
|
| 52 |
-
from utils import load_spectrogram_data # support .mat and .pkl uniformly
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
DEFAULT_DATA_ROOT = "spectrograms/city_1_losangeles/LTE"
|
| 56 |
-
DEFAULT_MODELS_ROOT = "models/LTE_models"
|
| 57 |
-
|
| 58 |
-
PROFILE_PRESETS: Dict[str, Dict[str, str]] = {
|
| 59 |
-
"LTE": {
|
| 60 |
-
"data_root": DEFAULT_DATA_ROOT,
|
| 61 |
-
"models_root": DEFAULT_MODELS_ROOT,
|
| 62 |
-
},
|
| 63 |
-
"WiFi": {
|
| 64 |
-
"data_root": "spectrograms/city_1_losangeles/WiFi",
|
| 65 |
-
"models_root": "models/WiFi_models",
|
| 66 |
-
},
|
| 67 |
-
"5G": {
|
| 68 |
-
"data_root": "spectrograms/city_1_losangeles/5G",
|
| 69 |
-
"models_root": "models/5G_models",
|
| 70 |
-
},
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
|
| 75 |
-
means = specs.mean(axis=(1, 2), keepdims=True)
|
| 76 |
-
stds = specs.std(axis=(1, 2), keepdims=True)
|
| 77 |
-
stds = np.maximum(stds, eps)
|
| 78 |
-
return ((specs - means) / stds).astype(np.float32, copy=False)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def normalize_dataset(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray:
|
| 82 |
-
mean = float(specs.mean())
|
| 83 |
-
std = float(specs.std())
|
| 84 |
-
std = max(std, eps)
|
| 85 |
-
return ((specs - mean) / std).astype(np.float32, copy=False)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# ---------------------------------------------------------------------------
|
| 89 |
-
# Utility helpers
|
| 90 |
-
# ---------------------------------------------------------------------------
|
| 91 |
-
|
| 92 |
-
def parse_args() -> argparse.Namespace:
|
| 93 |
-
parser = argparse.ArgumentParser(description=__doc__)
|
| 94 |
-
parser.add_argument(
|
| 95 |
-
"--data-root",
|
| 96 |
-
default=DEFAULT_DATA_ROOT,
|
| 97 |
-
help="Root directory containing modulation folders (default: %(default)s)",
|
| 98 |
-
)
|
| 99 |
-
parser.add_argument(
|
| 100 |
-
"--modulation",
|
| 101 |
-
default="all",
|
| 102 |
-
help="Modulation folder to load (default: %(default)s)",
|
| 103 |
-
)
|
| 104 |
-
parser.add_argument(
|
| 105 |
-
"--snrs",
|
| 106 |
-
default="SNR-5dB,SNR0dB,SNR5dB,SNR10dB,SNR15dB,SNR20dB,SNR25dB",
|
| 107 |
-
help=(
|
| 108 |
-
"Comma-separated list of SNR folder names to include. Pass 'all' "
|
| 109 |
-
"to include every SNR discovered under the modulation (default: %(default)s)"
|
| 110 |
-
),
|
| 111 |
-
)
|
| 112 |
-
parser.add_argument(
|
| 113 |
-
"--mobility",
|
| 114 |
-
nargs="+",
|
| 115 |
-
default=["all"],
|
| 116 |
-
help=(
|
| 117 |
-
"Mobility folder(s) to filter on. Pass 'all' to include every mobility "
|
| 118 |
-
"(default: %(default)s). Multiple values can be provided either as a "
|
| 119 |
-
"space-separated list (e.g. '--mobility vehicular pedestrian') or a "
|
| 120 |
-
"comma-separated string."
|
| 121 |
-
),
|
| 122 |
-
)
|
| 123 |
-
parser.add_argument(
|
| 124 |
-
"--fft-folder",
|
| 125 |
-
default="all",
|
| 126 |
-
help=(
|
| 127 |
-
"FFT size folder name to use. Pass 'all' to include every FFT variant "
|
| 128 |
-
"(default: %(default)s)"
|
| 129 |
-
),
|
| 130 |
-
)
|
| 131 |
-
parser.add_argument(
|
| 132 |
-
"--samples-per-snr",
|
| 133 |
-
type=int,
|
| 134 |
-
default=500,
|
| 135 |
-
help="Maximum number of samples to draw for each SNR label",
|
| 136 |
-
)
|
| 137 |
-
parser.add_argument(
|
| 138 |
-
"--seed",
|
| 139 |
-
type=int,
|
| 140 |
-
default=42,
|
| 141 |
-
help="Random seed for sampling and t-SNE",
|
| 142 |
-
)
|
| 143 |
-
parser.add_argument(
|
| 144 |
-
"--pooling",
|
| 145 |
-
choices=("mean", "cls"),
|
| 146 |
-
default="mean",
|
| 147 |
-
help="How to collapse token embeddings into a single vector",
|
| 148 |
-
)
|
| 149 |
-
parser.add_argument(
|
| 150 |
-
"--save-path",
|
| 151 |
-
default="task1/snr_separation_plot_latest.png",
|
| 152 |
-
help="Location to save the generated figure (default: %(default)s)",
|
| 153 |
-
)
|
| 154 |
-
parser.add_argument(
|
| 155 |
-
"--checkpoint",
|
| 156 |
-
default=None,
|
| 157 |
-
help="Optional explicit checkpoint path; overrides automatic latest selection",
|
| 158 |
-
)
|
| 159 |
-
parser.add_argument(
|
| 160 |
-
"--models-root",
|
| 161 |
-
default=DEFAULT_MODELS_ROOT,
|
| 162 |
-
help=(
|
| 163 |
-
"Directory containing checkpoints. When --checkpoint is not given, "
|
| 164 |
-
"the latest/best checkpoint inside this directory will be used "
|
| 165 |
-
"(default: %(default)s)"
|
| 166 |
-
),
|
| 167 |
-
)
|
| 168 |
-
preset_group = parser.add_mutually_exclusive_group()
|
| 169 |
-
preset_group.add_argument(
|
| 170 |
-
"--profile",
|
| 171 |
-
dest="profile",
|
| 172 |
-
choices=tuple(PROFILE_PRESETS.keys()),
|
| 173 |
-
help=(
|
| 174 |
-
"Convenience preset that sets --data-root and --models-root when they "
|
| 175 |
-
"are left at their defaults"
|
| 176 |
-
),
|
| 177 |
-
)
|
| 178 |
-
preset_group.add_argument(
|
| 179 |
-
"--LTE",
|
| 180 |
-
dest="profile",
|
| 181 |
-
action="store_const",
|
| 182 |
-
const="LTE",
|
| 183 |
-
help="Shortcut for --profile LTE",
|
| 184 |
-
)
|
| 185 |
-
preset_group.add_argument(
|
| 186 |
-
"--WiFi",
|
| 187 |
-
dest="profile",
|
| 188 |
-
action="store_const",
|
| 189 |
-
const="WiFi",
|
| 190 |
-
help="Shortcut for --profile WiFi",
|
| 191 |
-
)
|
| 192 |
-
preset_group.add_argument(
|
| 193 |
-
"--5G",
|
| 194 |
-
dest="profile",
|
| 195 |
-
action="store_const",
|
| 196 |
-
const="5G",
|
| 197 |
-
help="Shortcut for --profile 5G",
|
| 198 |
-
)
|
| 199 |
-
parser.add_argument(
|
| 200 |
-
"--report-metrics",
|
| 201 |
-
action="store_true",
|
| 202 |
-
help="Print clustering metrics (silhouette, 5-fold kNN accuracy)",
|
| 203 |
-
)
|
| 204 |
-
parser.add_argument(
|
| 205 |
-
"--metrics-only",
|
| 206 |
-
action="store_true",
|
| 207 |
-
help="Exit after reporting metrics without running t-SNE or saving figures",
|
| 208 |
-
)
|
| 209 |
-
parser.add_argument(
|
| 210 |
-
"--sampling-mode",
|
| 211 |
-
choices=("first", "reservoir"),
|
| 212 |
-
default="first",
|
| 213 |
-
help="How to down-sample each class (default: first)",
|
| 214 |
-
)
|
| 215 |
-
parser.add_argument(
|
| 216 |
-
"--complex-mode",
|
| 217 |
-
choices=("auto", "magnitude", "interleaved"),
|
| 218 |
-
default="auto",
|
| 219 |
-
help=(
|
| 220 |
-
"How to handle complex spectrograms: 'magnitude' (abs), 'interleaved' (real/imag interleaved along width), "
|
| 221 |
-
"or 'auto' (prefer interleaved when complex). Real-valued inputs are unaffected."
|
| 222 |
-
),
|
| 223 |
-
)
|
| 224 |
-
parser.add_argument(
|
| 225 |
-
"--label-field",
|
| 226 |
-
choices=("snr", "modulation", "mobility"),
|
| 227 |
-
default="snr",
|
| 228 |
-
help="Choose which label to visualise and score (default: %(default)s)",
|
| 229 |
-
)
|
| 230 |
-
parser.add_argument(
|
| 231 |
-
"--normalization",
|
| 232 |
-
choices=("per-sample", "dataset"),
|
| 233 |
-
default="per-sample",
|
| 234 |
-
help="Normalisation strategy applied before embedding extraction",
|
| 235 |
-
)
|
| 236 |
-
return parser.parse_args()
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def find_latest_checkpoint(models_root: Path) -> Path:
|
| 240 |
-
"""Return a checkpoint path under ``models_root``.
|
| 241 |
-
|
| 242 |
-
Works with either a parent directory that contains multiple run folders,
|
| 243 |
-
or directly with a single run directory containing ``*.pth`` files.
|
| 244 |
-
Chooses the checkpoint with the lowest parsed validation value when
|
| 245 |
-
available, else falls back to most-recent modification time.
|
| 246 |
-
"""
|
| 247 |
-
|
| 248 |
-
if not models_root.exists():
|
| 249 |
-
raise FileNotFoundError(f"Models root not found: {models_root}")
|
| 250 |
-
|
| 251 |
-
if models_root.is_file():
|
| 252 |
-
raise FileNotFoundError(f"Expected a directory, got file: {models_root}")
|
| 253 |
-
|
| 254 |
-
# If the provided directory itself contains checkpoints, use it directly.
|
| 255 |
-
checkpoints = list(models_root.glob("*.pth"))
|
| 256 |
-
if not checkpoints:
|
| 257 |
-
# Otherwise, look for subdirectories that contain checkpoints and ignore others (e.g., tensorboard)
|
| 258 |
-
run_dirs = [p for p in models_root.iterdir() if p.is_dir()]
|
| 259 |
-
candidate_runs = [d for d in run_dirs if any(d.glob("*.pth"))]
|
| 260 |
-
if not candidate_runs:
|
| 261 |
-
raise FileNotFoundError(
|
| 262 |
-
f"No checkpoints found under {models_root} (no .pth files in this dir or its run subdirs)"
|
| 263 |
-
)
|
| 264 |
-
latest_run = max(candidate_runs, key=lambda p: p.stat().st_mtime)
|
| 265 |
-
checkpoints = list(latest_run.glob("*.pth"))
|
| 266 |
-
|
| 267 |
-
def parse_val_metric(path: Path) -> float | None:
|
| 268 |
-
match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name)
|
| 269 |
-
if match:
|
| 270 |
-
try:
|
| 271 |
-
return float(match.group(1))
|
| 272 |
-
except ValueError:
|
| 273 |
-
return None
|
| 274 |
-
return None
|
| 275 |
-
|
| 276 |
-
parsed = [(parse_val_metric(p), p) for p in checkpoints]
|
| 277 |
-
valid = [item for item in parsed if item[0] is not None]
|
| 278 |
-
if valid:
|
| 279 |
-
valid.sort(key=lambda item: item[0])
|
| 280 |
-
return valid[0][1]
|
| 281 |
-
|
| 282 |
-
# Fallback to most recent modification time
|
| 283 |
-
return max(checkpoints, key=lambda p: p.stat().st_mtime)
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
def parse_snr_list(snr_argument: str | None) -> set[str] | None:
|
| 287 |
-
if snr_argument is None or snr_argument.lower() == "all":
|
| 288 |
-
return None
|
| 289 |
-
values = [item.strip() for item in snr_argument.split(",") if item.strip()]
|
| 290 |
-
return set(values)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
def list_snr_samples(
|
| 294 |
-
data_root: Path,
|
| 295 |
-
modulation: str,
|
| 296 |
-
allowed_snrs: set[str] | None,
|
| 297 |
-
mobility_filter: set[str] | None,
|
| 298 |
-
fft_folder: str,
|
| 299 |
-
max_per_class: int,
|
| 300 |
-
rng: random.Random,
|
| 301 |
-
mode: str,
|
| 302 |
-
complex_mode: str,
|
| 303 |
-
) -> Dict[str, List[Tuple[np.ndarray, str, str]]]:
|
| 304 |
-
"""Collect spectrogram samples grouped by SNR label.
|
| 305 |
-
|
| 306 |
-
Supports both legacy PKL layout with a trailing 'spectrograms/' folder and
|
| 307 |
-
MATLAB .mat bundles saved directly under the mobility folder.
|
| 308 |
-
|
| 309 |
-
Returns: mapping from SNR label to list of tuples: (spec, modulation, mobility)
|
| 310 |
-
"""
|
| 311 |
-
|
| 312 |
-
class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]] = defaultdict(list)
|
| 313 |
-
seen_counts: Dict[str, int] = defaultdict(int)
|
| 314 |
-
|
| 315 |
-
# Search patterns:
|
| 316 |
-
# - PKL under .../spectrograms/*.pkl
|
| 317 |
-
# - MAT under .../spectrogram_*.mat
|
| 318 |
-
patterns = [
|
| 319 |
-
str(data_root / "**" / "spectrograms" / "*.pkl"),
|
| 320 |
-
str(data_root / "**" / "spectrogram_*.mat"),
|
| 321 |
-
]
|
| 322 |
-
|
| 323 |
-
mobility_set = {"static", "pedestrian", "vehicular"}
|
| 324 |
-
|
| 325 |
-
def extract_tokens(rel_parts: Tuple[str, ...]) -> Tuple[str, str, str, str] | None:
|
| 326 |
-
# Heuristic extraction to support both layouts
|
| 327 |
-
# modulation: first path segment below data_root
|
| 328 |
-
if not rel_parts:
|
| 329 |
-
return None
|
| 330 |
-
modulation_folder = rel_parts[0]
|
| 331 |
-
|
| 332 |
-
# snr: first segment like SNR(-?)NdB
|
| 333 |
-
snr_folder = next((p for p in rel_parts if re.match(r"^SNR-?\d+dB$", p)), None)
|
| 334 |
-
if snr_folder is None:
|
| 335 |
-
return None
|
| 336 |
-
|
| 337 |
-
# mobility: one of known labels
|
| 338 |
-
mobility_folder = next((p for p in rel_parts if p.lower() in mobility_set), None)
|
| 339 |
-
if mobility_folder is None:
|
| 340 |
-
return None
|
| 341 |
-
|
| 342 |
-
# fft/window folder if present (PKL layout), else fallback for MAT
|
| 343 |
-
fft_folder_name = next((p for p in rel_parts if p.startswith("win") or p.startswith("fft")), "fft_unknown")
|
| 344 |
-
|
| 345 |
-
return modulation_folder, snr_folder, mobility_folder, fft_folder_name
|
| 346 |
-
|
| 347 |
-
for pattern in patterns:
|
| 348 |
-
for path_str in glob.iglob(pattern, recursive=True):
|
| 349 |
-
path = Path(path_str)
|
| 350 |
-
try:
|
| 351 |
-
rel_parts = path.relative_to(data_root).parts
|
| 352 |
-
except ValueError:
|
| 353 |
-
continue
|
| 354 |
-
|
| 355 |
-
tokens = extract_tokens(rel_parts)
|
| 356 |
-
if tokens is None:
|
| 357 |
-
continue
|
| 358 |
-
modulation_folder, snr_folder, mobility_folder, fft_folder_name = tokens
|
| 359 |
-
|
| 360 |
-
# Apply filters
|
| 361 |
-
if modulation.lower() != "all" and modulation_folder != modulation:
|
| 362 |
-
continue
|
| 363 |
-
if allowed_snrs is not None and snr_folder not in allowed_snrs:
|
| 364 |
-
continue
|
| 365 |
-
if mobility_filter is not None and mobility_folder.lower() not in mobility_filter:
|
| 366 |
-
continue
|
| 367 |
-
if fft_folder != "all" and fft_folder_name != fft_folder:
|
| 368 |
-
continue
|
| 369 |
-
|
| 370 |
-
class_label = snr_folder
|
| 371 |
-
if mode == "first" and len(class_samples[class_label]) >= max_per_class:
|
| 372 |
-
continue
|
| 373 |
-
|
| 374 |
-
# Load spectrogram data (supports .pkl and .mat)
|
| 375 |
-
try:
|
| 376 |
-
arr = load_spectrogram_data(str(path))
|
| 377 |
-
except Exception as exc: # pragma: no cover - I/O heavy
|
| 378 |
-
print(f"[WARN] Failed to load {path}: {exc}")
|
| 379 |
-
continue
|
| 380 |
-
|
| 381 |
-
if not isinstance(arr, np.ndarray) or arr.size == 0:
|
| 382 |
-
continue
|
| 383 |
-
|
| 384 |
-
# If loaded spectrograms are complex, convert according to mode
|
| 385 |
-
if np.iscomplexobj(arr):
|
| 386 |
-
if complex_mode == "magnitude":
|
| 387 |
-
arr = np.abs(arr)
|
| 388 |
-
else:
|
| 389 |
-
# Interleave real/imag parts along the width dimension
|
| 390 |
-
if arr.ndim == 4 and arr.shape[1] == 1:
|
| 391 |
-
arr = arr[:, 0]
|
| 392 |
-
if arr.ndim == 3:
|
| 393 |
-
real = arr.real.astype(np.float32, copy=False)
|
| 394 |
-
imag = arr.imag.astype(np.float32, copy=False)
|
| 395 |
-
n, h, w = real.shape
|
| 396 |
-
inter = np.empty((n, h, w * 2), dtype=np.float32)
|
| 397 |
-
inter[:, :, 0::2] = real
|
| 398 |
-
inter[:, :, 1::2] = imag
|
| 399 |
-
arr = inter
|
| 400 |
-
else:
|
| 401 |
-
# Fallback to magnitude for unsupported shapes
|
| 402 |
-
arr = np.abs(arr)
|
| 403 |
-
|
| 404 |
-
# Normalize shapes:
|
| 405 |
-
# - (N, H, W)
|
| 406 |
-
# - (N, C, H, W) -> collapse channels via mean
|
| 407 |
-
if arr.ndim == 4:
|
| 408 |
-
# (N, C, H, W) -> (N, H, W)
|
| 409 |
-
if arr.shape[1] > 1:
|
| 410 |
-
specs = arr.mean(axis=1)
|
| 411 |
-
else:
|
| 412 |
-
specs = arr[:, 0]
|
| 413 |
-
elif arr.ndim == 3:
|
| 414 |
-
specs = arr
|
| 415 |
-
elif arr.ndim == 2:
|
| 416 |
-
specs = arr[None, ...]
|
| 417 |
-
else:
|
| 418 |
-
print(f"[WARN] Unexpected spectrogram shape in {path}: {arr.shape}")
|
| 419 |
-
continue
|
| 420 |
-
|
| 421 |
-
for spec in specs:
|
| 422 |
-
sample = np.asarray(spec, dtype=np.float32)
|
| 423 |
-
bucket = class_samples[class_label]
|
| 424 |
-
|
| 425 |
-
if len(bucket) < max_per_class:
|
| 426 |
-
bucket.append((sample, modulation_folder, mobility_folder))
|
| 427 |
-
seen_counts[class_label] += 1
|
| 428 |
-
elif mode == "reservoir":
|
| 429 |
-
seen_counts[class_label] += 1
|
| 430 |
-
j = rng.randint(0, seen_counts[class_label] - 1)
|
| 431 |
-
if j < max_per_class:
|
| 432 |
-
bucket[j] = (sample, modulation_folder, mobility_folder)
|
| 433 |
-
else: # mode == "first" and already full
|
| 434 |
-
break
|
| 435 |
-
|
| 436 |
-
return class_samples
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
def sample_balanced_dataset(
|
| 440 |
-
class_samples: Dict[str, List[Tuple[np.ndarray, str, str]]],
|
| 441 |
-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[str]]:
|
| 442 |
-
"""Stack the sampled spectrograms alongside SNR, modulation, and mobility labels."""
|
| 443 |
-
|
| 444 |
-
features: List[np.ndarray] = []
|
| 445 |
-
snr_labels: List[str] = []
|
| 446 |
-
modulation_labels: List[str] = []
|
| 447 |
-
mobility_labels: List[str] = []
|
| 448 |
-
class_names = sorted(class_samples.keys())
|
| 449 |
-
|
| 450 |
-
for class_name in class_names:
|
| 451 |
-
samples = class_samples[class_name]
|
| 452 |
-
if not samples:
|
| 453 |
-
continue
|
| 454 |
-
for sample, modulation_label, mobility_label in samples:
|
| 455 |
-
features.append(sample)
|
| 456 |
-
snr_labels.append(class_name)
|
| 457 |
-
modulation_labels.append(modulation_label)
|
| 458 |
-
mobility_labels.append(mobility_label)
|
| 459 |
-
|
| 460 |
-
if not features:
|
| 461 |
-
raise RuntimeError("No spectrogram samples collected for the specified filters")
|
| 462 |
-
|
| 463 |
-
stacked = np.stack(features) # [N, 128, 128]
|
| 464 |
-
return (
|
| 465 |
-
stacked,
|
| 466 |
-
np.array(snr_labels),
|
| 467 |
-
np.array(modulation_labels),
|
| 468 |
-
np.array(mobility_labels),
|
| 469 |
-
class_names,
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
def unfold_patches_square(x: torch.Tensor, patch_size: int = 4) -> torch.Tensor:
|
| 474 |
-
# Input shape: [B, H, W]; extracts (patch_size x patch_size) patches
|
| 475 |
-
patches_h = x.unfold(1, patch_size, patch_size)
|
| 476 |
-
patches = patches_h.unfold(2, patch_size, patch_size)
|
| 477 |
-
return patches.contiguous().view(x.shape[0], -1, patch_size * patch_size)
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
def unfold_patches_rect(x: torch.Tensor, patch_rows: int = 4, patch_cols: int = 8) -> torch.Tensor:
|
| 481 |
-
# Input shape: [B, H, W]; extracts (patch_rows x patch_cols) patches (for interleaved complex)
|
| 482 |
-
patches_h = x.unfold(1, patch_rows, patch_rows)
|
| 483 |
-
patches = patches_h.unfold(2, patch_cols, patch_cols)
|
| 484 |
-
return patches.contiguous().view(x.shape[0], -1, patch_rows * patch_cols)
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
def extract_tokens(spec: np.ndarray, device: torch.device, interleaved: bool) -> torch.Tensor:
|
| 488 |
-
tensor = torch.from_numpy(spec).unsqueeze(0).to(device)
|
| 489 |
-
if interleaved:
|
| 490 |
-
# Rectangular patches 4x8 to cover 4x4 complex bins (real+imag)
|
| 491 |
-
return unfold_patches_rect(tensor, 4, 8) # [1, 1024, 32]
|
| 492 |
-
else:
|
| 493 |
-
return unfold_patches_square(tensor, 4) # [1, 1024, 16]
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
def pool_embeddings(
|
| 497 |
-
tokens: torch.Tensor,
|
| 498 |
-
model: torch.nn.Module,
|
| 499 |
-
pooling: str,
|
| 500 |
-
) -> np.ndarray:
|
| 501 |
-
# Append CLS token (value 0.2) before passing through the transformer.
|
| 502 |
-
cls_token = torch.full((tokens.size(0), 1, tokens.size(-1)), 0.2, device=tokens.device)
|
| 503 |
-
inputs = torch.cat([cls_token, tokens], dim=1) # [B, 1025, 16]
|
| 504 |
-
|
| 505 |
-
with torch.no_grad():
|
| 506 |
-
outputs = model(inputs) # [B, 1025, 128]
|
| 507 |
-
|
| 508 |
-
if pooling == "cls":
|
| 509 |
-
pooled = outputs[:, 0]
|
| 510 |
-
else: # mean pooling across patch tokens (exclude CLS)
|
| 511 |
-
pooled = outputs[:, 1:].mean(dim=1)
|
| 512 |
-
|
| 513 |
-
return pooled.detach().cpu().numpy()
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
def sort_snr_labels(labels: List[str]) -> List[str]:
|
| 517 |
-
"""Sort SNR labels by numeric value instead of lexicographic order."""
|
| 518 |
-
def extract_snr_value(label: str) -> float:
|
| 519 |
-
"""Extract numeric SNR value from label like 'SNR-5dB' -> -5.0"""
|
| 520 |
-
import re
|
| 521 |
-
match = re.search(r'SNR(-?\d+)dB', label)
|
| 522 |
-
if match:
|
| 523 |
-
return float(match.group(1))
|
| 524 |
-
else:
|
| 525 |
-
return float('inf') # Put non-SNR labels at the end
|
| 526 |
-
|
| 527 |
-
return sorted(labels, key=extract_snr_value)
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
def run_tsne(x: np.ndarray, labels: np.ndarray, title: str, ax: plt.Axes) -> None:
|
| 531 |
-
scaler = StandardScaler()
|
| 532 |
-
x_scaled = scaler.fit_transform(x)
|
| 533 |
-
# Guard against NaN/Inf from upstream (normalisation or model outputs)
|
| 534 |
-
x_scaled = np.nan_to_num(x_scaled, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
|
| 535 |
-
# Use a safe perplexity relative to sample count (sklearn requirement: < n_samples).
|
| 536 |
-
max_perplexity = max(5, min(30, len(x_scaled) // 10))
|
| 537 |
-
perplexity = min(max_perplexity, len(x_scaled) - 1)
|
| 538 |
-
perplexity = max(perplexity, 5)
|
| 539 |
-
|
| 540 |
-
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
|
| 541 |
-
embedding = tsne.fit_transform(x_scaled)
|
| 542 |
-
|
| 543 |
-
class_names = sort_snr_labels(list(np.unique(labels)))
|
| 544 |
-
colors = plt.cm.Set3(np.linspace(0, 1, len(class_names)))
|
| 545 |
-
for color, class_name in zip(colors, class_names):
|
| 546 |
-
mask = labels == class_name
|
| 547 |
-
ax.scatter(embedding[mask, 0], embedding[mask, 1], c=[color], s=18, alpha=0.7, label=class_name)
|
| 548 |
-
|
| 549 |
-
# ax.set_title(title, fontsize=14, fontweight="bold") # Title removed for paper
|
| 550 |
-
ax.set_xlabel("t-SNE Component 1", fontsize=16)
|
| 551 |
-
ax.set_ylabel("t-SNE Component 2", fontsize=16)
|
| 552 |
-
ax.tick_params(labelsize=14) # Increase tick label size
|
| 553 |
-
ax.grid(True, alpha=0.3)
|
| 554 |
-
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=12)
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
def compute_metrics(name: str, features: np.ndarray, labels: np.ndarray) -> None:
|
| 558 |
-
if len(np.unique(labels)) < 2:
|
| 559 |
-
print(f"[METRIC] {name}: skipped (only one class present)")
|
| 560 |
-
return
|
| 561 |
-
|
| 562 |
-
scaler = StandardScaler()
|
| 563 |
-
features_scaled = scaler.fit_transform(features)
|
| 564 |
-
|
| 565 |
-
silhouette = silhouette_score(features_scaled, labels)
|
| 566 |
-
|
| 567 |
-
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
| 568 |
-
scores: List[float] = []
|
| 569 |
-
for train_idx, test_idx in skf.split(features_scaled, labels):
|
| 570 |
-
clf = KNeighborsClassifier(n_neighbors=5)
|
| 571 |
-
clf.fit(features_scaled[train_idx], labels[train_idx])
|
| 572 |
-
scores.append(clf.score(features_scaled[test_idx], labels[test_idx]))
|
| 573 |
-
|
| 574 |
-
mean_acc = float(np.mean(scores))
|
| 575 |
-
std_acc = float(np.std(scores))
|
| 576 |
-
print(
|
| 577 |
-
f"[METRIC] {name}: silhouette={silhouette:.3f}, "
|
| 578 |
-
f"5-NN accuracy={mean_acc:.3f} ± {std_acc:.3f}"
|
| 579 |
-
)
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
# ---------------------------------------------------------------------------
|
| 583 |
-
# Main execution
|
| 584 |
-
# ---------------------------------------------------------------------------
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
def main() -> None:
|
| 588 |
-
args = parse_args()
|
| 589 |
-
|
| 590 |
-
if args.profile:
|
| 591 |
-
preset = PROFILE_PRESETS.get(args.profile)
|
| 592 |
-
if not preset:
|
| 593 |
-
raise ValueError(f"Unknown profile requested: {args.profile}")
|
| 594 |
-
if args.data_root == DEFAULT_DATA_ROOT:
|
| 595 |
-
args.data_root = preset["data_root"]
|
| 596 |
-
if args.models_root == DEFAULT_MODELS_ROOT:
|
| 597 |
-
args.models_root = preset["models_root"]
|
| 598 |
-
|
| 599 |
-
if args.profile:
|
| 600 |
-
print(f"[INFO] Profile preset active: {args.profile}")
|
| 601 |
-
|
| 602 |
-
random.seed(args.seed)
|
| 603 |
-
np.random.seed(args.seed)
|
| 604 |
-
torch.manual_seed(args.seed)
|
| 605 |
-
|
| 606 |
-
data_root = Path(args.data_root)
|
| 607 |
-
if not data_root.exists():
|
| 608 |
-
raise FileNotFoundError(f"Data root not found: {data_root}")
|
| 609 |
-
|
| 610 |
-
allowed_snrs = parse_snr_list(args.snrs)
|
| 611 |
-
|
| 612 |
-
mobility_filter: set[str] | None = None
|
| 613 |
-
if args.mobility:
|
| 614 |
-
mobility_values: List[str] = []
|
| 615 |
-
for value in args.mobility:
|
| 616 |
-
mobility_values.extend([item.strip() for item in value.split(",") if item.strip()])
|
| 617 |
-
mobility_values = [value for value in mobility_values if value]
|
| 618 |
-
if mobility_values and not (len(mobility_values) == 1 and mobility_values[0].lower() == "all"):
|
| 619 |
-
mobility_filter = {value.lower() for value in mobility_values}
|
| 620 |
-
print(
|
| 621 |
-
"[INFO] Mobility filter active: "
|
| 622 |
-
+ ", ".join(sorted(mobility_filter))
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
class_samples = list_snr_samples(
|
| 626 |
-
data_root,
|
| 627 |
-
args.modulation,
|
| 628 |
-
allowed_snrs,
|
| 629 |
-
mobility_filter,
|
| 630 |
-
args.fft_folder,
|
| 631 |
-
args.samples_per_snr,
|
| 632 |
-
random,
|
| 633 |
-
args.sampling_mode,
|
| 634 |
-
args.complex_mode,
|
| 635 |
-
)
|
| 636 |
-
samples, snr_labels, modulation_labels, mobility_labels, _ = sample_balanced_dataset(class_samples)
|
| 637 |
-
|
| 638 |
-
if args.label_field == "snr":
|
| 639 |
-
labels = snr_labels
|
| 640 |
-
label_name = "SNR"
|
| 641 |
-
label_display = "SNR"
|
| 642 |
-
elif args.label_field == "modulation":
|
| 643 |
-
labels = modulation_labels
|
| 644 |
-
label_name = "modulation"
|
| 645 |
-
label_display = "Modulation"
|
| 646 |
-
else: # mobility
|
| 647 |
-
labels = mobility_labels
|
| 648 |
-
label_name = "mobility"
|
| 649 |
-
label_display = "Mobility"
|
| 650 |
-
|
| 651 |
-
unique_labels = np.unique(labels)
|
| 652 |
-
print(
|
| 653 |
-
f"[INFO] Loaded {samples.shape[0]} spectrograms across {len(unique_labels)} {label_name} buckets"
|
| 654 |
-
)
|
| 655 |
-
class_counts = Counter(labels)
|
| 656 |
-
print(f"[INFO] Samples per {label_name}:")
|
| 657 |
-
for name, count in sorted(class_counts.items()):
|
| 658 |
-
print(f" {name}: {count}")
|
| 659 |
-
|
| 660 |
-
if args.label_field != "snr":
|
| 661 |
-
snr_counts = Counter(snr_labels)
|
| 662 |
-
print("[INFO] SNR distribution (sampling classes):")
|
| 663 |
-
for name, count in sorted(snr_counts.items()):
|
| 664 |
-
print(f" {name}: {count}")
|
| 665 |
-
if args.label_field == "mobility":
|
| 666 |
-
modulation_counts = Counter(modulation_labels)
|
| 667 |
-
print("[INFO] Modulation distribution:")
|
| 668 |
-
for name, count in sorted(modulation_counts.items()):
|
| 669 |
-
print(f" {name}: {count}")
|
| 670 |
-
|
| 671 |
-
normalization_mode = args.normalization
|
| 672 |
-
if normalization_mode == "per-sample":
|
| 673 |
-
normalized_samples = normalize_per_sample(samples)
|
| 674 |
-
else:
|
| 675 |
-
normalized_samples = normalize_dataset(samples)
|
| 676 |
-
print(f"[INFO] Normalisation mode: {normalization_mode}")
|
| 677 |
-
|
| 678 |
-
# Flatten spectrograms (after optional normalization) for the raw t-SNE view.
|
| 679 |
-
raw_vectors = normalized_samples.reshape(normalized_samples.shape[0], -1)
|
| 680 |
-
|
| 681 |
-
# Prepare LWM model and embeddings for the right subplot.
|
| 682 |
-
if args.checkpoint:
|
| 683 |
-
checkpoint_path = Path(args.checkpoint)
|
| 684 |
-
if not checkpoint_path.exists():
|
| 685 |
-
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 686 |
-
else:
|
| 687 |
-
checkpoint_path = find_latest_checkpoint(Path(args.models_root))
|
| 688 |
-
print(f"[INFO] Using checkpoint: {checkpoint_path}")
|
| 689 |
-
|
| 690 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 691 |
-
print(f"[INFO] Using device: {device}")
|
| 692 |
-
print(f"[INFO] Pooling strategy: {args.pooling}")
|
| 693 |
-
# Determine complex handling strategy for model/patching
|
| 694 |
-
use_interleaved = False
|
| 695 |
-
if args.complex_mode == "interleaved":
|
| 696 |
-
use_interleaved = True
|
| 697 |
-
elif args.complex_mode == "auto":
|
| 698 |
-
# Heuristic: if any sample contains width > 128, assume interleaved (e.g., 128x256)
|
| 699 |
-
sample_shape = tuple(normalized_samples.shape[1:])
|
| 700 |
-
if len(sample_shape) == 2 and sample_shape[1] > 128:
|
| 701 |
-
use_interleaved = True
|
| 702 |
-
|
| 703 |
-
element_length = 32 if use_interleaved else 16
|
| 704 |
-
|
| 705 |
-
model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
|
| 706 |
-
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 707 |
-
if any(k.startswith("module.") for k in state_dict):
|
| 708 |
-
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 709 |
-
try:
|
| 710 |
-
model.load_state_dict(state_dict, strict=False)
|
| 711 |
-
except RuntimeError as e:
|
| 712 |
-
msg = str(e)
|
| 713 |
-
# Fallback: checkpoint expects element_length=16 (magnitude), but we constructed 32 (interleaved)
|
| 714 |
-
mismatch16 = "[128, 16]" in msg or "[16]" in msg
|
| 715 |
-
mismatch32 = "[128, 32]" in msg or "[32]" in msg
|
| 716 |
-
if mismatch16 and not mismatch32:
|
| 717 |
-
print("[WARN] Checkpoint expects token dimension 16. Falling back to magnitude embedding.")
|
| 718 |
-
use_interleaved = False
|
| 719 |
-
element_length = 16
|
| 720 |
-
# Recreate model and reload
|
| 721 |
-
model = lwm_model(element_length=element_length, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
|
| 722 |
-
model.load_state_dict(state_dict, strict=False)
|
| 723 |
-
else:
|
| 724 |
-
raise
|
| 725 |
-
model = model.to(device).eval()
|
| 726 |
-
|
| 727 |
-
def collapse_interleaved_to_magnitude(spec: np.ndarray) -> np.ndarray:
|
| 728 |
-
# spec: [H, 2W] with interleaved real/imag along width -> [H, W] magnitude
|
| 729 |
-
h, w2 = spec.shape
|
| 730 |
-
if w2 % 2 != 0:
|
| 731 |
-
return spec # cannot collapse; return as-is
|
| 732 |
-
real = spec[:, 0::2]
|
| 733 |
-
imag = spec[:, 1::2]
|
| 734 |
-
return np.sqrt(np.maximum(real * real + imag * imag, 0.0, dtype=np.float32))
|
| 735 |
-
|
| 736 |
-
# If we fell back to magnitude (use_interleaved False) but inputs are interleaved, collapse for embeddings only
|
| 737 |
-
embed_inputs = normalized_samples
|
| 738 |
-
if not use_interleaved and normalized_samples.shape[2] > 128:
|
| 739 |
-
collapsed = []
|
| 740 |
-
for spec in normalized_samples:
|
| 741 |
-
collapsed.append(collapse_interleaved_to_magnitude(spec))
|
| 742 |
-
embed_inputs = np.stack(collapsed).astype(np.float32, copy=False)
|
| 743 |
-
|
| 744 |
-
embeddings: List[np.ndarray] = []
|
| 745 |
-
for spec in embed_inputs:
|
| 746 |
-
tokens = extract_tokens(spec, device, interleaved=use_interleaved)
|
| 747 |
-
embedding = pool_embeddings(tokens, model, args.pooling)
|
| 748 |
-
embeddings.append(embedding.squeeze(0))
|
| 749 |
-
|
| 750 |
-
embeddings_np = np.vstack(embeddings)
|
| 751 |
-
print(f"[INFO] Generated embeddings with shape {embeddings_np.shape}")
|
| 752 |
-
|
| 753 |
-
if args.report_metrics:
|
| 754 |
-
compute_metrics("Raw spectrogram", raw_vectors, labels)
|
| 755 |
-
pool_label = "LWM mean" if args.pooling == "mean" else "LWM CLS"
|
| 756 |
-
compute_metrics(pool_label, embeddings_np, labels)
|
| 757 |
-
if args.metrics_only:
|
| 758 |
-
return
|
| 759 |
-
|
| 760 |
-
# Plot results (two subplots matching the original figure format).
|
| 761 |
-
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
|
| 762 |
-
raw_title = f"Raw Spectrogram t-SNE (by {label_display})"
|
| 763 |
-
pooling_label = "Mean Pool" if args.pooling == "mean" else "CLS Token"
|
| 764 |
-
embedding_title = f"LWM Embedding t-SNE ({pooling_label}, by {label_display})"
|
| 765 |
-
run_tsne(raw_vectors, labels, raw_title, axes[0])
|
| 766 |
-
run_tsne(embeddings_np, labels, embedding_title, axes[1])
|
| 767 |
-
|
| 768 |
-
fig.tight_layout()
|
| 769 |
-
save_path = Path(args.save_path)
|
| 770 |
-
|
| 771 |
-
communication_tag: str | None = None
|
| 772 |
-
if args.profile:
|
| 773 |
-
communication_tag = args.profile
|
| 774 |
-
else:
|
| 775 |
-
root_name = Path(args.data_root).name
|
| 776 |
-
if root_name:
|
| 777 |
-
communication_tag = root_name
|
| 778 |
-
|
| 779 |
-
def ensure_suffix(stem: str, suffix: str) -> str:
|
| 780 |
-
return stem if stem.endswith(suffix) else f"{stem}_{suffix}"
|
| 781 |
-
|
| 782 |
-
updated_stem = save_path.stem
|
| 783 |
-
if communication_tag:
|
| 784 |
-
updated_stem = ensure_suffix(updated_stem, communication_tag)
|
| 785 |
-
if args.label_field != "snr":
|
| 786 |
-
label_suffix = f"by_{args.label_field}"
|
| 787 |
-
updated_stem = ensure_suffix(updated_stem, label_suffix)
|
| 788 |
-
|
| 789 |
-
if updated_stem != save_path.stem:
|
| 790 |
-
save_path = save_path.with_name(f"{updated_stem}{save_path.suffix}")
|
| 791 |
-
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 792 |
-
plt.savefig(save_path, dpi=600, bbox_inches="tight")
|
| 793 |
-
print(f"[INFO] Figure saved to {save_path}")
|
| 794 |
-
|
| 795 |
-
# Also save PDF version for paper (vector format, no resolution limit)
|
| 796 |
-
pdf_path = save_path.with_suffix('.pdf')
|
| 797 |
-
plt.savefig(pdf_path, format='pdf', bbox_inches="tight")
|
| 798 |
-
print(f"[INFO] PDF version saved to {pdf_path}")
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
if __name__ == "__main__":
|
| 802 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task1/train_mcs_models.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
task2/mobility_utils.py
DELETED
|
@@ -1,414 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Shared mobility-classification utilities used across Task 2 helpers.
|
| 3 |
-
|
| 4 |
-
This module provides the lightweight LWM classifier head plus supporting
|
| 5 |
-
sampling and normalization helpers that were previously bundled inside the
|
| 6 |
-
stand-alone mobility fine-tuning scripts. They remain available so that
|
| 7 |
-
benchmarking, router training, and visualisation pipelines can reuse the same
|
| 8 |
-
logic without depending on a separate CLI.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
-
import glob
|
| 14 |
-
import json
|
| 15 |
-
from collections import defaultdict
|
| 16 |
-
from pathlib import Path
|
| 17 |
-
from typing import Any, Dict, Iterable, List, Sequence, Tuple
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import torch
|
| 21 |
-
import torch.nn as nn
|
| 22 |
-
import torch.nn.functional as F
|
| 23 |
-
|
| 24 |
-
from pretraining.pretrained_model import lwm as lwm_model
|
| 25 |
-
from task1.train_mcs_models import (
|
| 26 |
-
_extract_metadata,
|
| 27 |
-
identify_modulation,
|
| 28 |
-
load_all_samples,
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
MOBILITY_LABELS = ["static", "pedestrian", "vehicular"]
|
| 32 |
-
BINARY_MOBILITY_LABELS = ["vehicular", "pedestrian"]
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def load_dataset_stats(models_root: Path) -> Dict[str, float | str]:
|
| 36 |
-
"""Load dataset statistics (mean/std/normalization mode) from a models directory."""
|
| 37 |
-
stats_path = models_root / "dataset_stats.json"
|
| 38 |
-
if not stats_path.exists():
|
| 39 |
-
print(
|
| 40 |
-
f"[WARN] dataset_stats.json not found under {models_root}; "
|
| 41 |
-
"falling back to per-sample normalization with mean=0/std=1.",
|
| 42 |
-
flush=True,
|
| 43 |
-
)
|
| 44 |
-
return {"mean": 0.0, "std": 1.0, "normalization": "per_sample"}
|
| 45 |
-
with open(stats_path, "r", encoding="utf-8") as f:
|
| 46 |
-
stats = json.load(f)
|
| 47 |
-
mean = float(stats.get("mean", 0.0))
|
| 48 |
-
std = float(stats.get("std", 1.0))
|
| 49 |
-
if std == 0.0:
|
| 50 |
-
std = 1.0
|
| 51 |
-
normalization = str(stats.get("normalization", stats.get("mode", "dataset")))
|
| 52 |
-
return {
|
| 53 |
-
"mean": mean,
|
| 54 |
-
"std": std,
|
| 55 |
-
"normalization": normalization,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def gather_controlled_groups(
|
| 60 |
-
data_root: Path,
|
| 61 |
-
cities: Sequence[str],
|
| 62 |
-
comm: str,
|
| 63 |
-
mobilities: Sequence[str],
|
| 64 |
-
snrs: Sequence[str] | None,
|
| 65 |
-
fft_whitelist: Sequence[str] | None,
|
| 66 |
-
) -> Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]]:
|
| 67 |
-
"""Group spectrogram paths by (city, modulation, rate, SNR, FFT) while balancing mobilities."""
|
| 68 |
-
groups: Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list))
|
| 69 |
-
mobility_set = set(mobilities)
|
| 70 |
-
snr_set = set(snrs) if snrs else None
|
| 71 |
-
fft_set = set(fft_whitelist) if fft_whitelist else None
|
| 72 |
-
|
| 73 |
-
for city in cities:
|
| 74 |
-
base = data_root / city / comm
|
| 75 |
-
if not base.exists():
|
| 76 |
-
continue
|
| 77 |
-
pattern = str(base / "**" / "spectrograms" / "*.pkl")
|
| 78 |
-
for path_str in glob.iglob(pattern, recursive=True):
|
| 79 |
-
path = Path(path_str)
|
| 80 |
-
rate, snr, mobility = _extract_metadata(path.parts)
|
| 81 |
-
if mobility not in mobility_set:
|
| 82 |
-
continue
|
| 83 |
-
if snr_set is not None and snr not in snr_set:
|
| 84 |
-
continue
|
| 85 |
-
fft = next((part for part in path.parts if part.startswith("win")), "fft_unknown")
|
| 86 |
-
if fft_set is not None and fft not in fft_set:
|
| 87 |
-
continue
|
| 88 |
-
_, modulation = identify_modulation(path_str)
|
| 89 |
-
if modulation is None:
|
| 90 |
-
continue
|
| 91 |
-
key = (city, modulation, rate, snr, fft)
|
| 92 |
-
groups[key][mobility].append(str(path))
|
| 93 |
-
return {key: dict(mob_map) for key, mob_map in groups.items()}
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def _collect_balanced_arrays(
|
| 97 |
-
groups: Dict[Tuple[str, str, str, str, str], Dict[str, List[str]]],
|
| 98 |
-
mobilities: Sequence[str],
|
| 99 |
-
max_per_config: int,
|
| 100 |
-
rng: np.random.Generator,
|
| 101 |
-
) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
|
| 102 |
-
"""Load spectrogram arrays with per-configuration balance across mobilities."""
|
| 103 |
-
features: List[np.ndarray] = []
|
| 104 |
-
labels: List[np.ndarray] = []
|
| 105 |
-
mobility_to_idx = {mob: idx for idx, mob in enumerate(mobilities)}
|
| 106 |
-
per_mobility_totals = {mob: 0 for mob in mobilities}
|
| 107 |
-
matched_configs = 0
|
| 108 |
-
preview_configs: List[Tuple[str, str, str, str, str]] = []
|
| 109 |
-
|
| 110 |
-
for key, mobility_map in groups.items():
|
| 111 |
-
if not all(mob in mobility_map for mob in mobilities):
|
| 112 |
-
continue
|
| 113 |
-
|
| 114 |
-
cached_arrays: Dict[str, np.ndarray] = {}
|
| 115 |
-
per_mobility_counts: List[int] = []
|
| 116 |
-
for mobility in mobilities:
|
| 117 |
-
paths = mobility_map[mobility]
|
| 118 |
-
collected: List[np.ndarray] = []
|
| 119 |
-
for path in paths:
|
| 120 |
-
arr = load_all_samples(path)
|
| 121 |
-
if arr.size == 0:
|
| 122 |
-
continue
|
| 123 |
-
collected.append(arr)
|
| 124 |
-
if not collected:
|
| 125 |
-
cached_arrays = {}
|
| 126 |
-
break
|
| 127 |
-
stacked = np.concatenate(collected, axis=0)
|
| 128 |
-
cached_arrays[mobility] = stacked
|
| 129 |
-
per_mobility_counts.append(stacked.shape[0])
|
| 130 |
-
|
| 131 |
-
if len(cached_arrays) != len(mobilities):
|
| 132 |
-
continue
|
| 133 |
-
|
| 134 |
-
limit = min(per_mobility_counts)
|
| 135 |
-
if max_per_config > 0:
|
| 136 |
-
limit = min(limit, max_per_config)
|
| 137 |
-
if limit == 0:
|
| 138 |
-
continue
|
| 139 |
-
|
| 140 |
-
for mobility in mobilities:
|
| 141 |
-
arr = cached_arrays[mobility]
|
| 142 |
-
if arr.shape[0] > limit:
|
| 143 |
-
indices = rng.permutation(arr.shape[0])[:limit]
|
| 144 |
-
arr = arr[indices]
|
| 145 |
-
features.append(arr)
|
| 146 |
-
labels.append(np.full(arr.shape[0], mobility_to_idx[mob], dtype=np.int64))
|
| 147 |
-
per_mobility_totals[mobility] += arr.shape[0]
|
| 148 |
-
|
| 149 |
-
if matched_configs < 5:
|
| 150 |
-
preview_configs.append(key)
|
| 151 |
-
matched_configs += 1
|
| 152 |
-
|
| 153 |
-
if not features:
|
| 154 |
-
return (
|
| 155 |
-
np.empty((0, 128, 128), dtype=np.float32),
|
| 156 |
-
np.empty((0,), dtype=np.int64),
|
| 157 |
-
{"per_mobility": per_mobility_totals, "matched_configs": matched_configs, "preview_configs": preview_configs},
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
stacked_features = np.concatenate(features, axis=0).astype(np.float32, copy=False)
|
| 161 |
-
stacked_labels = np.concatenate(labels, axis=0).astype(np.int64, copy=False)
|
| 162 |
-
return stacked_features, stacked_labels, {
|
| 163 |
-
"per_mobility": per_mobility_totals,
|
| 164 |
-
"matched_configs": matched_configs,
|
| 165 |
-
"preview_configs": preview_configs,
|
| 166 |
-
}
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
class ResidualBlock1D(nn.Module):
|
| 170 |
-
"""1D Residual block used by the Res1DCNN classification head."""
|
| 171 |
-
|
| 172 |
-
def __init__(self, in_channels: int, out_channels: int) -> None:
|
| 173 |
-
super().__init__()
|
| 174 |
-
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 175 |
-
self.bn1 = nn.BatchNorm1d(out_channels)
|
| 176 |
-
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
|
| 177 |
-
self.bn2 = nn.BatchNorm1d(out_channels)
|
| 178 |
-
self.shortcut = nn.Sequential()
|
| 179 |
-
if in_channels != out_channels:
|
| 180 |
-
self.shortcut = nn.Sequential(
|
| 181 |
-
nn.Conv1d(in_channels, out_channels, kernel_size=1),
|
| 182 |
-
nn.BatchNorm1d(out_channels),
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 186 |
-
residual = x
|
| 187 |
-
x = F.relu(self.bn1(self.conv1(x)))
|
| 188 |
-
x = self.bn2(self.conv2(x))
|
| 189 |
-
x += self.shortcut(residual)
|
| 190 |
-
return F.relu(x)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
class Res1DCNNHead(nn.Module):
|
| 194 |
-
"""Compact ResNet-style 1D head for classifying 128-d embeddings."""
|
| 195 |
-
|
| 196 |
-
def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.5) -> None:
|
| 197 |
-
super().__init__()
|
| 198 |
-
hidden_dim = 64
|
| 199 |
-
self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1)
|
| 200 |
-
self.bn1 = nn.BatchNorm1d(hidden_dim)
|
| 201 |
-
self.res_block = ResidualBlock1D(hidden_dim, hidden_dim)
|
| 202 |
-
self.fc = nn.Linear(hidden_dim, num_classes)
|
| 203 |
-
self.dropout = nn.Dropout(dropout)
|
| 204 |
-
|
| 205 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 206 |
-
x = x.unsqueeze(1)
|
| 207 |
-
x = F.relu(self.bn1(self.conv1(x)))
|
| 208 |
-
x = self.res_block(x)
|
| 209 |
-
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
|
| 210 |
-
x = self.dropout(x)
|
| 211 |
-
return self.fc(x)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
class LWMClassifierMinimal(nn.Module):
|
| 215 |
-
"""LWM backbone wrapper with configurable classifier and optional projection head."""
|
| 216 |
-
|
| 217 |
-
def __init__(
|
| 218 |
-
self,
|
| 219 |
-
backbone: nn.Module,
|
| 220 |
-
num_classes: int,
|
| 221 |
-
classifier_dim: int,
|
| 222 |
-
dropout: float,
|
| 223 |
-
trainable_layers: int,
|
| 224 |
-
projection_dim: int,
|
| 225 |
-
append_input_stats: bool,
|
| 226 |
-
normalization_stats: Dict[str, object] | None,
|
| 227 |
-
head_type: str = "mlp",
|
| 228 |
-
) -> None:
|
| 229 |
-
super().__init__()
|
| 230 |
-
self.backbone = backbone
|
| 231 |
-
self.patch_size = 4
|
| 232 |
-
self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size)
|
| 233 |
-
self.head_type = head_type
|
| 234 |
-
|
| 235 |
-
self.append_input_stats = bool(append_input_stats)
|
| 236 |
-
stats_info = normalization_stats or {}
|
| 237 |
-
self.normalization_mode = str(stats_info.get("normalization", "dataset")).lower()
|
| 238 |
-
self.dataset_mean = float(stats_info.get("mean", 0.0))
|
| 239 |
-
self.dataset_std = float(stats_info.get("std", 1.0))
|
| 240 |
-
if abs(self.dataset_std) < 1e-6:
|
| 241 |
-
self.dataset_std = 1e-6
|
| 242 |
-
base_dim = 128
|
| 243 |
-
stats_dim = 2 if self.append_input_stats else 0
|
| 244 |
-
input_dim = base_dim + stats_dim
|
| 245 |
-
|
| 246 |
-
classifier_dim = max(32, int(classifier_dim))
|
| 247 |
-
dropout = max(0.0, float(dropout))
|
| 248 |
-
|
| 249 |
-
if head_type == "linear":
|
| 250 |
-
self.classifier = nn.Sequential(
|
| 251 |
-
nn.LayerNorm(input_dim),
|
| 252 |
-
nn.Linear(input_dim, num_classes),
|
| 253 |
-
)
|
| 254 |
-
elif head_type == "res1dcnn":
|
| 255 |
-
self.classifier = nn.Sequential(
|
| 256 |
-
nn.LayerNorm(input_dim),
|
| 257 |
-
Res1DCNNHead(input_dim, num_classes, dropout=dropout),
|
| 258 |
-
)
|
| 259 |
-
else:
|
| 260 |
-
head_layers: List[nn.Module] = [
|
| 261 |
-
nn.LayerNorm(input_dim),
|
| 262 |
-
nn.Linear(input_dim, classifier_dim),
|
| 263 |
-
nn.GELU(),
|
| 264 |
-
]
|
| 265 |
-
if dropout > 0:
|
| 266 |
-
head_layers.append(nn.Dropout(dropout))
|
| 267 |
-
head_layers.append(nn.Linear(classifier_dim, num_classes))
|
| 268 |
-
self.classifier = nn.Sequential(*head_layers)
|
| 269 |
-
|
| 270 |
-
proj_dim = int(projection_dim)
|
| 271 |
-
if proj_dim > 0:
|
| 272 |
-
self.projection_head = nn.Sequential(
|
| 273 |
-
nn.Linear(128, proj_dim),
|
| 274 |
-
nn.ReLU(inplace=True),
|
| 275 |
-
nn.Linear(proj_dim, proj_dim),
|
| 276 |
-
)
|
| 277 |
-
else:
|
| 278 |
-
self.projection_head = None
|
| 279 |
-
|
| 280 |
-
for param in self.backbone.parameters():
|
| 281 |
-
param.requires_grad = False
|
| 282 |
-
|
| 283 |
-
if trainable_layers > 0:
|
| 284 |
-
layers = getattr(self.backbone, "layers", None)
|
| 285 |
-
if layers is not None:
|
| 286 |
-
trainable_layers = min(trainable_layers, len(layers))
|
| 287 |
-
for layer in layers[-trainable_layers:]:
|
| 288 |
-
for param in layer.parameters():
|
| 289 |
-
param.requires_grad = True
|
| 290 |
-
|
| 291 |
-
def spectrogram_to_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
| 292 |
-
x = x.unsqueeze(1)
|
| 293 |
-
patches = self.unfold(x).transpose(1, 2)
|
| 294 |
-
cls_token = torch.full(
|
| 295 |
-
(patches.size(0), 1, patches.size(-1)),
|
| 296 |
-
0.2,
|
| 297 |
-
dtype=patches.dtype,
|
| 298 |
-
device=patches.device,
|
| 299 |
-
)
|
| 300 |
-
return torch.cat([cls_token, patches], dim=1)
|
| 301 |
-
|
| 302 |
-
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 303 |
-
tokens = self.spectrogram_to_tokens(x)
|
| 304 |
-
outputs = self.backbone(tokens)
|
| 305 |
-
if outputs.size(1) <= 1:
|
| 306 |
-
return outputs[:, 0, :]
|
| 307 |
-
return outputs[:, 1:, :].mean(dim=1)
|
| 308 |
-
|
| 309 |
-
def _collect_input_stats(self, x: torch.Tensor) -> torch.Tensor:
|
| 310 |
-
mean = x.mean(dim=(1, 2))
|
| 311 |
-
std = x.std(dim=(1, 2), unbiased=False)
|
| 312 |
-
if self.normalization_mode == "dataset":
|
| 313 |
-
mean = mean * self.dataset_std + self.dataset_mean
|
| 314 |
-
std = std * self.dataset_std
|
| 315 |
-
return torch.stack([mean, std], dim=1)
|
| 316 |
-
|
| 317 |
-
def forward(
|
| 318 |
-
self,
|
| 319 |
-
x: torch.Tensor,
|
| 320 |
-
*,
|
| 321 |
-
input_stats: torch.Tensor | None = None,
|
| 322 |
-
return_projection: bool = False,
|
| 323 |
-
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
| 324 |
-
features = self.forward_features(x)
|
| 325 |
-
classifier_input = features
|
| 326 |
-
if self.append_input_stats:
|
| 327 |
-
stats = input_stats if input_stats is not None else self._collect_input_stats(x)
|
| 328 |
-
if stats.dtype != classifier_input.dtype:
|
| 329 |
-
stats = stats.to(classifier_input.dtype)
|
| 330 |
-
stats = stats.to(classifier_input.device)
|
| 331 |
-
classifier_input = torch.cat([classifier_input, stats], dim=1)
|
| 332 |
-
logits = self.classifier(classifier_input)
|
| 333 |
-
if return_projection:
|
| 334 |
-
projection = self.projection_head(features) if self.projection_head is not None else None
|
| 335 |
-
return logits, projection
|
| 336 |
-
return logits
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
def prepare_model(
|
| 340 |
-
checkpoint: Path,
|
| 341 |
-
num_classes: int,
|
| 342 |
-
classifier_dim: int,
|
| 343 |
-
dropout: float,
|
| 344 |
-
trainable_layers: int,
|
| 345 |
-
projection_dim: int,
|
| 346 |
-
*,
|
| 347 |
-
append_input_stats: bool = False,
|
| 348 |
-
normalization_stats: Dict[str, object] | None = None,
|
| 349 |
-
head_type: str = "mlp",
|
| 350 |
-
) -> nn.Module:
|
| 351 |
-
"""Instantiate an LWM backbone with the minimal classifier head."""
|
| 352 |
-
backbone = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1)
|
| 353 |
-
state = torch.load(checkpoint, map_location="cpu")
|
| 354 |
-
if any(k.startswith("module.") for k in state):
|
| 355 |
-
state = {k.replace("module.", ""): v for k, v in state.items()}
|
| 356 |
-
backbone.load_state_dict(state, strict=False)
|
| 357 |
-
return LWMClassifierMinimal(
|
| 358 |
-
backbone,
|
| 359 |
-
num_classes=num_classes,
|
| 360 |
-
classifier_dim=classifier_dim,
|
| 361 |
-
dropout=dropout,
|
| 362 |
-
trainable_layers=trainable_layers,
|
| 363 |
-
projection_dim=projection_dim,
|
| 364 |
-
append_input_stats=append_input_stats,
|
| 365 |
-
normalization_stats=normalization_stats,
|
| 366 |
-
head_type=head_type,
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
def supervised_contrastive_loss(
|
| 371 |
-
features: torch.Tensor,
|
| 372 |
-
labels: torch.Tensor,
|
| 373 |
-
temperature: float,
|
| 374 |
-
) -> torch.Tensor:
|
| 375 |
-
"""Supervised contrastive loss over a batch of feature embeddings."""
|
| 376 |
-
batch_size = features.size(0)
|
| 377 |
-
if batch_size < 2:
|
| 378 |
-
return features.new_tensor(0.0)
|
| 379 |
-
|
| 380 |
-
features = F.normalize(features, dim=1)
|
| 381 |
-
similarity = torch.div(torch.matmul(features, features.T), max(temperature, 1e-6))
|
| 382 |
-
logits_max, _ = similarity.max(dim=1, keepdim=True)
|
| 383 |
-
similarity = similarity - logits_max.detach()
|
| 384 |
-
|
| 385 |
-
device = features.device
|
| 386 |
-
labels = labels.contiguous().view(-1, 1)
|
| 387 |
-
mask = torch.eq(labels, labels.T).float().to(device)
|
| 388 |
-
logits_mask = torch.ones_like(mask) - torch.eye(batch_size, device=device)
|
| 389 |
-
mask = mask * logits_mask
|
| 390 |
-
|
| 391 |
-
exp_logits = torch.exp(similarity) * logits_mask
|
| 392 |
-
log_prob = similarity - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
|
| 393 |
-
|
| 394 |
-
mask_sum = mask.sum(dim=1)
|
| 395 |
-
valid = mask_sum > 0
|
| 396 |
-
if not torch.any(valid):
|
| 397 |
-
return features.new_tensor(0.0)
|
| 398 |
-
|
| 399 |
-
mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask_sum.clamp_min(1e-12)
|
| 400 |
-
loss = -mean_log_prob_pos[valid].mean()
|
| 401 |
-
return loss
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
__all__ = [
|
| 405 |
-
"BINARY_MOBILITY_LABELS",
|
| 406 |
-
"LWMClassifierMinimal",
|
| 407 |
-
"MOBILITY_LABELS",
|
| 408 |
-
"Res1DCNNHead",
|
| 409 |
-
"_collect_balanced_arrays",
|
| 410 |
-
"gather_controlled_groups",
|
| 411 |
-
"load_dataset_stats",
|
| 412 |
-
"prepare_model",
|
| 413 |
-
"supervised_contrastive_loss",
|
| 414 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|