“Namhyun-Kim” commited on
Commit
2b1a1e3
·
1 Parent(s): aebafe2

Sync app to fetch data from wi-lab/lwm-spectro

Browse files
README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LWM-Spectro Lab
3
+ emoji: 🔍
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: "6.0.1"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # LWM-Spectro Lab
13
+
14
+ One-stop lab for exploring spectrograms, LWM embeddings, and lightweight evaluation baselines.
15
+
16
+ ## Features
17
+ - Visualize LWM embeddings or raw spectrograms with customizable filters.
18
+ - Inspect joint SNR/Doppler performance using cached MoE embeddings and an adaptive k-NN classifier.
19
+ - Upload your own datasets to compare raw channels vs. model embeddings.
20
+
21
+ ## Usage
22
+ 1. Select the **Spectrograms** and **t-SNE Analysis** tabs to explore embeddings.
23
+ 2. Switch to **Modulation Classification** or **Joint SNR/Doppler Evaluation** to run the k-NN prototype with adjustable train/test splits.
24
+ 3. Provide custom data (optional) to benchmark against bundled samples.
app.py CHANGED
@@ -1,52 +1,55 @@
1
-
2
- import json
3
- import sys
4
  from pathlib import Path
5
- from typing import Dict, List, Optional, Sequence, Tuple
6
 
7
  import gradio as gr
 
 
8
  import numpy as np
9
  import pandas as pd
10
  import plotly.express as px
11
  import plotly.graph_objects as go
12
  import torch
 
13
  from sklearn.decomposition import PCA
14
  from sklearn.manifold import TSNE
15
- from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
 
 
16
 
17
- REPO_ROOT = Path(__file__).resolve().parents[1]
18
  APP_DIR = Path(__file__).resolve().parent
19
  DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
20
  MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
21
- MOE_CHECKPOINT = REPO_ROOT / "mixture" / "runs" / "embedding_router" / "moe_checkpoint.pth"
22
- SNR_MOB_MAPPING_PATH = REPO_ROOT / "mixture" / "runs" / "embedding_router" / "snr_mobility_mapping.json"
23
-
24
- if str(REPO_ROOT) not in sys.path:
25
- sys.path.append(str(REPO_ROOT))
26
-
27
- from mixture.train_embedding_router import ( # type: ignore
28
- MoEPredictor,
29
- compute_selected_expert_embeddings,
30
- normalize_per_sample_tensor,
31
- stack_expert_embeddings,
32
- )
33
-
34
-
35
- def load_joint_mapping() -> Optional[Dict[str, object]]:
36
- if not SNR_MOB_MAPPING_PATH.exists():
37
- print(f"[WARN] Mapping file not found at {SNR_MOB_MAPPING_PATH}")
38
- return None
39
- raw = json.loads(SNR_MOB_MAPPING_PATH.read_text())
40
- ordered_pairs: List[Tuple[str, str]] = []
41
- for key in sorted(raw.keys(), key=lambda k: int(k)):
42
- snr, mob = raw[key]
43
- ordered_pairs.append((snr, mob))
44
- label_names = [f"{snr} | {mob}" for snr, mob in ordered_pairs]
45
- pair_to_name = {pair: name for pair, name in zip(ordered_pairs, label_names)}
46
  name_to_id = {name: idx for idx, name in enumerate(label_names)}
47
- pair_to_id = {pair: idx for idx, pair in enumerate(ordered_pairs)}
48
  return {
49
- "pairs": ordered_pairs,
50
  "label_names": label_names,
51
  "pair_to_name": pair_to_name,
52
  "name_to_id": name_to_id,
@@ -54,77 +57,43 @@ def load_joint_mapping() -> Optional[Dict[str, object]]:
54
  }
55
 
56
 
57
- def compute_moe_embeddings(
58
- samples: Sequence[Dict[str, object]],
59
- predictor: MoEPredictor,
60
- batch_size: int = 64,
61
- ) -> torch.Tensor:
62
- router = predictor.router
63
- experts = predictor.experts
64
- device = predictor.device
65
- embeddings: List[torch.Tensor] = []
66
-
67
- with torch.no_grad():
68
- for start in range(0, len(samples), batch_size):
69
- batch = samples[start : start + batch_size]
70
- specs = torch.cat([sample["data"] for sample in batch], dim=0).to(device)
71
- specs_norm = normalize_per_sample_tensor(specs)
72
-
73
- if router is not None:
74
- router_logits = router(specs_norm)
75
- probs = torch.softmax(router_logits, dim=1)
76
- topk_vals, topk_idx = probs.topk(k=predictor.topk, dim=1)
77
- weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6)
78
- selected_embeddings = compute_selected_expert_embeddings(
79
- experts,
80
- specs_norm,
81
- topk_idx,
82
- allow_grad=False,
83
- )
84
- weighted = (weights.unsqueeze(-1) * selected_embeddings).sum(dim=1)
85
- else:
86
- stacked = stack_expert_embeddings(experts, specs_norm)
87
- weighted = stacked.mean(dim=1)
88
 
89
- embeddings.append(weighted.cpu())
90
 
91
- return torch.cat(embeddings, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
- def ensure_moe_embeddings(samples: List[Dict[str, object]]) -> Tuple[List[Dict[str, object]], bool]:
 
 
95
  if MOE_DATA_PATH.exists():
96
- cached = torch.load(MOE_DATA_PATH)
97
- if len(cached) == len(samples):
98
- print(f"[INFO] Loaded cached MoE embeddings from {MOE_DATA_PATH}")
99
- return cached, True
100
- print("[WARN] Cached MoE embeddings length mismatch. Recomputing...")
101
-
102
- if not MOE_CHECKPOINT.exists():
103
- print(f"[WARN] MoE checkpoint not found at {MOE_CHECKPOINT}. Skipping MoE embeddings.")
104
- return samples, False
105
-
106
- print("[INFO] Computing MoE embeddings using router checkpoint...")
107
- predictor = MoEPredictor.from_checkpoint(MOE_CHECKPOINT)
108
- moe_embeddings = compute_moe_embeddings(samples, predictor)
109
- for sample, emb in zip(samples, moe_embeddings):
110
- sample["moe_embedding"] = emb.detach().cpu()
111
-
112
- torch.save(samples, MOE_DATA_PATH)
113
- print(f"[INFO] Saved MoE-augmented dataset to {MOE_DATA_PATH}")
114
- return samples, True
115
-
116
-
117
- def load_data(mapping: Optional[Dict[str, object]]):
118
  if not DEMO_DATA_PATH.exists():
119
  raise FileNotFoundError(f"Dataset not found at {DEMO_DATA_PATH}")
 
 
120
 
121
- print(f"[INFO] Loading base dataset from {DEMO_DATA_PATH}")
122
- data: List[Dict[str, object]] = torch.load(DEMO_DATA_PATH)
123
- data, has_moe = ensure_moe_embeddings(data)
124
-
125
- pair_to_name = mapping["pair_to_name"] if mapping else {}
126
- pair_to_id = mapping["pair_to_id"] if mapping else {}
127
 
 
 
 
 
 
128
  records = []
129
  for i, sample in enumerate(data):
130
  embedding = sample["embedding"]
@@ -149,9 +118,14 @@ def load_data(mapping: Optional[Dict[str, object]]):
149
  joint_label = pair_to_name.get(pair)
150
  joint_label_id = pair_to_id.get(pair)
151
 
 
 
 
 
 
152
  records.append(
153
  {
154
- "index": i,
155
  "tech": sample["tech"],
156
  "snr": sample["snr"],
157
  "mod": sample["mod"],
@@ -161,11 +135,15 @@ def load_data(mapping: Optional[Dict[str, object]]):
161
  "spectrogram": flat_spec,
162
  "joint_label": joint_label,
163
  "joint_label_id": joint_label_id,
 
 
 
 
164
  }
165
  )
166
 
167
  df = pd.DataFrame(records)
168
- print(f"[INFO] Loaded {len(df)} samples.")
169
  return df, has_moe
170
 
171
 
@@ -188,50 +166,94 @@ def apply_filters(
188
  return filtered
189
 
190
 
191
- def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter):
192
  filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
193
  if len(filtered_df) < 5:
194
- return None, f"Not enough data points ({len(filtered_df)}). Need at least 5."
 
 
195
 
196
- if representation == "LWM Embedding":
197
- features = np.stack(filtered_df["embedding"].values)
 
 
198
  else:
199
- features = np.stack(filtered_df["spectrogram"].values)
200
- if features.shape[1] > 50:
201
- pca = PCA(n_components=50, random_state=42)
202
- features = pca.fit_transform(features)
203
-
204
- eff_perplexity = min(perplexity, len(filtered_df) - 1)
205
- tsne = TSNE(
206
- n_components=2,
207
- perplexity=eff_perplexity,
208
- n_iter=n_iter,
209
- random_state=42,
210
- init="pca",
211
- learning_rate="auto",
212
- )
213
- projections = tsne.fit_transform(features)
214
- filtered_df = filtered_df.copy()
215
- filtered_df["x"] = projections[:, 0]
216
- filtered_df["y"] = projections[:, 1]
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  fig = px.scatter(
219
  filtered_df,
220
  x="x",
221
  y="y",
222
- color=color_by,
223
  hover_data=["tech", "snr", "mod", "mob"],
224
  title=f"t-SNE of {representation} ({len(filtered_df)} samples)",
225
  template="plotly_white",
226
  )
227
- fig.update_layout(legend_title_text=color_by.capitalize())
228
- return fig, f"Displayed {len(filtered_df)} samples."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
232
  rng = np.random.default_rng(int(seed))
233
- train_indices: List[int] = []
234
- test_indices: List[int] = []
235
 
236
  for label_id, group in filtered_df.groupby("joint_label_id"):
237
  indices = group.index.to_numpy()
@@ -247,47 +269,22 @@ def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -
247
  return np.array(train_indices), np.array(test_indices)
248
 
249
 
250
- def compute_centroid_metrics(filtered_df: pd.DataFrame, train_idx: np.ndarray, test_idx: np.ndarray) -> Dict[str, object]:
251
- train_subset = filtered_df.loc[train_idx]
252
- test_subset = filtered_df.loc[test_idx]
253
-
254
- train_embeddings = np.stack(train_subset["moe_embedding"].values)
255
- test_embeddings = np.stack(test_subset["moe_embedding"].values)
256
- train_labels = train_subset["joint_label_id"].to_numpy(dtype=int)
257
- test_labels = test_subset["joint_label_id"].to_numpy(dtype=int)
258
-
259
- unique_labels = np.unique(train_labels)
260
- centroids = []
261
- centroid_ids: List[int] = []
262
- for label_id in unique_labels:
263
- mask = train_labels == label_id
264
- centroids.append(train_embeddings[mask].mean(axis=0))
265
- centroid_ids.append(int(label_id))
266
-
267
- centroids = np.stack(centroids)
268
- centroid_ids = np.array(centroid_ids, dtype=int)
269
-
270
- dists = ((test_embeddings[:, None, :] - centroids[None, :, :]) ** 2).sum(axis=-1)
271
- preds = centroid_ids[np.argmin(dists, axis=1)]
272
-
273
- accuracy = accuracy_score(test_labels, preds)
274
- macro_f1 = f1_score(test_labels, preds, average="macro", labels=centroid_ids, zero_division=0)
275
-
276
- active_ids = sorted(np.unique(np.concatenate([test_labels, preds])))
277
- label_names = [CLASS_LABELS[i] for i in active_ids]
278
- cm = confusion_matrix(test_labels, preds, labels=active_ids)
279
 
280
- return {
281
- "accuracy": accuracy,
282
- "macro_f1": macro_f1,
283
- "confusion": cm,
284
- "label_names": label_names,
285
- "train_size": len(train_idx),
286
- "test_size": len(test_idx),
287
- }
288
 
289
-
290
- def plot_confusion_heatmap(confusion: np.ndarray, label_names: List[str]) -> go.Figure:
 
291
  fig = go.Figure(
292
  data=go.Heatmap(
293
  z=confusion,
@@ -298,7 +295,7 @@ def plot_confusion_heatmap(confusion: np.ndarray, label_names: List[str]) -> go.
298
  )
299
  )
300
  fig.update_layout(
301
- title="Prototype Classifier Confusion Matrix",
302
  xaxis_title="Predicted",
303
  yaxis_title="True",
304
  xaxis=dict(tickangle=45),
@@ -307,70 +304,312 @@ def plot_confusion_heatmap(confusion: np.ndarray, label_names: List[str]) -> go.
307
 
308
 
309
  def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter):
310
- if joint_eval_df.empty:
311
  fig = go.Figure()
312
  fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False))
313
- return fig, "MoE embeddings are not available for evaluation."
314
 
315
  filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter)
316
  if filtered.empty:
317
  fig = go.Figure()
318
  fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False))
319
- return fig, "No samples match the selected filters."
320
 
321
  if filtered["joint_label_id"].nunique() < 2:
322
  fig = go.Figure()
323
  fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False))
324
- return fig, "Need at least two joint SNR/Doppler classes to evaluate."
 
 
325
 
326
  try:
327
  train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed)
328
  except ValueError as exc:
329
  fig = go.Figure()
330
  fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False))
331
- return fig, str(exc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- metrics = compute_centroid_metrics(filtered, train_idx, test_idx)
334
- fig = plot_confusion_heatmap(metrics["confusion"], metrics["label_names"])
335
  status = (
336
- f"Train samples: {metrics['train_size']}\n"
337
- f"Test samples: {metrics['test_size']}\n"
338
- f"Accuracy: {metrics['accuracy'] * 100:.2f}%\n"
339
- f"Macro F1: {metrics['macro_f1']:.3f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  )
341
- return fig, status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
 
344
  mapping_info = load_joint_mapping()
345
  df, has_moe_embeddings = load_data(mapping_info)
346
- CLASS_LABELS: List[str] = mapping_info["label_names"] if mapping_info else []
347
 
348
- joint_eval_df = df.copy()
349
- joint_eval_df = joint_eval_df[joint_eval_df["joint_label_id"].notna()]
350
- joint_eval_df = joint_eval_df[joint_eval_df["moe_embedding"].notna()]
351
 
352
  tech_choices = sorted(df["tech"].unique())
353
  snr_choices = sorted(df["snr"].unique())
354
  mod_choices = sorted(df["mod"].unique())
355
  mob_choices = sorted(df["mob"].unique())
356
 
357
- evaluation_disabled = joint_eval_df.empty
 
 
 
 
 
 
 
 
 
 
 
358
 
359
- with gr.Blocks(title="LWM-Spectro Demo") as demo:
 
 
 
 
 
 
 
 
 
360
  gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
361
  gr.Markdown(
362
  """
363
- Compare **LWM embeddings** vs **Raw Spectrograms** for visualization, then evaluate **MoE embeddings**
364
- with a lightweight prototype classifier for joint SNR/Doppler recognition.
365
  """
366
  )
367
 
368
  with gr.Tabs():
369
- with gr.Tab("Visualization"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  with gr.Row():
371
  with gr.Column(scale=1, min_width=300):
372
  gr.Markdown("### Filters")
373
- tech_filter = gr.CheckboxGroup(choices=tech_choices, value=tech_choices, label="Technology")
374
  snr_filter = gr.Dropdown(
375
  choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)"
376
  )
@@ -387,14 +626,17 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
387
  value="LWM Embedding",
388
  label="Representation",
389
  )
390
- color_by = gr.Dropdown(choices=["tech", "snr", "mod", "mob"], value="tech", label="Color By")
 
 
 
 
391
 
392
  with gr.Accordion("Advanced t-SNE Settings", open=False):
393
  perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity")
394
  n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
395
 
396
  btn = gr.Button("Update Plot", variant="primary")
397
- status = gr.Textbox(label="Status", interactive=False)
398
 
399
  with gr.Column(scale=3):
400
  plot = gr.Plot(label="t-SNE Visualization")
@@ -402,19 +644,43 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
402
  btn.click(
403
  plot_tsne,
404
  inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
405
- outputs=[plot, status],
406
  )
407
 
408
  demo.load(
409
  plot_tsne,
410
  inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
411
- outputs=[plot, status],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  )
413
 
414
- with gr.Tab("Evaludation (Joint SNR/Doppler)"):
415
  if evaluation_disabled:
416
  gr.Markdown(
417
- "⚠️ MoE embeddings are unavailable. Ensure `demo_data_moe.pt` exists or the checkpoint is present."
418
  )
419
 
420
  with gr.Row():
@@ -422,7 +688,7 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
422
  gr.Markdown("### Evaluation Filters")
423
  eval_tech_filter = gr.CheckboxGroup(
424
  choices=tech_choices,
425
- value=tech_choices,
426
  label="Technology",
427
  interactive=not evaluation_disabled,
428
  )
@@ -468,13 +734,15 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
468
  eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled)
469
 
470
  with gr.Column(scale=3):
471
- eval_plot = gr.Plot(label="Prototype Confusion Matrix")
472
- eval_status = gr.Textbox(label="Metrics", interactive=False)
 
 
473
 
474
  eval_btn.click(
475
  run_joint_evaluation,
476
  inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter],
477
- outputs=[eval_plot, eval_status],
478
  )
479
 
480
  if __name__ == "__main__":
 
1
+ import os
2
+ import shutil
 
3
  from pathlib import Path
4
+ from typing import Dict, List, Tuple, Optional
5
 
6
  import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
9
  import numpy as np
10
  import pandas as pd
11
  import plotly.express as px
12
  import plotly.graph_objects as go
13
  import torch
14
+ from huggingface_hub import hf_hub_download
15
  from sklearn.decomposition import PCA
16
  from sklearn.manifold import TSNE
17
+ from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
18
+ from sklearn.neighbors import KNeighborsClassifier
19
+ from sklearn.preprocessing import StandardScaler
20
 
 
21
  APP_DIR = Path(__file__).resolve().parent
22
  DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
23
  MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
24
+ HUB_REPO_ID = "wi-lab/lwm-spectro"
25
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_HUB_TOKEN")
26
+
27
+ # Fixed ordering for the 14 joint SNR/Doppler labels
28
+ JOINT_LABELS = [
29
+ ("SNR-5dB", "pedestrian"),
30
+ ("SNR-5dB", "vehicular"),
31
+ ("SNR0dB", "pedestrian"),
32
+ ("SNR0dB", "vehicular"),
33
+ ("SNR5dB", "pedestrian"),
34
+ ("SNR5dB", "vehicular"),
35
+ ("SNR10dB", "pedestrian"),
36
+ ("SNR10dB", "vehicular"),
37
+ ("SNR15dB", "pedestrian"),
38
+ ("SNR15dB", "vehicular"),
39
+ ("SNR20dB", "pedestrian"),
40
+ ("SNR20dB", "vehicular"),
41
+ ("SNR25dB", "pedestrian"),
42
+ ("SNR25dB", "vehicular"),
43
+ ]
44
+
45
+
46
+ def load_joint_mapping() -> Dict[str, object]:
47
+ label_names = [f"{snr} | {mob}" for snr, mob in JOINT_LABELS]
48
+ pair_to_name = {pair: name for pair, name in zip(JOINT_LABELS, label_names)}
49
  name_to_id = {name: idx for idx, name in enumerate(label_names)}
50
+ pair_to_id = {pair: idx for idx, pair in enumerate(JOINT_LABELS)}
51
  return {
52
+ "pairs": JOINT_LABELS,
53
  "label_names": label_names,
54
  "pair_to_name": pair_to_name,
55
  "name_to_id": name_to_id,
 
57
  }
58
 
59
 
60
+ def _safe_load_tensor(path: Path):
61
+ # Torch 2.6 defaults to weights_only=True, which breaks our saved dicts.
62
+ return torch.load(path, weights_only=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
64
 
65
+ def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]:
66
+ """Ensure a file exists locally; try Hub download if missing."""
67
+ if local_path.exists():
68
+ return local_path
69
+ try:
70
+ cached = hf_hub_download(repo_id=HUB_REPO_ID, filename=hub_filename, token=HF_TOKEN)
71
+ cached_path = Path(cached)
72
+ shutil.copyfile(cached_path, local_path)
73
+ print(f"[INFO] Downloaded {hub_filename} from Hub to {local_path}")
74
+ return local_path
75
+ except Exception as exc:
76
+ print(f"[WARN] Could not download {hub_filename} from Hub ({exc}); continuing without it.")
77
+ return None
78
 
79
 
80
+ def load_augmented_samples() -> Tuple[List[Dict[str, object]], bool]:
81
+ _ensure_local_file(MOE_DATA_PATH, "demo_data_moe.pt")
82
+ _ensure_local_file(DEMO_DATA_PATH, "demo_data.pt")
83
  if MOE_DATA_PATH.exists():
84
+ print(f"[INFO] Loading MoE-augmented dataset from {MOE_DATA_PATH}")
85
+ return _safe_load_tensor(MOE_DATA_PATH), True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if not DEMO_DATA_PATH.exists():
87
  raise FileNotFoundError(f"Dataset not found at {DEMO_DATA_PATH}")
88
+ print(f"[WARN] {MOE_DATA_PATH} missing; falling back to base data only")
89
+ return _safe_load_tensor(DEMO_DATA_PATH), False
90
 
 
 
 
 
 
 
91
 
92
+ def load_data(mapping: Dict[str, object]):
93
+ data, has_moe = load_augmented_samples()
94
+ pair_to_name = mapping["pair_to_name"]
95
+ pair_to_id = mapping["pair_to_id"]
96
+
97
  records = []
98
  for i, sample in enumerate(data):
99
  embedding = sample["embedding"]
 
118
  joint_label = pair_to_name.get(pair)
119
  joint_label_id = pair_to_id.get(pair)
120
 
121
+ tsne_x = sample.get("tsne_x")
122
+ tsne_y = sample.get("tsne_y")
123
+ tsne_raw_x = sample.get("tsne_raw_x")
124
+ tsne_raw_y = sample.get("tsne_raw_y")
125
+
126
  records.append(
127
  {
128
+ "index": i,
129
  "tech": sample["tech"],
130
  "snr": sample["snr"],
131
  "mod": sample["mod"],
 
135
  "spectrogram": flat_spec,
136
  "joint_label": joint_label,
137
  "joint_label_id": joint_label_id,
138
+ "tsne_x": tsne_x,
139
+ "tsne_y": tsne_y,
140
+ "tsne_raw_x": tsne_raw_x,
141
+ "tsne_raw_y": tsne_raw_y,
142
  }
143
  )
144
 
145
  df = pd.DataFrame(records)
146
+ print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})")
147
  return df, has_moe
148
 
149
 
 
166
  return filtered
167
 
168
 
169
+ def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, color_label, perplexity, n_iter):
170
  filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
171
  if len(filtered_df) < 5:
172
+ return None
173
+
174
+ color_column = COLOR_OPTIONS.get(color_label, "snr")
175
 
176
+ tsne_cols = ("tsne_x", "tsne_y") if representation == "LWM Embedding" else ("tsne_raw_x", "tsne_raw_y")
177
+ has_cached = all(col in filtered_df.columns for col in tsne_cols)
178
+ if has_cached:
179
+ valid = filtered_df[tsne_cols[0]].notna().all() and filtered_df[tsne_cols[1]].notna().all()
180
  else:
181
+ valid = False
182
+
183
+ if valid:
184
+ filtered_df = filtered_df.copy()
185
+ filtered_df["x"] = filtered_df[tsne_cols[0]]
186
+ filtered_df["y"] = filtered_df[tsne_cols[1]]
187
+ else:
188
+ sampled_df = filtered_df
189
+ if len(sampled_df) > 1200:
190
+ sampled_df = sampled_df.sample(n=1200, random_state=42)
191
+ sampled_df = sampled_df.copy()
 
 
 
 
 
 
 
192
 
193
+ if representation == "LWM Embedding":
194
+ features = np.stack(sampled_df["embedding"].values)
195
+ else:
196
+ features = np.stack(sampled_df["spectrogram"].values)
197
+ if features.shape[1] > 50:
198
+ pca = PCA(n_components=50, random_state=42)
199
+ features = pca.fit_transform(features)
200
+
201
+ eff_perplexity = min(perplexity, len(sampled_df) - 1)
202
+ eff_perplexity = max(5, eff_perplexity)
203
+ tsne = TSNE(
204
+ n_components=2,
205
+ perplexity=eff_perplexity,
206
+ n_iter=n_iter,
207
+ random_state=42,
208
+ init="pca",
209
+ learning_rate="auto",
210
+ )
211
+ try:
212
+ projections = tsne.fit_transform(features)
213
+ except Exception as exc:
214
+ pca = PCA(n_components=2, random_state=42)
215
+ projections = pca.fit_transform(features)
216
+ sampled_df["x"] = projections[:, 0]
217
+ sampled_df["y"] = projections[:, 1]
218
+ filtered_df = sampled_df
219
  fig = px.scatter(
220
  filtered_df,
221
  x="x",
222
  y="y",
223
+ color=color_column,
224
  hover_data=["tech", "snr", "mod", "mob"],
225
  title=f"t-SNE of {representation} ({len(filtered_df)} samples)",
226
  template="plotly_white",
227
  )
228
+ height = 680 if color_label == "SNR" else 640
229
+ fig.update_layout(
230
+ legend_title_text=color_label,
231
+ width=640,
232
+ height=height,
233
+ )
234
+ fig.update_yaxes(scaleanchor="x", scaleratio=1)
235
+ return fig
236
+
237
+
238
+ def build_raw_feature_matrix(samples: pd.Series, max_components: int = 256) -> np.ndarray:
239
+ raw_flat = []
240
+ for spec in samples:
241
+ arr = np.asarray(spec, dtype=np.float32)
242
+ raw_flat.append(arr.reshape(-1))
243
+ matrix = np.stack(raw_flat)
244
+ matrix = np.nan_to_num(matrix, copy=False)
245
+ scaler = StandardScaler()
246
+ matrix = scaler.fit_transform(matrix)
247
+ if max_components and matrix.shape[1] > max_components:
248
+ projector = PCA(n_components=max_components, random_state=42)
249
+ matrix = projector.fit_transform(matrix)
250
+ return matrix
251
 
252
 
253
  def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
254
  rng = np.random.default_rng(int(seed))
255
+ train_indices = []
256
+ test_indices = []
257
 
258
  for label_id, group in filtered_df.groupby("joint_label_id"):
259
  indices = group.index.to_numpy()
 
269
  return np.array(train_indices), np.array(test_indices)
270
 
271
 
272
+ def select_knn_k(train_labels: np.ndarray, max_k: int = 9) -> int:
273
+ if train_labels.size == 0:
274
+ return 1
275
+ class_counts = pd.Series(train_labels).value_counts()
276
+ min_class = int(class_counts.min())
277
+ heuristic = int(np.sqrt(train_labels.size))
278
+ candidate = max(1, min(max_k, heuristic))
279
+ k = max(1, min(candidate, min_class))
280
+ if k % 2 == 0 and k > 1:
281
+ k -= 1
282
+ return k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
 
 
 
 
 
 
 
 
284
 
285
+ def plot_confusion_heatmap(
286
+ confusion: np.ndarray, label_names: List[str], title: str = "Prototype Classifier Confusion Matrix"
287
+ ) -> go.Figure:
288
  fig = go.Figure(
289
  data=go.Heatmap(
290
  z=confusion,
 
295
  )
296
  )
297
  fig.update_layout(
298
+ title=title,
299
  xaxis_title="Predicted",
300
  yaxis_title="True",
301
  xaxis=dict(tickangle=45),
 
304
 
305
 
306
  def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter):
307
+ if evaluation_disabled:
308
  fig = go.Figure()
309
  fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False))
310
+ return fig, fig, "MoE embeddings are not available in this Space build."
311
 
312
  filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter)
313
  if filtered.empty:
314
  fig = go.Figure()
315
  fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False))
316
+ return fig, fig, "No samples match the selected filters."
317
 
318
  if filtered["joint_label_id"].nunique() < 2:
319
  fig = go.Figure()
320
  fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False))
321
+ return fig, fig, "Need at least two joint SNR/Doppler classes to evaluate."
322
+
323
+ filtered = filtered.reset_index(drop=True)
324
 
325
  try:
326
  train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed)
327
  except ValueError as exc:
328
  fig = go.Figure()
329
  fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False))
330
+ return fig, fig, str(exc)
331
+
332
+ labels = filtered["joint_label_id"].to_numpy(dtype=int)
333
+ moe_features = np.stack(filtered["moe_embedding"].values)
334
+ raw_features = build_raw_feature_matrix(filtered["spectrogram"], max_components=256)
335
+
336
+ train_labels = labels[train_idx]
337
+ knn_k = select_knn_k(train_labels)
338
+
339
+ moe_metrics = compute_knn_metrics(moe_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS)
340
+ raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS)
341
+
342
+ moe_fig = plot_confusion_heatmap(
343
+ moe_metrics["confusion"], moe_metrics["label_names"], title=f"MoE Embedding Confusion (k={moe_metrics['k']})"
344
+ )
345
+ raw_fig = plot_confusion_heatmap(
346
+ raw_metrics["confusion"], raw_metrics["label_names"], title=f"Raw Spectrogram Confusion (k={raw_metrics['k']})"
347
+ )
348
 
 
 
349
  status = (
350
+ f"### Joint SNR/Doppler Metrics\n"
351
+ f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Train %:** {train_pct}% | **Seed:** {seed} | **k-NN k:** {knn_k}\n\n"
352
+ "| Representation | Accuracy | Macro F1 |\n"
353
+ "| --- | --- | --- |\n"
354
+ f"| **MoE Embedding** | {moe_metrics['accuracy'] * 100:.2f}% | {moe_metrics['macro_f1']:.3f} |\n"
355
+ f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |"
356
+ )
357
+ return moe_fig, raw_fig, status
358
+
359
+
360
+ def stratified_split_mod(df_subset: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
361
+ rng = np.random.default_rng(int(seed))
362
+ train_idx = []
363
+ test_idx = []
364
+ for _, group in df_subset.groupby("mod"):
365
+ indices = group.index.to_numpy()
366
+ if indices.size < 2:
367
+ raise ValueError("Each modulation needs at least 2 samples.")
368
+ rng.shuffle(indices)
369
+ split = int(round(len(indices) * train_ratio))
370
+ split = max(1, min(len(indices) - 1, split))
371
+ train_idx.extend(indices[:split])
372
+ test_idx.extend(indices[split:])
373
+ return np.array(train_idx), np.array(test_idx)
374
+
375
+
376
+ def compute_knn_metrics(
377
+ features: np.ndarray,
378
+ labels: np.ndarray,
379
+ train_idx: np.ndarray,
380
+ test_idx: np.ndarray,
381
+ knn_k: int,
382
+ label_lookup: List[str] | None = None,
383
+ ) -> Dict[str, object]:
384
+ train_features = features[train_idx]
385
+ test_features = features[test_idx]
386
+ train_labels = labels[train_idx]
387
+ test_labels = labels[test_idx]
388
+
389
+ candidate_k = max(1, min(int(knn_k), len(train_labels)))
390
+ if candidate_k % 2 == 0 and candidate_k > 1:
391
+ candidate_k -= 1
392
+ knn = KNeighborsClassifier(n_neighbors=candidate_k, metric="euclidean")
393
+ knn.fit(train_features, train_labels)
394
+ preds = knn.predict(test_features)
395
+
396
+ acc = accuracy_score(test_labels, preds)
397
+ active_labels = np.unique(np.concatenate([train_labels, test_labels, preds]))
398
+ macro = f1_score(test_labels, preds, labels=active_labels, average="macro", zero_division=0)
399
+
400
+ if label_lookup is None:
401
+ label_names = [str(lbl) for lbl in active_labels]
402
+ else:
403
+ label_names = [label_lookup[int(lbl)] for lbl in active_labels]
404
+
405
+ cm = confusion_matrix(test_labels, preds, labels=active_labels)
406
+ return {
407
+ "accuracy": acc,
408
+ "macro_f1": macro,
409
+ "confusion": cm,
410
+ "label_names": label_names,
411
+ "k": candidate_k,
412
+ }
413
+
414
+
415
+ def evaluate_modulation(tech: str, train_pct: int, seed: int):
416
+ if not tech:
417
+ fig = go.Figure()
418
+ fig.update_layout(title="Select a technology to evaluate.", xaxis=dict(visible=False), yaxis=dict(visible=False))
419
+ return fig, fig, "No technology selected."
420
+
421
+ subset = df[df["tech"] == tech].copy().reset_index(drop=True)
422
+ if subset.empty or subset["mod"].nunique() < 2:
423
+ fig = go.Figure()
424
+ fig.update_layout(
425
+ title="Need at least two modulation classes for this technology.",
426
+ xaxis=dict(visible=False),
427
+ yaxis=dict(visible=False),
428
+ )
429
+ return fig, fig, "Not enough modulation classes."
430
+
431
+ try:
432
+ train_idx, test_idx = stratified_split_mod(subset, train_pct / 100.0, seed)
433
+ except ValueError as exc:
434
+ fig = go.Figure()
435
+ fig.update_layout(title=str(exc), xaxis=dict(visible=False), yaxis=dict(visible=False))
436
+ return fig, fig, str(exc)
437
+
438
+ labels = subset["mod"].astype(str).to_numpy()
439
+ emb_features = np.stack(subset["embedding"].values)
440
+
441
+ raw_features = build_raw_feature_matrix(subset["spectrogram"], max_components=256)
442
+
443
+ train_labels = labels[train_idx]
444
+ class_counts = pd.Series(train_labels).value_counts()
445
+ if class_counts.empty:
446
+ fig = go.Figure()
447
+ fig.update_layout(title="No modulation classes found.", xaxis=dict(visible=False), yaxis=dict(visible=False))
448
+ return fig, fig, "No modulation classes found."
449
+
450
+ knn_k = select_knn_k(train_labels)
451
+
452
+ emb_metrics = compute_knn_metrics(emb_features, labels, train_idx, test_idx, knn_k)
453
+ raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k)
454
+
455
+ emb_fig = plot_confusion_heatmap(emb_metrics["confusion"], emb_metrics["label_names"], title="Embedding Confusion")
456
+ raw_fig = plot_confusion_heatmap(raw_metrics["confusion"], raw_metrics["label_names"], title="Raw Confusion")
457
+
458
+ summary = (
459
+ f"### {tech} Modulation Metrics\n"
460
+ f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Classifier:** k-NN (k = {emb_metrics['k']})\n\n"
461
+ "| Representation | Accuracy | Macro F1 |\n"
462
+ "| --- | --- | --- |\n"
463
+ f"| **LWM Embedding** | {emb_metrics['accuracy'] * 100:.2f}% | {emb_metrics['macro_f1']:.3f} |\n"
464
+ f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |"
465
  )
466
+ return emb_fig, raw_fig, summary
467
+
468
+
469
+ def _reshape_spectrogram(spec: np.ndarray) -> np.ndarray:
470
+ arr = np.asarray(spec)
471
+ if arr.ndim == 1:
472
+ side = int(round(arr.size ** 0.5))
473
+ if side * side == arr.size:
474
+ arr = arr.reshape(side, side)
475
+ else:
476
+ arr = arr.reshape(-1, side)
477
+ elif arr.ndim == 3:
478
+ arr = arr.squeeze()
479
+ return arr
480
+
481
+
482
+ def _spectrogram_to_image(spec: np.ndarray, title: str) -> np.ndarray:
483
+ normalized = spec.astype(np.float32)
484
+ if np.isnan(normalized).any():
485
+ normalized = np.nan_to_num(normalized)
486
+ vmin, vmax = normalized.min(), normalized.max()
487
+ if vmax - vmin > 0:
488
+ normalized = (normalized - vmin) / (vmax - vmin)
489
+ fig, ax = plt.subplots(figsize=(3, 3))
490
+ im = ax.imshow(normalized, cmap="turbo", aspect="auto", origin="lower")
491
+ ax.set_xticks([])
492
+ ax.set_yticks([])
493
+ ax.set_title(title, fontsize=8)
494
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
495
+ cbar.ax.tick_params(labelsize=6)
496
+ fig.tight_layout(pad=0.5)
497
+ canvas = FigureCanvasAgg(fig)
498
+ canvas.draw()
499
+ width, height = canvas.get_width_height()
500
+ buf = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(height, width, 4)
501
+ image = buf[..., :3].copy()
502
+ plt.close(fig)
503
+ return image
504
+
505
+
506
+ def render_spectrogram_gallery(tech, snr, mod, mob, sample_count, seed):
507
+ tech_list = [tech] if tech else None
508
+ snr_list = [snr] if snr else None
509
+ mod_list = [mod] if mod else None
510
+ mob_list = [mob] if mob else None
511
+
512
+ filtered = apply_filters(df, tech_list, snr_list, mod_list, mob_list)
513
+ if filtered.empty:
514
+ return [], "No spectrograms match the selected filters."
515
+
516
+ sample_count = max(1, int(sample_count))
517
+ rng = np.random.default_rng(int(seed))
518
+ if len(filtered) > sample_count:
519
+ indices = rng.choice(filtered.index.to_numpy(), size=sample_count, replace=False)
520
+ subset = filtered.loc[indices]
521
+ else:
522
+ subset = filtered
523
+
524
+ gallery_items = []
525
+ for _, row in subset.iterrows():
526
+ spec = _reshape_spectrogram(row["spectrogram"])
527
+ caption = f"{row['tech']} | {row['mod']} | {row['snr']} | {row['mob']}"
528
+ img = _spectrogram_to_image(spec, caption)
529
+ gallery_items.append((img, caption))
530
+
531
+ status = f"Showing {len(subset)} spectrograms (seed={seed})."
532
+ return gallery_items, status
533
 
534
 
535
  mapping_info = load_joint_mapping()
536
  df, has_moe_embeddings = load_data(mapping_info)
537
+ CLASS_LABELS = mapping_info["label_names"]
538
 
539
+ has_moe_column = df["moe_embedding"].apply(lambda x: x is not None)
540
+ joint_eval_df = df[has_moe_column & df["joint_label_id"].notna()]
 
541
 
542
  tech_choices = sorted(df["tech"].unique())
543
  snr_choices = sorted(df["snr"].unique())
544
  mod_choices = sorted(df["mod"].unique())
545
  mob_choices = sorted(df["mob"].unique())
546
 
547
+ TECH_TO_MODS: Dict[str, List[str]] = {
548
+ tech: sorted(df.loc[df["tech"] == tech, "mod"].unique().tolist()) for tech in tech_choices
549
+ }
550
+
551
+ COLOR_OPTIONS: Dict[str, str] = {
552
+ "SNR": "snr",
553
+ "Modulation": "mod",
554
+ "Mobility": "mob",
555
+ }
556
+
557
+ default_tech = tech_choices[:1] if tech_choices else []
558
+ initial_spec_mod_choices = TECH_TO_MODS.get(default_tech[0], mod_choices) if default_tech else mod_choices
559
 
560
+ evaluation_disabled = (not has_moe_embeddings) or joint_eval_df.empty
561
+
562
+
563
+ def update_modulation_choices(selected_tech: Optional[str]):
564
+ choices = mod_choices
565
+ if selected_tech:
566
+ choices = TECH_TO_MODS.get(selected_tech, mod_choices)
567
+ return gr.Dropdown.update(choices=choices, value=None)
568
+
569
+ with gr.Blocks(title="LWM-Spectro Lab") as demo:
570
  gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
571
  gr.Markdown(
572
  """
573
+ Compare **LWM embeddings** vs **Raw Spectrograms** for visualization, then evaluate **precomputed MoE embeddings**
574
+ with a lightweight k-NN prototype classifier for joint SNR/Doppler recognition.
575
  """
576
  )
577
 
578
  with gr.Tabs():
579
+ with gr.Tab("Spectrograms"):
580
+ gr.Markdown("Visualize raw 128×128 spectrograms with optional filters.")
581
+ with gr.Row():
582
+ with gr.Column(scale=1, min_width=320):
583
+ spec_tech = gr.Dropdown(
584
+ choices=tech_choices,
585
+ value=default_tech[0] if default_tech else None,
586
+ label="Technology",
587
+ )
588
+ spec_snr = gr.Dropdown(choices=snr_choices, value=None, label="SNR (optional)")
589
+ spec_mod = gr.Dropdown(choices=initial_spec_mod_choices, value=None, label="Modulation (optional)")
590
+ spec_mob = gr.Dropdown(choices=mob_choices, value=None, label="Mobility (optional)")
591
+ spec_count = gr.Slider(minimum=1, maximum=12, step=1, value=6, label="Samples to show")
592
+ spec_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=0, label="Random seed")
593
+ spec_btn = gr.Button("Show spectrograms", variant="primary")
594
+ with gr.Column(scale=3):
595
+ gallery = gr.Gallery(
596
+ label="Spectrogram Samples",
597
+ columns=[3],
598
+ rows=[3],
599
+ height=560,
600
+ preview=True,
601
+ )
602
+ gallery_status = gr.Textbox(label="Status", interactive=False)
603
+ spec_inputs = [spec_tech, spec_snr, spec_mod, spec_mob, spec_count, spec_seed]
604
+ spec_btn.click(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status])
605
+ demo.load(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status])
606
+ spec_tech.change(update_modulation_choices, inputs=spec_tech, outputs=spec_mod)
607
+
608
+ with gr.Tab("t-SNE Analysis"):
609
  with gr.Row():
610
  with gr.Column(scale=1, min_width=300):
611
  gr.Markdown("### Filters")
612
+ tech_filter = gr.CheckboxGroup(choices=tech_choices, value=default_tech, label="Technology")
613
  snr_filter = gr.Dropdown(
614
  choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)"
615
  )
 
626
  value="LWM Embedding",
627
  label="Representation",
628
  )
629
+ color_by = gr.Dropdown(
630
+ choices=list(COLOR_OPTIONS.keys()),
631
+ value="SNR",
632
+ label="Color By",
633
+ )
634
 
635
  with gr.Accordion("Advanced t-SNE Settings", open=False):
636
  perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity")
637
  n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
638
 
639
  btn = gr.Button("Update Plot", variant="primary")
 
640
 
641
  with gr.Column(scale=3):
642
  plot = gr.Plot(label="t-SNE Visualization")
 
644
  btn.click(
645
  plot_tsne,
646
  inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
647
+ outputs=[plot],
648
  )
649
 
650
  demo.load(
651
  plot_tsne,
652
  inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
653
+ outputs=[plot],
654
+ )
655
+
656
+ with gr.Tab("Modulation Classification"):
657
+ gr.Markdown("Compare LWM embeddings vs raw spectrograms for per-technology modulation classification.")
658
+ with gr.Row():
659
+ with gr.Column(scale=1, min_width=320):
660
+ mod_tech = gr.Dropdown(
661
+ choices=tech_choices,
662
+ value=default_tech[0] if default_tech else None,
663
+ label="Technology",
664
+ )
665
+ mod_train = gr.Slider(minimum=50, maximum=90, step=5, value=70, label="Training Percentage (%)")
666
+ mod_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
667
+ gr.Markdown("k-NN uses an adaptive k based on the number of modulation classes and available training samples.")
668
+ mod_btn = gr.Button("Run modulation evaluation", variant="primary")
669
+ with gr.Column(scale=3):
670
+ with gr.Row():
671
+ emb_plot = gr.Plot(label="Embedding Confusion Matrix")
672
+ raw_plot = gr.Plot(label="Raw Confusion Matrix")
673
+ mod_summary = gr.Markdown(value="Select a technology and run the evaluation to view metrics.")
674
+ mod_btn.click(
675
+ evaluate_modulation,
676
+ inputs=[mod_tech, mod_train, mod_seed],
677
+ outputs=[emb_plot, raw_plot, mod_summary],
678
  )
679
 
680
+ with gr.Tab("Joint SNR/Doppler Evaluation"):
681
  if evaluation_disabled:
682
  gr.Markdown(
683
+ "⚠️ Precomputed MoE embeddings are not bundled in this Space build. Upload a dataset locally to run evaluations."
684
  )
685
 
686
  with gr.Row():
 
688
  gr.Markdown("### Evaluation Filters")
689
  eval_tech_filter = gr.CheckboxGroup(
690
  choices=tech_choices,
691
+ value=default_tech,
692
  label="Technology",
693
  interactive=not evaluation_disabled,
694
  )
 
734
  eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled)
735
 
736
  with gr.Column(scale=3):
737
+ with gr.Row():
738
+ eval_plot = gr.Plot(label="MoE Prototype Confusion")
739
+ eval_plot_raw = gr.Plot(label="Raw Prototype Confusion")
740
+ eval_status = gr.Markdown(value="Run an evaluation to compare MoE vs raw baselines.")
741
 
742
  eval_btn.click(
743
  run_joint_evaluation,
744
  inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter],
745
+ outputs=[eval_plot, eval_plot_raw, eval_status],
746
  )
747
 
748
  if __name__ == "__main__":
pretraining/README.md DELETED
@@ -1,44 +0,0 @@
1
- # 🔬 Pretraining Scripts
2
-
3
- This folder contains scripts for **Large Wireless Model (LWM)** pre-training.
4
-
5
- ## 📁 File Descriptions
6
-
7
- ### `train_lwm_spectro.py`
8
- - **Purpose**: Pre-train LWM model with spectrogram data
9
- - **Features**:
10
- - Self-supervised learning through masked patch prediction
11
- - Multi-size spectrogram support (32x32, 128x128)
12
- - MSE loss-based reconstruction
13
- - Real-time training monitoring and result storage
14
-
15
- ### `pretrained_model.py`
16
- - **Purpose**: Define structure of pre-trained LWM model
17
- - **Features**: LWM model architecture implementation
18
-
19
- ## 🚀 Usage
20
-
21
- ### Basic Training Execution
22
- ```bash
23
- cd pretraining
24
- python train_lwm_spectro.py
25
- ```
26
-
27
- ### GPU Memory Optimization
28
- ```bash
29
- cd pretraining
30
- python train_lwm_spectro.py # GPU 메모리에 맞춰 batch_size 조정
31
- ```
32
-
33
- ### Check Results
34
- Training results are automatically saved in `models/` folder:
35
- - `*_checkpoint.pth`: Model checkpoint
36
- - `*_training_history.json`: Training history
37
- - `*_training_curves.png`: Training curve graphs
38
-
39
- ## 📊 Research Perspective
40
-
41
- These scripts are used to study **LWM's representation learning capabilities**:
42
- - Extract meaningful features from spectrograms
43
- - Generalized representation learning through unsupervised learning
44
- - Validate transfer learning effectiveness in downstream tasks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretraining/__init__.py DELETED
File without changes
pretraining/pretrained_model.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
-
6
-
7
- class LayerNormalization(nn.Module):
8
- """Layer norm with learnable scale and bias."""
9
-
10
- def __init__(self, d_model: int, eps: float = 1e-6) -> None:
11
- super().__init__()
12
- self.eps = eps
13
- self.alpha = nn.Parameter(torch.ones(d_model))
14
- self.bias = nn.Parameter(torch.zeros(d_model))
15
-
16
- def forward(self, x: torch.Tensor) -> torch.Tensor:
17
- mean = x.mean(dim=-1, keepdim=True)
18
- std = x.std(dim=-1, keepdim=True)
19
- return self.alpha * (x - mean) / (std + self.eps) + self.bias
20
-
21
-
22
- class Embedding(nn.Module):
23
- """Linear projection + positional embedding with optional max_len override."""
24
-
25
- def __init__(self, element_length: int, d_model: int, max_len: int | None = None) -> None:
26
- super().__init__()
27
- self.element_length = element_length
28
- self.d_model = d_model
29
- self.max_len = max_len if max_len is not None else 1025
30
-
31
- self.proj = nn.Linear(element_length, d_model)
32
- self.pos_embed = nn.Embedding(self.max_len, d_model)
33
- self.norm = LayerNormalization(d_model)
34
-
35
- def forward(self, x: torch.Tensor) -> torch.Tensor:
36
- seq_len = x.size(1)
37
- if seq_len > self.max_len:
38
- raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}.")
39
-
40
- pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
41
- pos_encodings = self.pos_embed(pos)
42
- tok_emb = self.proj(x.float())
43
- return self.norm(tok_emb + pos_encodings)
44
-
45
-
46
- class ScaledDotProductAttention(nn.Module):
47
- """Scaled dot-product attention."""
48
-
49
- def __init__(self, d_k: int) -> None:
50
- super().__init__()
51
- self.d_k = d_k
52
-
53
- def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
54
- scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
55
- attn = F.softmax(scores, dim=-1)
56
- context = torch.matmul(attn, V)
57
- return context, attn
58
-
59
-
60
- class MultiHeadAttention(nn.Module):
61
- """Multi-head self-attention module."""
62
-
63
- def __init__(self, d_model: int, n_heads: int, dropout: float) -> None:
64
- super().__init__()
65
- if d_model % n_heads != 0:
66
- raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads}).")
67
-
68
- self.d_k = d_model // n_heads
69
- self.d_v = d_model // n_heads
70
- self.n_heads = n_heads
71
-
72
- self.W_Q = nn.Linear(d_model, self.d_k * n_heads)
73
- self.W_K = nn.Linear(d_model, self.d_k * n_heads)
74
- self.W_V = nn.Linear(d_model, self.d_v * n_heads)
75
- self.linear = nn.Linear(n_heads * self.d_v, d_model)
76
- self.dropout = nn.Dropout(dropout)
77
- self.scaled_dot_attn = ScaledDotProductAttention(self.d_k)
78
-
79
- def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
80
- residual = Q
81
- batch_size = Q.size(0)
82
-
83
- q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
84
- k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
85
- v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
86
-
87
- context, attn = self.scaled_dot_attn(q_s, k_s, v_s)
88
- output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
89
- output = self.linear(output)
90
- return residual + self.dropout(output), attn
91
-
92
-
93
- class PoswiseFeedForwardNet(nn.Module):
94
- """Position-wise feed-forward network."""
95
-
96
- def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
97
- super().__init__()
98
- self.fc1 = nn.Linear(d_model, d_ff)
99
- self.fc2 = nn.Linear(d_ff, d_model)
100
- self.dropout = nn.Dropout(dropout)
101
-
102
- def forward(self, x: torch.Tensor) -> torch.Tensor:
103
- return self.fc2(self.dropout(F.relu(self.fc1(x))))
104
-
105
-
106
- class EncoderLayer(nn.Module):
107
- """Transformer encoder block."""
108
-
109
- def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float) -> None:
110
- super().__init__()
111
- self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
112
- self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff, dropout)
113
- self.norm1 = LayerNormalization(d_model)
114
- self.norm2 = LayerNormalization(d_model)
115
-
116
- def forward(self, enc_inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
117
- attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
118
- attn_outputs = self.norm1(attn_outputs)
119
- ff_outputs = self.pos_ffn(attn_outputs)
120
- enc_outputs = self.norm2(attn_outputs + ff_outputs)
121
- return enc_outputs, attn
122
-
123
-
124
- class LWM(nn.Module):
125
- """Large Wireless Model (Transformer encoder)."""
126
-
127
- def __init__(
128
- self,
129
- element_length: int = 32,
130
- d_model: int = 128,
131
- n_layers: int = 12,
132
- max_len: int | None = None,
133
- n_heads: int = 8,
134
- dropout: float = 0.1,
135
- ) -> None:
136
- super().__init__()
137
-
138
- self.element_length = element_length
139
- self.d_model = d_model
140
- self.n_layers = n_layers
141
- self.max_len = max_len if max_len is not None else 1025
142
- self.n_heads = n_heads
143
- self.dropout = dropout
144
-
145
- self.embedding = Embedding(element_length, d_model, self.max_len)
146
- self.layers = nn.ModuleList(
147
- [EncoderLayer(d_model, n_heads, d_model * 4, dropout) for _ in range(n_layers)]
148
- )
149
- self.linear = nn.Linear(d_model, d_model)
150
- self.norm = LayerNormalization(d_model)
151
-
152
- embed_weight = self.embedding.proj.weight
153
- _, n_dim = embed_weight.size()
154
- self.decoder = nn.Linear(d_model, n_dim, bias=False)
155
- self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
156
-
157
- def forward(
158
- self,
159
- input_ids: torch.Tensor,
160
- masked_pos: torch.Tensor | None = None,
161
- ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
162
- output = self.embedding(input_ids)
163
-
164
- for layer in self.layers:
165
- output, attn = layer(output)
166
-
167
- if masked_pos is not None:
168
- masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
169
- h_masked = torch.gather(output, 1, masked_pos)
170
- h_masked = self.norm(F.relu(self.linear(h_masked)))
171
- logits_lm = self.decoder(h_masked) + self.decoder_bias
172
- return logits_lm, output
173
-
174
- return output
175
-
176
-
177
- def lwm(*args, **kwargs) -> LWM:
178
- """Factory to preserve backward compatibility with older imports."""
179
-
180
- return LWM(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretraining/train_lwm_spectro.py DELETED
@@ -1,741 +0,0 @@
1
- #!/usr/bin/env python3
2
- # =============================================================================
3
- # train_lwm_spectro.py - LWM Pretraining with Complex-Valued Spectrogram Support
4
- # Modified from train_lwm_spectro_no_contrast.py to handle complex spectrograms
5
- # by separating real and imaginary parts and flattening them (similar to train_lwm.py)
6
- # =============================================================================
7
-
8
- # =============================================================================
9
- # 1. IMPORTS AND WARNINGS SETUP
10
- # - Load necessary PyTorch modules, utilities, and suppress UserWarnings
11
- # =============================================================================
12
- import sys
13
- import os
14
- import argparse
15
- # Add project root to path (Windows compatible)
16
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
- sys.path.insert(0, project_root)
18
- import torch
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- from torch.utils.data import DataLoader, random_split, TensorDataset
22
- import torch.optim as optim
23
- from utils import (generate_spectrograms_and_labels, tokenizer_train,
24
- create_train_dataloader, count_parameters, train_lwm)
25
- import numpy as np
26
- import pretrained_model # Assuming this contains the LWM model definition
27
- from torch.optim.lr_scheduler import LambdaLR
28
- from torch.optim import AdamW
29
- import warnings
30
- import platform
31
- import re
32
- from tqdm import tqdm
33
- from datetime import datetime
34
- import concurrent.futures
35
- import multiprocessing
36
- from collections import Counter
37
- from functools import lru_cache
38
- import json
39
-
40
- SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
41
- DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
42
- DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
43
-
44
-
45
- def _parse_snr_and_doppler(path: str) -> tuple[float, int]:
46
- snr_db = 0.0
47
- doppler_id = 0
48
-
49
- matches = SNR_PATTERN.findall(path)
50
- if matches:
51
- try:
52
- snr_db = float(matches[-1])
53
- except ValueError:
54
- snr_db = 0.0
55
-
56
- normalized_path = os.path.normpath(path)
57
- parts = normalized_path.split(os.sep)
58
- for part in parts:
59
- if part in DOPPLER_MAP:
60
- doppler_id = DOPPLER_MAP[part]
61
- break
62
-
63
- return snr_db, doppler_id
64
-
65
- warnings.filterwarnings("ignore", category=UserWarning)
66
-
67
- # Use simple progress display instead of tqdm on Windows
68
- USE_TQDM = platform.system() != 'Windows'
69
-
70
- # CPU 코어 수 계산 (메모리 사용량 고려하여 보수적으로 설정)
71
- total_cores = multiprocessing.cpu_count()
72
- if total_cores >= 16:
73
- MAX_WORKERS = min(8, total_cores // 2) # 고성능 서버의 경우 8코어로 제한
74
- else:
75
- MAX_WORKERS = max(2, total_cores // 2) # 일반 시스템의 경우 절반 사용
76
- print(f"🚀 Using {MAX_WORKERS}/{total_cores} CPU cores for parallel processing")
77
-
78
- PRINT_CONVERSION_STATS = os.environ.get("LWM_PRINT_CONVERSION_STATS", "").strip().lower() in {"1", "true", "yes"}
79
-
80
-
81
- def convert_complex_to_interleaved(spectrograms):
82
- """
83
- Convert complex-valued spectrograms to real-imaginary interleaved format.
84
-
85
- Similar to patch_maker() in train_lwm.py, this function:
86
- 1. Extracts real and imaginary parts
87
- 2. Interleaves them along the last dimension
88
-
89
- Args:
90
- spectrograms (np.ndarray): Complex-valued array of shape (n_samples, n_rows, n_cols)
91
- or (n_samples, 1, n_rows, n_cols)
92
-
93
- Returns:
94
- np.ndarray: Real-valued array with interleaved real/imag parts
95
- Shape: (n_samples, n_rows, n_cols * 2)
96
- """
97
- # Handle different input shapes
98
- if spectrograms.ndim == 4:
99
- # Remove channel dimension if present: (n_samples, 1, n_rows, n_cols) -> (n_samples, n_rows, n_cols)
100
- spectrograms = spectrograms[:, 0, :, :]
101
-
102
- # Check if data is complex
103
- if np.iscomplexobj(spectrograms):
104
- n_samples, n_rows, n_cols = spectrograms.shape
105
-
106
- # Extract real and imaginary parts
107
- flat_real = spectrograms.real
108
- flat_imag = spectrograms.imag
109
-
110
- # Interleave real and imaginary parts along the last axis
111
- # Output shape: (n_samples, n_rows, n_cols * 2)
112
- interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
113
- interleaved[:, :, 0::2] = flat_real # Even indices: real parts
114
- interleaved[:, :, 1::2] = flat_imag # Odd indices: imaginary parts
115
-
116
- if PRINT_CONVERSION_STATS:
117
- print(f" ℹ️ Converted complex spectrograms: {spectrograms.shape} -> {interleaved.shape}")
118
- print(f" Real part range: [{flat_real.min():.2e}, {flat_real.max():.2e}]")
119
- print(f" Imag part range: [{flat_imag.min():.2e}, {flat_imag.max():.2e}]")
120
-
121
- return interleaved
122
- else:
123
- # Already real-valued, just ensure correct shape
124
- if spectrograms.ndim == 3:
125
- if PRINT_CONVERSION_STATS:
126
- print(f" ℹ️ Data is already real-valued: {spectrograms.shape}")
127
- return spectrograms
128
- else:
129
- raise ValueError(f"Unexpected spectrogram shape: {spectrograms.shape}")
130
-
131
-
132
- def process_single_scenario(scenario_info):
133
- """단일 시나리오를 처리하는 함수 (멀티프로세싱용)"""
134
- scenario_name, spectrogram_path = scenario_info
135
-
136
- try:
137
- # 메모리 효율성을 위해 필요한 데이터만 로드
138
- scenario_spectrograms, scenario_labels = generate_spectrograms_and_labels(
139
- scenario_name=scenario_name,
140
- spectrogram_path=spectrogram_path,
141
- cache_path=None, # 메모리 문제로 캐시 비활성화
142
- )
143
-
144
- # Validate load
145
- if scenario_spectrograms is None or (hasattr(scenario_spectrograms, 'size') and scenario_spectrograms.size == 0):
146
- print(f" ⚠️ No data loaded from: {spectrogram_path}")
147
- return None
148
-
149
- # Convert complex spectrograms to interleaved real-imaginary format
150
- scenario_spectrograms = convert_complex_to_interleaved(scenario_spectrograms)
151
-
152
- snr_db, doppler_id = _parse_snr_and_doppler(spectrogram_path)
153
-
154
- # 데이터 분할 (인덱스만 계산)
155
- total_samples = len(scenario_spectrograms)
156
- train_size = int(0.8 * total_samples)
157
- val_size = total_samples - train_size
158
-
159
- # 메모리 절약을 위해 numpy array로 유지 (필요할 때만 tensor로 변환)
160
- train_data = np.array(scenario_spectrograms[:train_size], dtype=np.float32)
161
- val_data = np.array(scenario_spectrograms[train_size:], dtype=np.float32)
162
-
163
- snr_array = np.full(total_samples, snr_db, dtype=np.float32)
164
- doppler_array = np.full(total_samples, doppler_id, dtype=np.int64)
165
- train_meta = {
166
- 'snr_db': snr_array[:train_size],
167
- 'doppler_id': doppler_array[:train_size],
168
- }
169
- val_meta = {
170
- 'snr_db': snr_array[train_size:],
171
- 'doppler_id': doppler_array[train_size:],
172
- }
173
-
174
- # 불필요한 데이터 즉시 삭제
175
- del scenario_spectrograms
176
-
177
- return {
178
- 'scenario': scenario_name,
179
- 'train_data': train_data,
180
- 'val_data': val_data,
181
- 'train_meta': train_meta,
182
- 'val_meta': val_meta,
183
- 'train_size': len(train_data),
184
- 'val_size': len(val_data)
185
- }
186
- except Exception as e:
187
- print(f"❌ Error processing scenario {scenario_name}: {e}")
188
- import traceback
189
- traceback.print_exc()
190
- return None
191
-
192
- # GPU Memory Monitor import (for Lambda) - Removed
193
-
194
- # =============================================================================
195
- # 2. SCENARIO LIST DEFINITION
196
- # - Define the list of scenario names to iterate over for data generation
197
- # =============================================================================
198
-
199
- # Supported communications; can be limited via CLI
200
- SUPPORTED_COMM_TYPES = {"LTE", "WiFi", "5G"}
201
-
202
-
203
- def _parse_standard_args():
204
- parser = argparse.ArgumentParser(add_help=False)
205
- parser.add_argument('--standards', nargs='+', choices=SUPPORTED_COMM_TYPES,
206
- help='Specify one or more communication types to include (default: all).')
207
- for comm in SUPPORTED_COMM_TYPES:
208
- parser.add_argument(f'--{comm}', dest=f'flag_{comm}', action='store_true',
209
- help=f'Include only {comm} data (can be combined).')
210
- parser.add_argument('--city', '--cities', dest='cities', nargs='+',
211
- help='Limit scenarios to one or more city prefixes (e.g., "0" or "city_0").')
212
- parser.add_argument(
213
- '--normalization',
214
- choices=('per_sample', 'dataset'),
215
- default='per_sample',
216
- help='Normalization mode applied during tokenization (default: %(default)s).'
217
- )
218
- parser.add_argument('--help', action='help')
219
-
220
- args, remaining = parser.parse_known_args()
221
-
222
- enabled = set(SUPPORTED_COMM_TYPES)
223
- if args.standards:
224
- enabled = set(args.standards)
225
- else:
226
- flagged = {comm for comm in SUPPORTED_COMM_TYPES if getattr(args, f'flag_{comm}', False)}
227
- if flagged:
228
- enabled = flagged
229
-
230
- selected_cities: list[str] | None = None
231
- if args.cities:
232
- selected_cities = []
233
- for city_token in args.cities:
234
- token = str(city_token).strip()
235
- if not token:
236
- continue
237
- if token.startswith('city_'):
238
- selected_cities.append(token)
239
- else:
240
- selected_cities.append(f'city_{token}')
241
- if not selected_cities:
242
- selected_cities = None
243
-
244
- # Return remaining args to allow downstream parsing if needed
245
- sys.argv = [sys.argv[0]] + remaining
246
- return enabled, selected_cities, args.normalization
247
-
248
-
249
- ENABLED_COMM_TYPES, ENABLED_CITY_PREFIXES, NORMALIZATION_MODE = _parse_standard_args()
250
- MAX_SCENARIOS = int(os.environ.get("LWM_MAX_SCENARIOS", "0")) or None
251
-
252
-
253
- def _extract_scenario_token(file_path):
254
- """Derive the base scenario token (without city) from the file path."""
255
- normalized_path = os.path.normpath(file_path)
256
- parts = normalized_path.split(os.sep)
257
-
258
- scenario_parts = []
259
- for i, part in enumerate(parts):
260
- if part in SUPPORTED_COMM_TYPES:
261
- trailing = parts[i:i + 5]
262
- if trailing:
263
- scenario_parts = trailing[:5]
264
- break
265
-
266
- if not scenario_parts:
267
- # Fallback for datasets where the communication type is only captured in the filename
268
- base_name = os.path.splitext(os.path.basename(file_path))[0]
269
- if base_name.startswith('spectrogram_'):
270
- tokens = base_name.split('_')[1:] # drop 'spectrogram'
271
- if tokens and tokens[0] in SUPPORTED_COMM_TYPES:
272
- scenario_parts = tokens[:5] if len(tokens) >= 5 else tokens
273
-
274
- return '_'.join(scenario_parts) if scenario_parts else None
275
-
276
-
277
- @lru_cache(maxsize=1)
278
- def _collect_scenario_file_info():
279
- import glob
280
-
281
- scenario_entries = []
282
-
283
- # New MATLAB receiver pipeline output
284
- new_base = os.path.join('ls_data', 'MATLAB', 'receiver_pipeline')
285
- if os.path.isdir(new_base):
286
- patterns = [os.path.join(new_base, '*', '**', 'spectrogram_*.mat')]
287
- for pattern in patterns:
288
- for file_path in sorted(glob.glob(pattern, recursive=True)):
289
- norm = os.path.normpath(file_path)
290
- parts = norm.split(os.sep)
291
- # Determine a grouping token similar to city_name; use the standard folder name
292
- try:
293
- idx = parts.index('receiver_pipeline')
294
- city_name = parts[idx + 1] if idx + 1 < len(parts) else 'receiver_pipeline'
295
- except ValueError:
296
- city_name = 'receiver_pipeline'
297
-
298
- base_token = _extract_scenario_token(file_path)
299
- if not base_token:
300
- continue
301
- comm_type = base_token.split('_', 1)[0]
302
- if comm_type not in ENABLED_COMM_TYPES:
303
- continue
304
- scenario_id = f"{city_name}::{base_token}"
305
- scenario_entries.append((scenario_id, file_path, city_name, base_token))
306
-
307
- # Legacy repo layouts under spectrograms/city_*
308
- import glob as _glob
309
- for city_dir in sorted(_glob.glob(os.path.join('spectrograms', 'city_*'))):
310
- if not os.path.isdir(city_dir):
311
- continue
312
- city_name = os.path.basename(city_dir)
313
- if ENABLED_CITY_PREFIXES:
314
- if not any(city_name.startswith(prefix) for prefix in ENABLED_CITY_PREFIXES):
315
- continue
316
- # Look for complex spectrogram outputs; support both nested and flat layouts
317
- candidate_patterns = [
318
- os.path.join(city_dir, '**', 'complex_raw', '**', 'spectrogram_*.mat'),
319
- os.path.join(city_dir, '**', 'spectrogram_*.mat'),
320
- ]
321
- city_files = []
322
- seen_paths = set()
323
- for pattern in candidate_patterns:
324
- for file_path in sorted(_glob.glob(pattern, recursive=True)):
325
- if not file_path.lower().endswith('.mat'):
326
- continue
327
- if file_path in seen_paths:
328
- continue
329
- seen_paths.add(file_path)
330
- city_files.append(file_path)
331
-
332
- # Fallback: 512FFT pattern (기존 호환성)
333
- if not city_files:
334
- pattern = os.path.join(city_dir, '**', '512FFT', '**', 'spectrograms', '*.pkl')
335
- city_files = sorted(_glob.glob(pattern, recursive=True))
336
-
337
- for file_path in city_files:
338
- base_token = _extract_scenario_token(file_path)
339
- if not base_token:
340
- continue
341
- comm_type = base_token.split('_', 1)[0]
342
- if comm_type not in ENABLED_COMM_TYPES:
343
- continue
344
- scenario_id = f"{city_name}::{base_token}"
345
- scenario_entries.append((scenario_id, file_path, city_name, base_token))
346
-
347
- if MAX_SCENARIOS:
348
- scenario_entries = scenario_entries[:MAX_SCENARIOS]
349
-
350
- return scenario_entries
351
-
352
-
353
- def scenarios_list():
354
- scenario_entries = _collect_scenario_file_info()
355
-
356
- if not scenario_entries:
357
- print("⚠️ No spectrogram files found for pretraining.")
358
- return np.array([])
359
-
360
- print(f"Enabled communication types: {sorted(ENABLED_COMM_TYPES)}")
361
- if ENABLED_CITY_PREFIXES:
362
- print(f"Selected city prefixes: {sorted(ENABLED_CITY_PREFIXES)}")
363
- city_counts = Counter(entry[2] for entry in scenario_entries)
364
- print("Using scenarios from the following city datasets:")
365
- for city_name, count in city_counts.items():
366
- print(f" - {city_name}: {count} files")
367
-
368
- print(f"Total scenarios selected: {len(scenario_entries)}")
369
- return np.array([entry[0] for entry in scenario_entries])
370
-
371
-
372
- # =============================================================================
373
- # 3. SCENARIO PROPERTIES MAPPING
374
- # - Map each scenario name to its corresponding properties
375
- # =============================================================================
376
-
377
- def scenario_prop():
378
- scenario_entries = _collect_scenario_file_info()
379
-
380
- row_column_users = {}
381
- for scenario_id, file_path, city_name, _ in scenario_entries:
382
- row_column_users[scenario_id] = {
383
- 'spectrogram_path': file_path,
384
- 'cache_path': os.path.join('spectrograms', city_name, 'spectrogram_cache_128x128.pkl')
385
- }
386
-
387
- return row_column_users
388
-
389
- # =============================================================================
390
- # 4. TRAINING PARAMETERS AND HYPERPARAMETERS
391
- # - Set training epochs, batch sizes, learning rates, model dimensions, etc.
392
- # =============================================================================
393
-
394
- EPOCHS = 20 # Increased for better convergence
395
- # Optimized batch size for A100 GPU (40GB)
396
- BATCH_SIZE = 16
397
- VAL_BATCH_SIZE = 16
398
- WARMUP_EPOCHS = 5
399
- BASE_LR = 5e-4
400
- MIN_LR = 1e-8
401
- # Updated for 128x128 complex spectrograms with real-imaginary interleaving
402
- N_ROWS = 4
403
- N_COLUMNS = 4
404
- ELEMENT_LENGTH = N_ROWS * N_COLUMNS * 2 # Complex spectrograms: 2x for real+imaginary interleaving
405
- D_MODEL = 128
406
- MAX_LEN = 1025 # (128/4) * (128/4) + 1 = 32 * 32 + 1 = 1024 + 1 for [CLS] token
407
- # Interleaving keeps the same number of spatial patches (32x32) while doubling patch width
408
- # so each token covers 4x4 complex bins (real+imag) and sequence length stays at 1025.
409
- N_LAYERS = 12
410
- device_idx = 0
411
- WEIGHT_DECAY = 0.05
412
- BETA1 = 0.9
413
- BETA2 = 0.999
414
- MASK_PERCENT = 0.6
415
- N_HEADS = 8
416
- DROPOUT = 0.1
417
-
418
- print(f"📊 Model configuration for complex spectrograms:")
419
- print(f" Patch size: {N_ROWS}x{N_COLUMNS}")
420
- print(f" Element length: {ELEMENT_LENGTH} (includes real+imag interleaving)")
421
- print(f" Max sequence length: {MAX_LEN}")
422
-
423
- # =============================================================================
424
- # 5. DATA GENERATION LOOP
425
- # - Iterate over scenarios to generate spectrogram samples and labels
426
- # =============================================================================
427
-
428
- scenarios = scenarios_list()
429
- scenario_properties = scenario_prop()
430
-
431
- # Collect all training and validation data separately
432
- train_spectrogram_chunks = []
433
- val_spectrogram_chunks = []
434
- train_label_chunks = []
435
- val_label_chunks = []
436
- train_meta_chunks = []
437
- val_meta_chunks = []
438
-
439
- print(f"📂 Loading {len(scenarios)} scenarios...")
440
-
441
- # TEMP: Modified to not use cache
442
- print("⚠️ TEMPORARY FIX: Skipping cache to avoid memory issues")
443
- cache_path = None # Disable cache usage
444
-
445
- # 단일 프로세스 시나리오 처리 (멀티프로세싱 비활성화)
446
- scenario_info_list = []
447
- missing_props = []
448
- for scenario in scenarios:
449
- props = scenario_properties.get(scenario)
450
- if props is None:
451
- missing_props.append(scenario)
452
- continue
453
- scenario_info_list.append((scenario, props["spectrogram_path"]))
454
-
455
- if missing_props:
456
- print("⚠️ Missing metadata for the following scenarios; skipping:")
457
- for scen in missing_props:
458
- print(f" - {scen}")
459
-
460
- print(f"📂 Loading {len(scenario_info_list)} scenarios using single process...")
461
-
462
- # 단일 프로세스로 처리
463
- successful_scenarios = 0
464
- scenario_results = []
465
-
466
- for scenario_info in tqdm(scenario_info_list, desc="Processing scenarios", unit="scenario"):
467
- scenario_name = scenario_info[0]
468
- try:
469
- result = process_single_scenario(scenario_info)
470
- if result is not None:
471
- # 데이터 수집 (시나리오 단위로 누적)
472
- train_spectrogram_chunks.append(result['train_data'])
473
- val_spectrogram_chunks.append(result['val_data'])
474
- train_label_chunks.append(np.zeros(result['train_size'], dtype=np.int64))
475
- val_label_chunks.append(np.zeros(result['val_size'], dtype=np.int64))
476
- train_meta_chunks.append(result['train_meta'])
477
- val_meta_chunks.append(result['val_meta'])
478
- successful_scenarios += 1
479
- except Exception as e:
480
- print(f"❌ Scenario {scenario_name} processing failed: {e}")
481
-
482
- print(f"✅ Processing completed! Successful scenarios: {successful_scenarios}/{len(scenario_info_list)}")
483
-
484
- if not train_spectrogram_chunks or not val_spectrogram_chunks:
485
- raise ValueError("No spectrogram data collected; check scenario configuration.")
486
-
487
- print("🔄 Collating spectrogram arrays...")
488
- train_spectrograms = np.concatenate(train_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
489
- val_spectrograms = np.concatenate(val_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
490
- train_labels = np.concatenate(train_label_chunks, axis=0)
491
- val_labels = np.concatenate(val_label_chunks, axis=0)
492
-
493
- def _concat_metadata_dicts(dict_list):
494
- if not dict_list:
495
- return {}
496
- keys = dict_list[0].keys()
497
- return {k: np.concatenate([d[k] for d in dict_list], axis=0) for k in keys}
498
-
499
- train_metadata = _concat_metadata_dicts(train_meta_chunks)
500
- val_metadata = _concat_metadata_dicts(val_meta_chunks)
501
-
502
- del train_spectrogram_chunks, val_spectrogram_chunks, train_label_chunks, val_label_chunks
503
- del train_meta_chunks, val_meta_chunks
504
-
505
- print(f"Training spectrograms shape: {train_spectrograms.shape}")
506
- print(f"Validation spectrograms shape: {val_spectrograms.shape}")
507
- print(f"Memory usage: {train_spectrograms.nbytes + val_spectrograms.nbytes + train_labels.nbytes + val_labels.nbytes:,} bytes")
508
-
509
- train_mean = float(train_spectrograms.mean())
510
- train_std = float(train_spectrograms.std())
511
- if abs(train_std) < 1e-6:
512
- print("⚠️ Training std near zero, using epsilon for stability")
513
- train_std = 1e-6
514
- dataset_normalization = {'mean': train_mean, 'std': train_std, 'normalization': NORMALIZATION_MODE}
515
- print(f"Dataset normalization stats -> mean: {train_mean:.4f}, std: {train_std:.4f}")
516
-
517
- # =============================================================================
518
- # 6. DATA TOKENIZATION
519
- # - Tokenize spectrogram matrices into input sequences with masking for pretraining
520
- # =============================================================================
521
-
522
- # Tokenize training data
523
- print("🔄 Starting tokenization of training data...")
524
- preprocessed_train = tokenizer_train(
525
- train_spectrograms,
526
- max_len=MAX_LEN,
527
- masking_percent=MASK_PERCENT,
528
- mask=True,
529
- seed=42,
530
- metadata=train_metadata,
531
- dataset_stats=dataset_normalization,
532
- normalization=NORMALIZATION_MODE,
533
- interleaved=True,
534
- )
535
- print("✅ Training data tokenization completed!")
536
-
537
- # Tokenize validation data (with masking for pretraining evaluation)
538
- print("🔄 Starting tokenization of validation data...")
539
- preprocessed_val = tokenizer_train(
540
- val_spectrograms,
541
- max_len=MAX_LEN,
542
- masking_percent=MASK_PERCENT,
543
- mask=True, # Apply masking for pretraining evaluation
544
- seed=42,
545
- metadata=val_metadata,
546
- dataset_stats=dataset_normalization,
547
- normalization=NORMALIZATION_MODE,
548
- interleaved=True,
549
- )
550
- print("✅ Validation data tokenization completed!")
551
-
552
- # =============================================================================
553
- # 7. TRAIN/VALIDATION DATA SETUP
554
- # - Use pre-split training and validation data
555
- # =============================================================================
556
-
557
- SEED = 42
558
- torch.manual_seed(SEED)
559
- np.random.seed(SEED)
560
-
561
- # Use pre-split data
562
- train_data = preprocessed_train
563
- val_data = preprocessed_val
564
-
565
- # =============================================================================
566
- # 8. DATALOADER CREATION
567
- # - Build PyTorch DataLoader objects for batched training and validation
568
- # =============================================================================
569
-
570
- # Handle different data formats
571
- print("🔧 Creating data loaders...")
572
-
573
- if isinstance(train_data, dict):
574
- print(f" Training data format: dict with {len(train_data)} sequence lengths")
575
- # Training data with masking
576
- train_loaders = create_train_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
577
- else:
578
- print(f" Training data format: tensor with shape {train_data.shape}")
579
- # Training data without masking (fallback)
580
- train_dataset = TensorDataset(train_data)
581
- train_loaders = {'seq_0': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)}
582
-
583
- if isinstance(val_data, dict):
584
- print(f" Validation data format: dict with {len(val_data)} sequence lengths")
585
- # Validation data with masking
586
- val_loaders = create_train_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
587
- else:
588
- print(f" Validation data format: tensor with shape {val_data.shape}")
589
- # Validation data without masking
590
- val_dataset = TensorDataset(val_data)
591
- val_loaders = {'seq_0': DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)}
592
-
593
- print("✅ Data loaders created successfully!")
594
-
595
- # =============================================================================
596
- # 9. MODEL INITIALIZATION
597
- # - Instantiate the LWM transformer model and optionally load pre-trained weights
598
- # - Wrap with DataParallel for multi-GPU support
599
- # =============================================================================
600
-
601
- # Device selection with MPS support for Mac
602
- print("🔧 Setting up device and GPU configuration...")
603
-
604
- if torch.cuda.is_available():
605
- device_count = torch.cuda.device_count()
606
- print(f" CUDA available: {device_count} GPU(s) detected")
607
-
608
- device = torch.device("cuda:0")
609
-
610
- # On Windows, use only available GPUs
611
- gpu_ids = list(range(device_count)) # 0, 1, 2... auto-detect
612
- print(f" Using CUDA GPUs: {gpu_ids}")
613
-
614
- # GPU memory status
615
- for i in gpu_ids:
616
- try:
617
- mem_total = torch.cuda.get_device_properties(i).total_memory / 1024**3
618
- mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
619
- print(f" GPU {i}: Total: {mem_total:.1f}GB, Allocated: {mem_allocated:.1f}GB")
620
- except Exception as e:
621
- print(f" GPU {i}: Error getting memory info - {e}")
622
-
623
- elif torch.backends.mps.is_available():
624
- device = torch.device("mps")
625
- gpu_ids = [] # MPS doesn't support DataParallel
626
- print(" Using MPS (Apple Silicon GPU)")
627
- else:
628
- device = torch.device("cpu")
629
- gpu_ids = []
630
- print(" Using CPU")
631
-
632
- print(f" Final device: {device}")
633
- print(f" GPU IDs for DataParallel: {gpu_ids}")
634
-
635
- print("🤖 Initializing LWM model...")
636
- print(f" Model parameters: element_length={ELEMENT_LENGTH}, d_model={D_MODEL}, n_layers={N_LAYERS}, max_len={MAX_LEN}, n_heads={N_HEADS}")
637
-
638
- try:
639
- model = pretrained_model.lwm(
640
- element_length=ELEMENT_LENGTH, # Complex spectrograms with real-imag interleaving
641
- d_model=D_MODEL,
642
- n_layers=N_LAYERS,
643
- max_len=MAX_LEN,
644
- n_heads=N_HEADS,
645
- dropout=DROPOUT
646
- )
647
- print(" ✅ Model created successfully")
648
-
649
- print(f" Moving model to device: {device}")
650
- # MPS only supports float32, so set dtype
651
- if 'mps' in str(device):
652
- model = model.to(device).float()
653
- print(" ✅ Model moved to MPS device (float32)")
654
- else:
655
- model = model.to(device)
656
- print(" ✅ Model moved to device successfully")
657
-
658
- except Exception as e:
659
- print(f" ❌ Model initialization failed: {e}")
660
- import traceback
661
- traceback.print_exc()
662
- exit(1)
663
-
664
- # Optional: Load pre-trained model
665
- load_model = False
666
- if load_model:
667
- model.load_state_dict(torch.load("models/model_checkpoint.pth", map_location=device))
668
- print("Pre-trained model loaded successfully.")
669
-
670
- # Use DataParallel for multi-GPU support (skip for MPS)
671
- if gpu_ids:
672
- model = nn.DataParallel(model, device_ids=gpu_ids)
673
- print(f"Model loaded successfully on GPU {device.index}")
674
- else:
675
- print(f"Model loaded successfully on {device}")
676
- n_parameters = count_parameters(model)
677
- print(f"Number of trainable parameters: {n_parameters:,}")
678
-
679
- # =============================================================================
680
- # 10. OPTIMIZER AND LEARNING RATE SCHEDULER
681
- # - Configure AdamW optimizer and a cosine-with-warmup LR schedule based on total steps
682
- # =============================================================================
683
-
684
- TOTAL_STEPS = sum(len(loader) for loader in train_loaders.values()) * EPOCHS
685
- WARMUP_STEPS = sum(len(loader) for loader in train_loaders.values()) * WARMUP_EPOCHS
686
-
687
- optimizer = AdamW(
688
- model.parameters(),
689
- lr=BASE_LR,
690
- betas=(BETA1, BETA2),
691
- weight_decay=WEIGHT_DECAY
692
- )
693
-
694
- def lr_lambda(current_step):
695
- if current_step < WARMUP_STEPS:
696
- return current_step / WARMUP_STEPS
697
- else:
698
- scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
699
- cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
700
- return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
701
-
702
- scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
703
-
704
- # =============================================================================
705
- # 11. PRE-TRAINING LOOP
706
- # - Call the train_lwm utility to run the pre-training epochs, logging metrics and saving models
707
- # =============================================================================
708
-
709
- # Create timestamp-based save directory
710
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
711
- save_dir = f"models/{timestamp}_complex"
712
- print(f"📁 Models and logs will be saved to: {save_dir}")
713
- os.makedirs(save_dir, exist_ok=True)
714
-
715
- stats_path = os.path.join(save_dir, "dataset_stats.json")
716
- with open(stats_path, 'w') as f:
717
- json.dump(dataset_normalization, f, indent=2)
718
- print(f"📝 Saved dataset stats to {stats_path}")
719
-
720
- comm_selection = sorted(ENABLED_COMM_TYPES) if ENABLED_COMM_TYPES else []
721
- if comm_selection:
722
- comm_suffix = "_" + "-".join(comm_selection)
723
- else:
724
- comm_suffix = ""
725
- if comm_selection:
726
- print(f"[INFO] Communication standards for this run: {', '.join(comm_selection)}")
727
-
728
- if __name__ == "__main__":
729
- pretrained_model_output = train_lwm(
730
- model,
731
- train_loaders,
732
- val_loaders,
733
- optimizer,
734
- scheduler,
735
- EPOCHS,
736
- device=device,
737
- save_dir=save_dir,
738
- log_file="training_log.csv",
739
- checkpoint_suffix=comm_suffix + "_complex",
740
- )
741
- print("🎉 Training completed successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretraining/train_lwm_spectro_contrastive.py DELETED
@@ -1,1450 +0,0 @@
1
- #!/usr/bin/env python3
2
- # =============================================================================
3
- # train_lwm_spectro_contrastive.py - LWM Pretraining with Contrastive Learning
4
- # Extended from train_lwm_spectro.py to add modulation/mobility contrastive learning
5
- #
6
- # Key additions:
7
- # - Contrastive learning module with projection head
8
- # - Multi-task loss: MLM + Contrastive (modulation + mobility)
9
- # - Hard negative mining
10
- # - Supervised contrastive loss (SupCon)
11
- # =============================================================================
12
-
13
- # =============================================================================
14
- # 1. IMPORTS AND WARNINGS SETUP
15
- # =============================================================================
16
- import sys
17
- import os
18
- import argparse
19
- import math
20
- # Add project root to path (Windows compatible)
21
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22
- sys.path.insert(0, project_root)
23
- import torch
24
- import torch.nn as nn
25
- import torch.nn.functional as F
26
- from torch.utils.data import DataLoader, random_split, TensorDataset, Dataset
27
- import torch.optim as optim
28
- from utils import (generate_spectrograms_and_labels, tokenizer_train,
29
- create_train_dataloader, count_parameters)
30
- import numpy as np
31
- import pretrained_model # Assuming this contains the LWM model definition
32
- from torch.optim.lr_scheduler import LambdaLR
33
- from torch.optim import AdamW
34
- import warnings
35
- import platform
36
- import re
37
- from tqdm import tqdm
38
- from datetime import datetime
39
- import concurrent.futures
40
- import multiprocessing
41
- from collections import Counter
42
- from functools import lru_cache
43
- import json
44
- from typing import Dict, Tuple, List, Optional
45
-
46
- SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
47
- DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
48
- DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
49
-
50
- # Dynamic modulation mapping - will be built from actual data
51
- MODULATION_MAP = {} # Will be populated: {"BPSK": 0, "QPSK": 1, ...}
52
- MODULATION_INV = {} # Will be populated: {0: "BPSK", 1: "QPSK", ...}
53
-
54
- # Standard-to-modulation mapping (for reference only - not used in code)
55
- # Note: Actual modulations are dynamically discovered from file paths
56
- # These match the MCS definitions in MATLAB/receiver_pipeline/getMCSDefinitions.m
57
- STANDARD_MODULATIONS = {
58
- "WiFi": [
59
- "BPSK", "QPSK", "16QAM", "64QAM"
60
- # From getMCSDefinitions.m WiFi MCS table:
61
- # - MCS 0: BPSK rate1-2
62
- # - MCS 1-2: QPSK rate1-2, rate3-4
63
- # - MCS 3-4: 16QAM rate1-2, rate3-4
64
- # - MCS 5-7: 64QAM rate2-3, rate3-4, rate5-6
65
- # Note: Your MATLAB pipeline uses 802.11a/g MCS (no 256QAM/1024QAM)
66
- ],
67
- "LTE": [
68
- "QPSK", "16QAM", "64QAM"
69
- # From getMCSDefinitions.m LTE MCS table:
70
- # - MCS 0-2: QPSK rate1-3, rate1-2, rate3-4
71
- # - MCS 3-4: 16QAM rate1-2, rate3-4
72
- # - MCS 5-6: 64QAM rate2-3, rate3-4
73
- # Note: Your MATLAB pipeline does NOT include 256QAM
74
- ],
75
- "5G": [
76
- "QPSK", "16QAM", "64QAM", "256QAM"
77
- # From getMCSDefinitions.m 5G MCS table:
78
- # - MCS 0-1: QPSK rate1-3, rate1-2
79
- # - MCS 2-3: 16QAM rate1-2, rate3-4
80
- # - MCS 4-5: 64QAM rate2-3, rate3-4
81
- # - MCS 6: 256QAM rate3-4
82
- ],
83
- }
84
-
85
- # Important: This mapping is for documentation only
86
- # The actual modulations used in your dataset may differ
87
- # They will be automatically discovered from file paths
88
-
89
-
90
- def _parse_metadata(path: str) -> Dict[str, any]:
91
- """
92
- Parse SNR, Doppler, and Modulation from file path.
93
- Modulation is dynamically extracted and added to global MODULATION_MAP.
94
-
95
- Returns:
96
- dict with keys: snr_db, doppler_id, modulation_id, modulation_name
97
- """
98
- global MODULATION_MAP, MODULATION_INV
99
-
100
- snr_db = 0.0
101
- doppler_id = 0
102
- modulation_name = "Unknown"
103
-
104
- # Parse SNR
105
- matches = SNR_PATTERN.findall(path)
106
- if matches:
107
- try:
108
- snr_db = float(matches[-1])
109
- except ValueError:
110
- snr_db = 0.0
111
-
112
- # Parse Doppler
113
- normalized_path = os.path.normpath(path)
114
- parts = normalized_path.split(os.sep)
115
- for part in parts:
116
- if part in DOPPLER_MAP:
117
- doppler_id = DOPPLER_MAP[part]
118
- break
119
-
120
- # Parse Modulation (dynamic - look for common modulation patterns)
121
- # Patterns: BPSK, QPSK, 8PSK, 16QAM, 32QAM, 64QAM, 256QAM, 1024QAM, etc.
122
- # Note: We ONLY use explicit modulation names in the path, not code rates
123
- # since the same code rate can be used with different modulations
124
- modulation_patterns = [
125
- r"BPSK",
126
- r"QPSK",
127
- r"8PSK",
128
- r"16QAM",
129
- r"32QAM",
130
- r"64QAM",
131
- r"128QAM",
132
- r"256QAM",
133
- r"512QAM",
134
- r"1024QAM",
135
- ]
136
-
137
- for pattern in modulation_patterns:
138
- if re.search(pattern, path, re.IGNORECASE):
139
- modulation_name = pattern
140
- break
141
-
142
- # Add to global mapping if new
143
- if modulation_name != "Unknown" and modulation_name not in MODULATION_MAP:
144
- modulation_id = len(MODULATION_MAP)
145
- MODULATION_MAP[modulation_name] = modulation_id
146
- MODULATION_INV[modulation_id] = modulation_name
147
- elif modulation_name in MODULATION_MAP:
148
- modulation_id = MODULATION_MAP[modulation_name]
149
- else:
150
- modulation_id = -1 # Unknown
151
-
152
- return {
153
- 'snr_db': snr_db,
154
- 'doppler_id': doppler_id,
155
- 'modulation_id': modulation_id,
156
- 'modulation_name': modulation_name
157
- }
158
-
159
-
160
- warnings.filterwarnings("ignore", category=UserWarning)
161
-
162
- # Use simple progress display instead of tqdm on Windows
163
- USE_TQDM = platform.system() != 'Windows'
164
-
165
- # CPU 코어 수 계산 (메모리 사용량 고려하여 보수적으로 설정)
166
- total_cores = multiprocessing.cpu_count()
167
- if total_cores >= 16:
168
- MAX_WORKERS = min(8, total_cores // 2)
169
- else:
170
- MAX_WORKERS = max(2, total_cores // 2)
171
- print(f"🚀 Using {MAX_WORKERS}/{total_cores} CPU cores for parallel processing")
172
-
173
- PRINT_CONVERSION_STATS = os.environ.get("LWM_PRINT_CONVERSION_STATS", "").strip().lower() in {"1", "true", "yes"}
174
-
175
-
176
- # =============================================================================
177
- # 2. CONTRASTIVE LEARNING COMPONENTS
178
- # =============================================================================
179
-
180
- class ProjectionHead(nn.Module):
181
- """
182
- Projection head for contrastive learning (SimCLR-style).
183
- Projects encoder output to a lower-dimensional space for contrastive loss.
184
- """
185
- def __init__(self, d_model: int, projection_dim: int = 128):
186
- super().__init__()
187
- self.projection = nn.Sequential(
188
- nn.Linear(d_model, d_model),
189
- nn.ReLU(),
190
- nn.Linear(d_model, projection_dim)
191
- )
192
-
193
- def forward(self, x):
194
- """
195
- Args:
196
- x: (batch, seq_len, d_model) - Encoder output
197
- Returns:
198
- z: (batch, projection_dim) - Projected embeddings
199
- """
200
- # Global average pooling over sequence dimension
201
- pooled = x.mean(dim=1) # (batch, d_model)
202
- z = self.projection(pooled) # (batch, projection_dim)
203
- z = F.normalize(z, dim=1) # L2 normalize
204
- return z
205
-
206
-
207
- class ContrastiveLWM(nn.Module):
208
- """
209
- LWM model with contrastive learning projection heads.
210
- """
211
- def __init__(self, lwm_encoder, projection_dim: int = 128, input_dim: int = 32):
212
- super().__init__()
213
- self.encoder = lwm_encoder
214
-
215
- # MLM reconstruction head: project d_model back to input_dim
216
- self.mlm_head = nn.Linear(lwm_encoder.d_model, input_dim)
217
-
218
- # Separate projection heads for modulation and mobility
219
- self.modulation_projection = ProjectionHead(lwm_encoder.d_model, projection_dim)
220
- self.mobility_projection = ProjectionHead(lwm_encoder.d_model, projection_dim)
221
-
222
- def forward(self, x, return_projections: bool = False):
223
- """
224
- Args:
225
- x: Input tokens
226
- return_projections: If True, return contrastive projections and MLM predictions
227
-
228
- Returns:
229
- If return_projections:
230
- mlm_predictions, z_modulation, z_mobility
231
- Else:
232
- mlm_predictions (for MLM task only)
233
- """
234
- # Forward through encoder
235
- encoder_out = self.encoder(x) # (batch, seq_len, d_model)
236
-
237
- # MLM prediction head (always compute for reconstruction)
238
- mlm_predictions = self.mlm_head(encoder_out) # (batch, seq_len, input_dim)
239
-
240
- if return_projections:
241
- z_mod = self.modulation_projection(encoder_out)
242
- z_mob = self.mobility_projection(encoder_out)
243
- return mlm_predictions, z_mod, z_mob
244
- else:
245
- return mlm_predictions
246
-
247
-
248
- def supervised_contrastive_loss(
249
- embeddings: torch.Tensor,
250
- labels: torch.Tensor,
251
- temperature: float = 0.07,
252
- base_temperature: float = 0.07
253
- ) -> torch.Tensor:
254
- """
255
- Supervised Contrastive Loss (SupCon) from Khosla et al. 2020.
256
-
257
- Args:
258
- embeddings: (batch, dim) - Normalized embeddings
259
- labels: (batch,) - Class labels
260
- temperature: Temperature scaling
261
- base_temperature: Base temperature for normalization
262
-
263
- Returns:
264
- loss: Scalar SupCon loss
265
- """
266
- batch_size = embeddings.size(0)
267
-
268
- # Compute similarity matrix
269
- sim_matrix = torch.matmul(embeddings, embeddings.T) / temperature # (batch, batch)
270
-
271
- # Mask for positives (same label)
272
- labels = labels.contiguous().view(-1, 1)
273
- mask_pos = torch.eq(labels, labels.T).float().to(embeddings.device) # (batch, batch)
274
-
275
- # Remove diagonal (self-similarity)
276
- logits_mask = torch.scatter(
277
- torch.ones_like(mask_pos),
278
- 1,
279
- torch.arange(batch_size).view(-1, 1).to(embeddings.device),
280
- 0
281
- )
282
- mask_pos = mask_pos * logits_mask
283
-
284
- # Compute log probabilities
285
- exp_sim = torch.exp(sim_matrix) * logits_mask
286
- log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
287
-
288
- # Mean over positives
289
- mean_log_prob_pos = (mask_pos * log_prob).sum(dim=1) / (mask_pos.sum(dim=1) + 1e-8)
290
-
291
- # Loss
292
- loss = -(temperature / base_temperature) * mean_log_prob_pos
293
- loss = loss.mean()
294
-
295
- return loss
296
-
297
-
298
- class ContrastiveDataset(Dataset):
299
- """
300
- Dataset wrapper that provides contrastive learning triplets.
301
- """
302
- def __init__(
303
- self,
304
- spectrograms: np.ndarray,
305
- labels: np.ndarray,
306
- metadata: Dict[str, np.ndarray],
307
- indices_by_modulation: Dict[int, List[int]],
308
- indices_by_mobility: Dict[int, List[int]]
309
- ):
310
- self.spectrograms = spectrograms
311
- self.labels = labels
312
- self.metadata = metadata
313
- self.indices_by_modulation = indices_by_modulation
314
- self.indices_by_mobility = indices_by_mobility
315
-
316
- def __len__(self):
317
- return len(self.spectrograms)
318
-
319
- def __getitem__(self, idx):
320
- """
321
- Returns anchor sample with its metadata.
322
- """
323
- spectrogram = self.spectrograms[idx]
324
- label = self.labels[idx]
325
-
326
- metadata = {
327
- 'snr_db': self.metadata['snr_db'][idx],
328
- 'doppler_id': self.metadata['doppler_id'][idx],
329
- 'modulation_id': self.metadata['modulation_id'][idx]
330
- }
331
-
332
- return spectrogram, label, metadata
333
-
334
-
335
- # =============================================================================
336
- # 3. DATA CONVERSION AND PREPROCESSING
337
- # =============================================================================
338
-
339
- def convert_complex_to_interleaved(spectrograms):
340
- """
341
- Convert complex-valued spectrograms to real-imaginary interleaved format.
342
-
343
- Args:
344
- spectrograms (np.ndarray): Complex-valued array of shape (n_samples, n_rows, n_cols)
345
- or (n_samples, 1, n_rows, n_cols)
346
-
347
- Returns:
348
- np.ndarray: Real-valued array with interleaved real/imag parts
349
- Shape: (n_samples, n_rows, n_cols * 2)
350
- """
351
- # Handle different input shapes
352
- if spectrograms.ndim == 4:
353
- spectrograms = spectrograms[:, 0, :, :]
354
-
355
- # Check if data is complex
356
- if np.iscomplexobj(spectrograms):
357
- n_samples, n_rows, n_cols = spectrograms.shape
358
-
359
- # Extract real and imaginary parts
360
- flat_real = spectrograms.real
361
- flat_imag = spectrograms.imag
362
-
363
- # Interleave real and imaginary parts along the last axis
364
- interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
365
- interleaved[:, :, 0::2] = flat_real # Even indices: real parts
366
- interleaved[:, :, 1::2] = flat_imag # Odd indices: imaginary parts
367
-
368
- if PRINT_CONVERSION_STATS:
369
- print(f" ℹ️ Converted complex spectrograms: {spectrograms.shape} -> {interleaved.shape}")
370
- print(f" Real part range: [{flat_real.min():.2e}, {flat_real.max():.2e}]")
371
- print(f" Imag part range: [{flat_imag.min():.2e}, {flat_imag.max():.2e}]")
372
-
373
- return interleaved
374
- else:
375
- # Already real-valued
376
- if spectrograms.ndim == 3:
377
- if PRINT_CONVERSION_STATS:
378
- print(f" ℹ️ Data is already real-valued: {spectrograms.shape}")
379
- return spectrograms
380
- else:
381
- raise ValueError(f"Unexpected spectrogram shape: {spectrograms.shape}")
382
-
383
-
384
- def process_single_scenario(scenario_info):
385
- """단일 시나리오를 처리하는 함수 (멀티프로세싱용)"""
386
- scenario_name, spectrogram_path = scenario_info
387
-
388
- try:
389
- # Parse metadata from path
390
- path_metadata = _parse_metadata(spectrogram_path)
391
-
392
- # 메모리 효율성을 위해 필요한 데이터만 로드
393
- scenario_spectrograms, scenario_labels = generate_spectrograms_and_labels(
394
- scenario_name=scenario_name,
395
- spectrogram_path=spectrogram_path,
396
- cache_path=None, # 메모리 문제로 캐시 비활성화
397
- )
398
-
399
- # Validate load
400
- if scenario_spectrograms is None or (hasattr(scenario_spectrograms, 'size') and scenario_spectrograms.size == 0):
401
- print(f" ⚠️ No data loaded from: {spectrogram_path}")
402
- return None
403
-
404
- # Convert complex spectrograms to interleaved real-imaginary format
405
- scenario_spectrograms = convert_complex_to_interleaved(scenario_spectrograms)
406
-
407
- # 데이터 분할 (인덱스만 계산)
408
- total_samples = len(scenario_spectrograms)
409
- train_size = int(0.8 * total_samples)
410
- val_size = total_samples - train_size
411
-
412
- # 메모리 절약을 위해 numpy array로 유지
413
- train_data = np.array(scenario_spectrograms[:train_size], dtype=np.float32)
414
- val_data = np.array(scenario_spectrograms[train_size:], dtype=np.float32)
415
-
416
- # Metadata arrays
417
- snr_array = np.full(total_samples, path_metadata['snr_db'], dtype=np.float32)
418
- doppler_array = np.full(total_samples, path_metadata['doppler_id'], dtype=np.int64)
419
- modulation_array = np.full(total_samples, path_metadata['modulation_id'], dtype=np.int64)
420
-
421
- train_meta = {
422
- 'snr_db': snr_array[:train_size],
423
- 'doppler_id': doppler_array[:train_size],
424
- 'modulation_id': modulation_array[:train_size],
425
- }
426
- val_meta = {
427
- 'snr_db': snr_array[train_size:],
428
- 'doppler_id': doppler_array[train_size:],
429
- 'modulation_id': modulation_array[train_size:],
430
- }
431
-
432
- # 불필요한 데이터 즉시 삭제
433
- del scenario_spectrograms
434
-
435
- return {
436
- 'scenario': scenario_name,
437
- 'train_data': train_data,
438
- 'val_data': val_data,
439
- 'train_meta': train_meta,
440
- 'val_meta': val_meta,
441
- 'train_size': len(train_data),
442
- 'val_size': len(val_data)
443
- }
444
- except Exception as e:
445
- print(f"❌ Error processing scenario {scenario_name}: {e}")
446
- import traceback
447
- traceback.print_exc()
448
- return None
449
-
450
-
451
- # =============================================================================
452
- # 4. SCENARIO LIST AND PROPERTIES (Same as original)
453
- # =============================================================================
454
-
455
- SUPPORTED_COMM_TYPES = {"LTE", "WiFi", "5G"}
456
-
457
-
458
- def _parse_standard_args():
459
- parser = argparse.ArgumentParser(add_help=False)
460
- parser.add_argument('--standards', nargs='+', choices=SUPPORTED_COMM_TYPES,
461
- help='Specify one or more communication types to include (default: all).')
462
- for comm in SUPPORTED_COMM_TYPES:
463
- parser.add_argument(f'--{comm}', dest=f'flag_{comm}', action='store_true',
464
- help=f'Include only {comm} data (can be combined).')
465
- parser.add_argument('--city', '--cities', dest='cities', nargs='+',
466
- help='Limit scenarios to one or more city prefixes (e.g., "0" or "city_0").')
467
- parser.add_argument(
468
- '--normalization',
469
- choices=('per_sample', 'dataset'),
470
- default='per_sample',
471
- help='Normalization mode applied during tokenization (default: %(default)s).'
472
- )
473
- parser.add_argument('--help', action='help')
474
-
475
- args, remaining = parser.parse_known_args()
476
-
477
- enabled = set(SUPPORTED_COMM_TYPES)
478
- if args.standards:
479
- enabled = set(args.standards)
480
- else:
481
- flagged = {comm for comm in SUPPORTED_COMM_TYPES if getattr(args, f'flag_{comm}', False)}
482
- if flagged:
483
- enabled = flagged
484
-
485
- selected_cities: list[str] | None = None
486
- if args.cities:
487
- selected_cities = []
488
- for city_token in args.cities:
489
- token = str(city_token).strip()
490
- if not token:
491
- continue
492
- if token.startswith('city_'):
493
- selected_cities.append(token)
494
- else:
495
- selected_cities.append(f'city_{token}')
496
- if not selected_cities:
497
- selected_cities = None
498
-
499
- sys.argv = [sys.argv[0]] + remaining
500
- return enabled, selected_cities, args.normalization
501
-
502
-
503
- ENABLED_COMM_TYPES, ENABLED_CITY_PREFIXES, NORMALIZATION_MODE = _parse_standard_args()
504
- MAX_SCENARIOS = int(os.environ.get("LWM_MAX_SCENARIOS", "0")) or None
505
-
506
-
507
- def _extract_scenario_token(file_path):
508
- """Derive the base scenario token (without city) from the file path."""
509
- normalized_path = os.path.normpath(file_path)
510
- parts = normalized_path.split(os.sep)
511
-
512
- scenario_parts = []
513
- for i, part in enumerate(parts):
514
- if part in SUPPORTED_COMM_TYPES:
515
- trailing = parts[i:i + 5]
516
- if trailing:
517
- scenario_parts = trailing[:5]
518
- break
519
-
520
- if not scenario_parts:
521
- base_name = os.path.splitext(os.path.basename(file_path))[0]
522
- if base_name.startswith('spectrogram_'):
523
- tokens = base_name.split('_')[1:]
524
- if tokens and tokens[0] in SUPPORTED_COMM_TYPES:
525
- scenario_parts = tokens[:5] if len(tokens) >= 5 else tokens
526
-
527
- return '_'.join(scenario_parts) if scenario_parts else None
528
-
529
-
530
- @lru_cache(maxsize=1)
531
- def _collect_scenario_file_info():
532
- import glob
533
-
534
- scenario_entries = []
535
-
536
- # New MATLAB receiver pipeline output
537
- new_base = os.path.join('ls_data', 'MATLAB', 'receiver_pipeline')
538
- if os.path.isdir(new_base):
539
- patterns = [os.path.join(new_base, '*', '**', 'spectrogram_*.mat')]
540
- for pattern in patterns:
541
- for file_path in sorted(glob.glob(pattern, recursive=True)):
542
- norm = os.path.normpath(file_path)
543
- parts = norm.split(os.sep)
544
- try:
545
- idx = parts.index('receiver_pipeline')
546
- city_name = parts[idx + 1] if idx + 1 < len(parts) else 'receiver_pipeline'
547
- except ValueError:
548
- city_name = 'receiver_pipeline'
549
-
550
- base_token = _extract_scenario_token(file_path)
551
- if not base_token:
552
- continue
553
- comm_type = base_token.split('_', 1)[0]
554
- if comm_type not in ENABLED_COMM_TYPES:
555
- continue
556
- scenario_id = f"{city_name}::{base_token}"
557
- scenario_entries.append((scenario_id, file_path, city_name, base_token))
558
-
559
- # Legacy repo layouts under spectrograms/city_*
560
- import glob as _glob
561
- for city_dir in sorted(_glob.glob(os.path.join('spectrograms', 'city_*'))):
562
- if not os.path.isdir(city_dir):
563
- continue
564
- city_name = os.path.basename(city_dir)
565
- if ENABLED_CITY_PREFIXES:
566
- if not any(city_name.startswith(prefix) for prefix in ENABLED_CITY_PREFIXES):
567
- continue
568
- candidate_patterns = [
569
- os.path.join(city_dir, '**', 'complex_raw', '**', 'spectrogram_*.mat'),
570
- os.path.join(city_dir, '**', 'spectrogram_*.mat'),
571
- ]
572
- city_files = []
573
- seen_paths = set()
574
- for pattern in candidate_patterns:
575
- for file_path in sorted(_glob.glob(pattern, recursive=True)):
576
- if not file_path.lower().endswith('.mat'):
577
- continue
578
- if file_path in seen_paths:
579
- continue
580
- seen_paths.add(file_path)
581
- city_files.append(file_path)
582
-
583
- if not city_files:
584
- pattern = os.path.join(city_dir, '**', '512FFT', '**', 'spectrograms', '*.pkl')
585
- city_files = sorted(_glob.glob(pattern, recursive=True))
586
-
587
- for file_path in city_files:
588
- base_token = _extract_scenario_token(file_path)
589
- if not base_token:
590
- continue
591
- comm_type = base_token.split('_', 1)[0]
592
- if comm_type not in ENABLED_COMM_TYPES:
593
- continue
594
- scenario_id = f"{city_name}::{base_token}"
595
- scenario_entries.append((scenario_id, file_path, city_name, base_token))
596
-
597
- if MAX_SCENARIOS:
598
- scenario_entries = scenario_entries[:MAX_SCENARIOS]
599
-
600
- return scenario_entries
601
-
602
-
603
- def scenarios_list():
604
- scenario_entries = _collect_scenario_file_info()
605
-
606
- if not scenario_entries:
607
- print("⚠️ No spectrogram files found for pretraining.")
608
- return np.array([])
609
-
610
- print(f"Enabled communication types: {sorted(ENABLED_COMM_TYPES)}")
611
- if ENABLED_CITY_PREFIXES:
612
- print(f"Selected city prefixes: {sorted(ENABLED_CITY_PREFIXES)}")
613
- city_counts = Counter(entry[2] for entry in scenario_entries)
614
- print("Using scenarios from the following city datasets:")
615
- for city_name, count in city_counts.items():
616
- print(f" - {city_name}: {count} files")
617
-
618
- print(f"Total scenarios selected: {len(scenario_entries)}")
619
- return np.array([entry[0] for entry in scenario_entries])
620
-
621
-
622
- def scenario_prop():
623
- scenario_entries = _collect_scenario_file_info()
624
-
625
- row_column_users = {}
626
- for scenario_id, file_path, city_name, _ in scenario_entries:
627
- row_column_users[scenario_id] = {
628
- 'spectrogram_path': file_path,
629
- 'cache_path': os.path.join('spectrograms', city_name, 'spectrogram_cache_128x128.pkl')
630
- }
631
-
632
- return row_column_users
633
-
634
-
635
- # =============================================================================
636
- # 5. TRAINING PARAMETERS AND HYPERPARAMETERS
637
- # =============================================================================
638
-
639
- EPOCHS = 20
640
- BATCH_SIZE = 64
641
- VAL_BATCH_SIZE = 64
642
- WARMUP_EPOCHS = 5
643
- BASE_LR = 5e-4
644
- MIN_LR = 1e-5 # Base LR의 1/50 (was 1e-8, too small for effective learning)
645
-
646
- # Gradient accumulation for larger effective batch size
647
- ACCUMULATION_STEPS = 4 # Effective batch size = 64 × 4 = 256
648
-
649
- # Model parameters
650
- N_ROWS = 4
651
- N_COLUMNS = 4
652
- ELEMENT_LENGTH = N_ROWS * N_COLUMNS * 2 # Complex spectrograms
653
- D_MODEL = 128
654
- MAX_LEN = 1025
655
- N_LAYERS = 12
656
- device_idx = 0
657
- WEIGHT_DECAY = 0.05
658
- BETA1 = 0.9
659
- BETA2 = 0.999
660
- MASK_PERCENT = 0.6
661
- N_HEADS = 8
662
- DROPOUT = 0.1
663
-
664
- # Contrastive learning parameters
665
- PROJECTION_DIM = 128
666
- CONTRASTIVE_TEMPERATURE = 0.07
667
- CONTRASTIVE_WEIGHT_MODULATION = 50.0 # Increased from 0.5 to match MLM loss scale
668
- CONTRASTIVE_WEIGHT_MOBILITY = 30.0 # Increased from 0.3 to match MLM loss scale
669
- MLM_WEIGHT = 1.0
670
-
671
- print(f"📊 Model configuration for complex spectrograms with contrastive learning:")
672
- print(f" Patch size: {N_ROWS}x{N_COLUMNS}")
673
- print(f" Element length: {ELEMENT_LENGTH} (includes real+imag interleaving)")
674
- print(f" Max sequence length: {MAX_LEN}")
675
- print(f" Batch size: {BATCH_SIZE} (physical), {BATCH_SIZE * ACCUMULATION_STEPS} (effective)")
676
- print(f" Gradient accumulation steps: {ACCUMULATION_STEPS}")
677
- print(f" Projection dim: {PROJECTION_DIM}")
678
- print(f" Contrastive temperature: {CONTRASTIVE_TEMPERATURE}")
679
- print(f" Loss weights - MLM: {MLM_WEIGHT}, Modulation: {CONTRASTIVE_WEIGHT_MODULATION}, Mobility: {CONTRASTIVE_WEIGHT_MOBILITY}")
680
-
681
-
682
- # =============================================================================
683
- # 6. DATA GENERATION AND LOADING
684
- # =============================================================================
685
-
686
- scenarios = scenarios_list()
687
- scenario_properties = scenario_prop()
688
-
689
- train_spectrogram_chunks = []
690
- val_spectrogram_chunks = []
691
- train_label_chunks = []
692
- val_label_chunks = []
693
- train_meta_chunks = []
694
- val_meta_chunks = []
695
-
696
- print(f"📂 Loading {len(scenarios)} scenarios...")
697
-
698
- scenario_info_list = []
699
- missing_props = []
700
- for scenario in scenarios:
701
- props = scenario_properties.get(scenario)
702
- if props is None:
703
- missing_props.append(scenario)
704
- continue
705
- scenario_info_list.append((scenario, props["spectrogram_path"]))
706
-
707
- if missing_props:
708
- print("⚠️ Missing metadata for the following scenarios; skipping:")
709
- for scen in missing_props:
710
- print(f" - {scen}")
711
-
712
- print(f"📂 Loading {len(scenario_info_list)} scenarios using {MAX_WORKERS} workers...")
713
-
714
- successful_scenarios = 0
715
-
716
- # Parallel processing with progress bar
717
- from multiprocessing import Pool
718
- with Pool(processes=MAX_WORKERS) as pool:
719
- results = list(tqdm(
720
- pool.imap(process_single_scenario, scenario_info_list),
721
- total=len(scenario_info_list),
722
- desc="Processing scenarios",
723
- unit="scenario"
724
- ))
725
-
726
- for result in results:
727
- if result is not None:
728
- train_spectrogram_chunks.append(result['train_data'])
729
- val_spectrogram_chunks.append(result['val_data'])
730
- train_label_chunks.append(np.zeros(result['train_size'], dtype=np.int64))
731
- val_label_chunks.append(np.zeros(result['val_size'], dtype=np.int64))
732
- train_meta_chunks.append(result['train_meta'])
733
- val_meta_chunks.append(result['val_meta'])
734
- successful_scenarios += 1
735
-
736
- print(f"✅ Processing completed! Successful scenarios: {successful_scenarios}/{len(scenario_info_list)}")
737
-
738
- if not train_spectrogram_chunks or not val_spectrogram_chunks:
739
- raise ValueError("No spectrogram data collected; check scenario configuration.")
740
-
741
- print("🔄 Collating spectrogram arrays...")
742
- train_spectrograms = np.concatenate(train_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
743
- val_spectrograms = np.concatenate(val_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
744
- train_labels = np.concatenate(train_label_chunks, axis=0)
745
- val_labels = np.concatenate(val_label_chunks, axis=0)
746
-
747
-
748
- def _concat_metadata_dicts(dict_list):
749
- if not dict_list:
750
- return {}
751
- keys = dict_list[0].keys()
752
- return {k: np.concatenate([d[k] for d in dict_list], axis=0) for k in keys}
753
-
754
-
755
- train_metadata = _concat_metadata_dicts(train_meta_chunks)
756
- val_metadata = _concat_metadata_dicts(val_meta_chunks)
757
-
758
- del train_spectrogram_chunks, val_spectrogram_chunks, train_label_chunks, val_label_chunks
759
- del train_meta_chunks, val_meta_chunks
760
-
761
- print(f"Training spectrograms shape: {train_spectrograms.shape}")
762
- print(f"Validation spectrograms shape: {val_spectrograms.shape}")
763
- print(f"Memory usage: {train_spectrograms.nbytes + val_spectrograms.nbytes:,} bytes")
764
-
765
- # Print metadata statistics
766
- print(f"\n📊 Metadata statistics:")
767
- print(f" Discovered modulation schemes: {len(MODULATION_MAP)}")
768
- for mod_name, mod_id in sorted(MODULATION_MAP.items(), key=lambda x: x[1]):
769
- count_train = np.sum(train_metadata['modulation_id'] == mod_id)
770
- count_val = np.sum(val_metadata['modulation_id'] == mod_id)
771
- print(f" {mod_name} (ID={mod_id}): {count_train} train, {count_val} val samples")
772
-
773
- print(f"\n Modulation distribution (train):")
774
- for mod_id in np.unique(train_metadata['modulation_id']):
775
- count = np.sum(train_metadata['modulation_id'] == mod_id)
776
- mod_name = MODULATION_INV.get(mod_id, f"Unknown({mod_id})")
777
- print(f" {mod_name}: {count} samples ({100*count/len(train_metadata['modulation_id']):.1f}%)")
778
-
779
- print(f" Mobility distribution (train):")
780
- for mob_id in np.unique(train_metadata['doppler_id']):
781
- count = np.sum(train_metadata['doppler_id'] == mob_id)
782
- mob_name = DOPPLER_INV.get(mob_id, f"Unknown({mob_id})")
783
- print(f" {mob_name}: {count} samples ({100*count/len(train_metadata['doppler_id']):.1f}%)")
784
-
785
- train_mean = float(train_spectrograms.mean())
786
- train_std = float(train_spectrograms.std())
787
- if abs(train_std) < 1e-6:
788
- print("⚠️ Training std near zero, using epsilon for stability")
789
- train_std = 1e-6
790
- dataset_normalization = {'mean': train_mean, 'std': train_std, 'normalization': NORMALIZATION_MODE}
791
- print(f"Dataset normalization stats -> mean: {train_mean:.4f}, std: {train_std:.4f}")
792
-
793
-
794
- # =============================================================================
795
- # 7. BUILD INDEX FOR CONTRASTIVE SAMPLING
796
- # =============================================================================
797
-
798
- def build_class_indices(metadata: Dict[str, np.ndarray]) -> Tuple[Dict, Dict]:
799
- """
800
- Build index mapping from modulation/mobility ID to sample indices.
801
- """
802
- indices_by_modulation = {}
803
- indices_by_mobility = {}
804
-
805
- for idx in range(len(metadata['modulation_id'])):
806
- mod_id = int(metadata['modulation_id'][idx])
807
- mob_id = int(metadata['doppler_id'][idx])
808
-
809
- if mod_id not in indices_by_modulation:
810
- indices_by_modulation[mod_id] = []
811
- indices_by_modulation[mod_id].append(idx)
812
-
813
- if mob_id not in indices_by_mobility:
814
- indices_by_mobility[mob_id] = []
815
- indices_by_mobility[mob_id].append(idx)
816
-
817
- return indices_by_modulation, indices_by_mobility
818
-
819
-
820
- print("🔍 Building class indices for contrastive learning...")
821
- train_indices_by_modulation, train_indices_by_mobility = build_class_indices(train_metadata)
822
- val_indices_by_modulation, val_indices_by_mobility = build_class_indices(val_metadata)
823
- print("✅ Class indices built successfully!")
824
-
825
-
826
- # =============================================================================
827
- # 8. DATA TOKENIZATION
828
- # =============================================================================
829
-
830
- print("🔄 Starting tokenization of training data...")
831
- preprocessed_train = tokenizer_train(
832
- train_spectrograms,
833
- max_len=MAX_LEN,
834
- masking_percent=MASK_PERCENT,
835
- mask=True,
836
- seed=42,
837
- metadata=train_metadata,
838
- dataset_stats=dataset_normalization,
839
- normalization=NORMALIZATION_MODE,
840
- interleaved=True,
841
- )
842
- print("✅ Training data tokenization completed!")
843
-
844
- print("🔄 Starting tokenization of validation data...")
845
- preprocessed_val = tokenizer_train(
846
- val_spectrograms,
847
- max_len=MAX_LEN,
848
- masking_percent=MASK_PERCENT,
849
- mask=True,
850
- seed=42,
851
- metadata=val_metadata,
852
- dataset_stats=dataset_normalization,
853
- normalization=NORMALIZATION_MODE,
854
- interleaved=True,
855
- )
856
- print("✅ Validation data tokenization completed!")
857
-
858
-
859
- # =============================================================================
860
- # 9. TRAIN/VALIDATION DATA SETUP
861
- # =============================================================================
862
-
863
- SEED = 42
864
- torch.manual_seed(SEED)
865
- np.random.seed(SEED)
866
-
867
- train_data = preprocessed_train
868
- val_data = preprocessed_val
869
-
870
-
871
- # =============================================================================
872
- # 10. DATALOADER CREATION
873
- # =============================================================================
874
-
875
- print("🔧 Creating data loaders...")
876
-
877
- if isinstance(train_data, dict):
878
- print(f" Training data format: dict with {len(train_data)} sequence lengths")
879
- train_loaders = create_train_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
880
- else:
881
- print(f" Training data format: tensor with shape {train_data.shape}")
882
- train_dataset = TensorDataset(train_data)
883
- train_loaders = {'seq_0': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)}
884
-
885
- if isinstance(val_data, dict):
886
- print(f" Validation data format: dict with {len(val_data)} sequence lengths")
887
- val_loaders = create_train_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
888
- else:
889
- print(f" Validation data format: tensor with shape {val_data.shape}")
890
- val_dataset = TensorDataset(val_data)
891
- val_loaders = {'seq_0': DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)}
892
-
893
- print("✅ Data loaders created successfully!")
894
-
895
-
896
- # =============================================================================
897
- # 11. MODEL INITIALIZATION
898
- # =============================================================================
899
-
900
- print("🔧 Setting up device and GPU configuration...")
901
-
902
- if torch.cuda.is_available():
903
- device_count = torch.cuda.device_count()
904
- print(f" CUDA available: {device_count} GPU(s) detected")
905
- device = torch.device("cuda:0")
906
- gpu_ids = list(range(device_count))
907
- print(f" Using CUDA GPUs: {gpu_ids}")
908
-
909
- for i in gpu_ids:
910
- try:
911
- mem_total = torch.cuda.get_device_properties(i).total_memory / 1024**3
912
- mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
913
- print(f" GPU {i}: Total: {mem_total:.1f}GB, Allocated: {mem_allocated:.1f}GB")
914
- except Exception as e:
915
- print(f" GPU {i}: Error getting memory info - {e}")
916
-
917
- elif torch.backends.mps.is_available():
918
- device = torch.device("mps")
919
- gpu_ids = []
920
- print(" Using MPS (Apple Silicon GPU)")
921
- else:
922
- device = torch.device("cpu")
923
- gpu_ids = []
924
- print(" Using CPU")
925
-
926
- print(f" Final device: {device}")
927
- print(f" GPU IDs for DataParallel: {gpu_ids}")
928
-
929
- print("🤖 Initializing LWM model with contrastive learning...")
930
- print(f" Model parameters: element_length={ELEMENT_LENGTH}, d_model={D_MODEL}, n_layers={N_LAYERS}, max_len={MAX_LEN}, n_heads={N_HEADS}")
931
-
932
- try:
933
- # Create base LWM encoder
934
- lwm_encoder = pretrained_model.lwm(
935
- element_length=ELEMENT_LENGTH,
936
- d_model=D_MODEL,
937
- n_layers=N_LAYERS,
938
- max_len=MAX_LEN,
939
- n_heads=N_HEADS,
940
- dropout=DROPOUT
941
- )
942
-
943
- # Wrap with contrastive learning module
944
- # MLM head must output patch dimension (ELEMENT_LENGTH), not full spectrogram width
945
- # Each token represents a 4×4×2 patch = 32 elements
946
- model = ContrastiveLWM(lwm_encoder, projection_dim=PROJECTION_DIM, input_dim=ELEMENT_LENGTH)
947
- print(f" ✅ Model created with input_dim={ELEMENT_LENGTH} (patch dimension)")
948
-
949
- print(f" Moving model to device: {device}")
950
- if 'mps' in str(device):
951
- model = model.to(device).float()
952
- print(" ✅ Model moved to MPS device (float32)")
953
- else:
954
- model = model.to(device)
955
- print(" ✅ Model moved to device successfully")
956
-
957
- except Exception as e:
958
- print(f" ❌ Model initialization failed: {e}")
959
- import traceback
960
- traceback.print_exc()
961
- exit(1)
962
-
963
- # Use DataParallel for multi-GPU support
964
- if gpu_ids:
965
- model = nn.DataParallel(model, device_ids=gpu_ids)
966
- print(f"Model loaded successfully on GPU {device.index}")
967
- else:
968
- print(f"Model loaded successfully on {device}")
969
-
970
- n_parameters = count_parameters(model)
971
- print(f"Number of trainable parameters: {n_parameters:,}")
972
-
973
-
974
- # =============================================================================
975
- # 12. OPTIMIZER AND LEARNING RATE SCHEDULER
976
- # =============================================================================
977
-
978
- # Account for gradient accumulation: scheduler step is called once per ACCUMULATION_STEPS batches
979
- # So actual optimizer steps = total_batches / ACCUMULATION_STEPS
980
- total_batches_per_epoch = sum(len(loader) for loader in train_loaders.values())
981
- actual_steps_per_epoch = math.ceil(total_batches_per_epoch / ACCUMULATION_STEPS)
982
- TOTAL_STEPS = actual_steps_per_epoch * EPOCHS
983
- WARMUP_STEPS = actual_steps_per_epoch * WARMUP_EPOCHS
984
-
985
- print(f"📊 Learning rate schedule:")
986
- print(f" Total batches per epoch: {total_batches_per_epoch}")
987
- print(f" Accumulation steps: {ACCUMULATION_STEPS}")
988
- print(f" Actual optimizer steps per epoch: {actual_steps_per_epoch}")
989
- print(f" Total training steps: {TOTAL_STEPS}")
990
- print(f" Warmup steps: {WARMUP_STEPS}")
991
-
992
- optimizer = AdamW(
993
- model.parameters(),
994
- lr=BASE_LR,
995
- betas=(BETA1, BETA2),
996
- weight_decay=WEIGHT_DECAY
997
- )
998
-
999
-
1000
- def lr_lambda(current_step):
1001
- if current_step < WARMUP_STEPS:
1002
- return current_step / WARMUP_STEPS
1003
- else:
1004
- scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
1005
- cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
1006
- return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
1007
-
1008
-
1009
- scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
1010
-
1011
-
1012
- # =============================================================================
1013
- # 13. TRAINING LOOP WITH CONTRASTIVE LEARNING
1014
- # =============================================================================
1015
-
1016
- def train_epoch_contrastive(
1017
- model,
1018
- train_loaders,
1019
- optimizer,
1020
- scheduler,
1021
- device,
1022
- epoch,
1023
- train_metadata
1024
- ):
1025
- """
1026
- Train one epoch with MLM + Contrastive Learning with Gradient Accumulation.
1027
- """
1028
- model.train()
1029
- total_mlm_loss = 0.0
1030
- total_contrastive_mod_loss = 0.0
1031
- total_contrastive_mob_loss = 0.0
1032
- total_loss = 0.0
1033
- total_batches = 0
1034
-
1035
- criterion = nn.MSELoss(reduction='sum')
1036
-
1037
- # Initialize gradient accumulation
1038
- optimizer.zero_grad()
1039
- accumulation_counter = 0
1040
-
1041
- for seq_key, loader in train_loaders.items():
1042
- for batch_idx, batch in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}", leave=False)):
1043
- # Unpack batch - expect (input_ids, masked_tokens, masked_pos, snr_db, doppler_id, power_stats, snr_id, modulation_id)
1044
- if len(batch) >= 8:
1045
- input_ids = batch[0].to(device)
1046
- masked_tokens = batch[1].to(device)
1047
- masked_pos = batch[2].to(device)
1048
- snr_db = batch[3].to(device)
1049
- doppler_id = batch[4].to(device)
1050
- power_stats = batch[5].to(device)
1051
- snr_id = batch[6].to(device)
1052
- modulation_id = batch[7].to(device)
1053
- has_metadata = True
1054
- elif len(batch) == 3:
1055
- input_ids = batch[0].to(device)
1056
- masked_tokens = batch[1].to(device)
1057
- masked_pos = batch[2].to(device)
1058
- has_metadata = False
1059
- else:
1060
- input_ids = batch[0].to(device)
1061
- has_metadata = False
1062
-
1063
- # Forward pass with projections
1064
- mlm_predictions, z_mod, z_mob = model(input_ids, return_projections=True)
1065
-
1066
- # MLM Loss (reconstruction)
1067
- if len(batch) >= 3 and masked_tokens.numel() > 0:
1068
- batch_size = input_ids.size(0)
1069
- mlm_loss = 0.0
1070
-
1071
- for i in range(batch_size):
1072
- # Get masked positions for this sample
1073
- sample_masked_pos = masked_pos[i]
1074
- sample_masked_tokens = masked_tokens[i]
1075
-
1076
- # Skip if no masked positions
1077
- if sample_masked_pos.numel() == 0:
1078
- continue
1079
-
1080
- # Get predictions at masked positions
1081
- predictions = mlm_predictions[i, sample_masked_pos, :]
1082
- targets = sample_masked_tokens
1083
-
1084
- # Ensure shapes match
1085
- if predictions.size(0) != targets.size(0):
1086
- # Adjust if needed
1087
- min_len = min(predictions.size(0), targets.size(0))
1088
- predictions = predictions[:min_len]
1089
- targets = targets[:min_len]
1090
-
1091
- # MSE loss
1092
- mlm_loss += criterion(predictions, targets)
1093
-
1094
- mlm_loss = mlm_loss / batch_size if batch_size > 0 else torch.tensor(0.0, device=device)
1095
- else:
1096
- mlm_loss = torch.zeros(1, device=device)
1097
-
1098
- # Contrastive losses (only if we have metadata)
1099
- if has_metadata:
1100
- # DEBUG: Print batch statistics
1101
- if batch_idx == 0 and epoch == 0: # Only first batch of first epoch
1102
- print(f"\n🔍 DEBUG - Batch analysis:")
1103
- print(f" Batch size: {modulation_id.size(0)}")
1104
- print(f" Modulation IDs: {modulation_id.cpu().numpy()}")
1105
- print(f" Unique modulations: {torch.unique(modulation_id).cpu().numpy()}")
1106
- print(f" Doppler IDs: {doppler_id.cpu().numpy()}")
1107
- print(f" Unique doppler: {torch.unique(doppler_id).cpu().numpy()}")
1108
-
1109
- # Modulation contrastive loss
1110
- # Filter out unknown modulations (-1)
1111
- valid_mod_mask = modulation_id >= 0
1112
- if valid_mod_mask.sum() > 1: # Need at least 2 samples
1113
- z_mod_valid = z_mod[valid_mod_mask]
1114
- mod_labels_valid = modulation_id[valid_mod_mask]
1115
-
1116
- # Check if we have positive pairs
1117
- unique_mods, counts = torch.unique(mod_labels_valid, return_counts=True)
1118
- has_positive_pairs = (counts > 1).any()
1119
-
1120
- if has_positive_pairs:
1121
- contrastive_mod_loss = supervised_contrastive_loss(
1122
- z_mod_valid,
1123
- mod_labels_valid,
1124
- temperature=CONTRASTIVE_TEMPERATURE
1125
- )
1126
- if batch_idx == 0 and epoch == 0:
1127
- print(f" Modulation contrastive loss: {contrastive_mod_loss.item():.4f}")
1128
- else:
1129
- contrastive_mod_loss = torch.zeros(1, device=device)
1130
- if batch_idx == 0 and epoch == 0:
1131
- print(f" No positive pairs for modulation - loss set to 0")
1132
- else:
1133
- contrastive_mod_loss = torch.zeros(1, device=device)
1134
- if batch_idx == 0 and epoch == 0:
1135
- print(f" Not enough valid modulation samples - loss set to 0")
1136
-
1137
- # Mobility contrastive loss
1138
- z_mob_valid = z_mob
1139
- mob_labels_valid = doppler_id
1140
- if mob_labels_valid.numel() > 1:
1141
- unique_mobs, counts = torch.unique(mob_labels_valid, return_counts=True)
1142
- has_positive_pairs = (counts > 1).any()
1143
-
1144
- if has_positive_pairs:
1145
- contrastive_mob_loss = supervised_contrastive_loss(
1146
- z_mob_valid,
1147
- mob_labels_valid,
1148
- temperature=CONTRASTIVE_TEMPERATURE
1149
- )
1150
- if batch_idx == 0 and epoch == 0:
1151
- print(f" Mobility contrastive loss: {contrastive_mob_loss.item():.4f}")
1152
- else:
1153
- contrastive_mob_loss = torch.zeros(1, device=device)
1154
- if batch_idx == 0 and epoch == 0:
1155
- print(f" No positive pairs for mobility - loss set to 0")
1156
- else:
1157
- contrastive_mob_loss = torch.zeros(1, device=device)
1158
- if batch_idx == 0 and epoch == 0:
1159
- print(f" Not enough mobility samples - loss set to 0")
1160
- else:
1161
- contrastive_mod_loss = torch.zeros(1, device=device)
1162
- contrastive_mob_loss = torch.zeros(1, device=device)
1163
-
1164
- # Combined loss
1165
- loss = (
1166
- MLM_WEIGHT * mlm_loss +
1167
- CONTRASTIVE_WEIGHT_MODULATION * contrastive_mod_loss +
1168
- CONTRASTIVE_WEIGHT_MOBILITY * contrastive_mob_loss
1169
- )
1170
-
1171
- # Normalize loss by accumulation steps
1172
- loss = loss / ACCUMULATION_STEPS
1173
-
1174
- # Backward pass (accumulate gradients)
1175
- loss.backward()
1176
-
1177
- # Accumulate losses (denormalized for logging)
1178
- total_mlm_loss += mlm_loss.item()
1179
- total_contrastive_mod_loss += contrastive_mod_loss.item()
1180
- total_contrastive_mob_loss += contrastive_mob_loss.item()
1181
- total_loss += (loss.item() * ACCUMULATION_STEPS) # Denormalize for logging
1182
- total_batches += 1
1183
- accumulation_counter += 1
1184
-
1185
- # Perform optimizer step every ACCUMULATION_STEPS
1186
- if accumulation_counter % ACCUMULATION_STEPS == 0:
1187
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
1188
- optimizer.step()
1189
- scheduler.step()
1190
- optimizer.zero_grad()
1191
- accumulation_counter = 0
1192
-
1193
- # Handle remaining gradients if total batches not divisible by ACCUMULATION_STEPS
1194
- if accumulation_counter > 0:
1195
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
1196
- optimizer.step()
1197
- scheduler.step()
1198
- optimizer.zero_grad()
1199
-
1200
- # Average losses
1201
- avg_mlm_loss = total_mlm_loss / total_batches if total_batches > 0 else 0
1202
- avg_contrastive_mod_loss = total_contrastive_mod_loss / total_batches if total_batches > 0 else 0
1203
- avg_contrastive_mob_loss = total_contrastive_mob_loss / total_batches if total_batches > 0 else 0
1204
- avg_total_loss = total_loss / total_batches if total_batches > 0 else 0
1205
-
1206
- return {
1207
- 'mlm_loss': avg_mlm_loss,
1208
- 'contrastive_mod_loss': avg_contrastive_mod_loss,
1209
- 'contrastive_mob_loss': avg_contrastive_mob_loss,
1210
- 'total_loss': avg_total_loss
1211
- }
1212
-
1213
-
1214
- def validate_epoch_contrastive(
1215
- model,
1216
- val_loaders,
1217
- device,
1218
- epoch
1219
- ):
1220
- """
1221
- Validate one epoch with MLM + Contrastive Learning.
1222
- """
1223
- model.eval()
1224
- total_mlm_loss = 0.0
1225
- total_contrastive_mod_loss = 0.0
1226
- total_contrastive_mob_loss = 0.0
1227
- total_loss = 0.0
1228
- total_batches = 0
1229
-
1230
- criterion = nn.MSELoss(reduction='sum')
1231
-
1232
- with torch.no_grad():
1233
- for seq_key, loader in val_loaders.items():
1234
- for batch_idx, batch in enumerate(loader):
1235
- # Unpack batch
1236
- if len(batch) >= 8:
1237
- input_ids = batch[0].to(device)
1238
- masked_tokens = batch[1].to(device)
1239
- masked_pos = batch[2].to(device)
1240
- snr_db = batch[3].to(device)
1241
- doppler_id = batch[4].to(device)
1242
- power_stats = batch[5].to(device)
1243
- snr_id = batch[6].to(device)
1244
- modulation_id = batch[7].to(device)
1245
- has_metadata = True
1246
- elif len(batch) == 3:
1247
- input_ids = batch[0].to(device)
1248
- masked_tokens = batch[1].to(device)
1249
- masked_pos = batch[2].to(device)
1250
- has_metadata = False
1251
- else:
1252
- input_ids = batch[0].to(device)
1253
- has_metadata = False
1254
-
1255
- # Forward pass
1256
- mlm_predictions, z_mod, z_mob = model(input_ids, return_projections=True)
1257
-
1258
- # MLM Loss
1259
- if len(batch) >= 3 and masked_tokens.numel() > 0:
1260
- batch_size = input_ids.size(0)
1261
- mlm_loss = 0.0
1262
-
1263
- for i in range(batch_size):
1264
- sample_masked_pos = masked_pos[i]
1265
- sample_masked_tokens = masked_tokens[i]
1266
-
1267
- if sample_masked_pos.numel() == 0:
1268
- continue
1269
-
1270
- predictions = mlm_predictions[i, sample_masked_pos, :]
1271
- targets = sample_masked_tokens
1272
-
1273
- if predictions.size(0) != targets.size(0):
1274
- min_len = min(predictions.size(0), targets.size(0))
1275
- predictions = predictions[:min_len]
1276
- targets = targets[:min_len]
1277
-
1278
- mlm_loss += criterion(predictions, targets)
1279
-
1280
- mlm_loss = mlm_loss / batch_size if batch_size > 0 else torch.tensor(0.0, device=device)
1281
- else:
1282
- mlm_loss = torch.zeros(1, device=device)
1283
-
1284
- # Contrastive losses
1285
- if has_metadata:
1286
- valid_mod_mask = modulation_id >= 0
1287
- if valid_mod_mask.sum() > 1:
1288
- z_mod_valid = z_mod[valid_mod_mask]
1289
- mod_labels_valid = modulation_id[valid_mod_mask]
1290
- contrastive_mod_loss = supervised_contrastive_loss(
1291
- z_mod_valid,
1292
- mod_labels_valid,
1293
- temperature=CONTRASTIVE_TEMPERATURE
1294
- )
1295
- else:
1296
- contrastive_mod_loss = torch.zeros(1, device=device)
1297
-
1298
- if doppler_id.numel() > 1:
1299
- contrastive_mob_loss = supervised_contrastive_loss(
1300
- z_mob,
1301
- doppler_id,
1302
- temperature=CONTRASTIVE_TEMPERATURE
1303
- )
1304
- else:
1305
- contrastive_mob_loss = torch.zeros(1, device=device)
1306
- else:
1307
- contrastive_mod_loss = torch.zeros(1, device=device)
1308
- contrastive_mob_loss = torch.zeros(1, device=device)
1309
-
1310
- loss = (
1311
- MLM_WEIGHT * mlm_loss +
1312
- CONTRASTIVE_WEIGHT_MODULATION * contrastive_mod_loss +
1313
- CONTRASTIVE_WEIGHT_MOBILITY * contrastive_mob_loss
1314
- )
1315
-
1316
- total_mlm_loss += mlm_loss.item()
1317
- total_contrastive_mod_loss += contrastive_mod_loss.item()
1318
- total_contrastive_mob_loss += contrastive_mob_loss.item()
1319
- total_loss += loss.item()
1320
- total_batches += 1
1321
-
1322
- avg_mlm_loss = total_mlm_loss / total_batches if total_batches > 0 else 0
1323
- avg_contrastive_mod_loss = total_contrastive_mod_loss / total_batches if total_batches > 0 else 0
1324
- avg_contrastive_mob_loss = total_contrastive_mob_loss / total_batches if total_batches > 0 else 0
1325
- avg_total_loss = total_loss / total_batches if total_batches > 0 else 0
1326
-
1327
- return {
1328
- 'mlm_loss': avg_mlm_loss,
1329
- 'contrastive_mod_loss': avg_contrastive_mod_loss,
1330
- 'contrastive_mob_loss': avg_contrastive_mob_loss,
1331
- 'total_loss': avg_total_loss
1332
- }
1333
-
1334
-
1335
- # =============================================================================
1336
- # 14. MAIN TRAINING LOOP
1337
- # =============================================================================
1338
-
1339
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1340
- save_dir = f"models/{timestamp}_contrastive"
1341
- print(f"📁 Models and logs will be saved to: {save_dir}")
1342
- os.makedirs(save_dir, exist_ok=True)
1343
-
1344
- stats_path = os.path.join(save_dir, "dataset_stats.json")
1345
- with open(stats_path, 'w') as f:
1346
- json.dump(dataset_normalization, f, indent=2)
1347
- print(f"📝 Saved dataset stats to {stats_path}")
1348
-
1349
- # Save training configuration
1350
- config = {
1351
- 'epochs': EPOCHS,
1352
- 'batch_size': BATCH_SIZE,
1353
- 'effective_batch_size': BATCH_SIZE * ACCUMULATION_STEPS,
1354
- 'accumulation_steps': ACCUMULATION_STEPS,
1355
- 'learning_rate': BASE_LR,
1356
- 'element_length': ELEMENT_LENGTH,
1357
- 'd_model': D_MODEL,
1358
- 'n_layers': N_LAYERS,
1359
- 'n_heads': N_HEADS,
1360
- 'projection_dim': PROJECTION_DIM,
1361
- 'contrastive_temperature': CONTRASTIVE_TEMPERATURE,
1362
- 'mlm_weight': MLM_WEIGHT,
1363
- 'contrastive_weight_modulation': CONTRASTIVE_WEIGHT_MODULATION,
1364
- 'contrastive_weight_mobility': CONTRASTIVE_WEIGHT_MOBILITY,
1365
- 'modulation_map': MODULATION_MAP,
1366
- 'doppler_map': DOPPLER_MAP,
1367
- 'num_modulations': len(MODULATION_MAP),
1368
- }
1369
- config_path = os.path.join(save_dir, "config.json")
1370
- with open(config_path, 'w') as f:
1371
- json.dump(config, f, indent=2)
1372
- print(f"📝 Saved training config to {config_path}")
1373
-
1374
- # Training log
1375
- log_path = os.path.join(save_dir, "training_log.csv")
1376
- with open(log_path, 'w') as f:
1377
- f.write("epoch,train_mlm_loss,train_contrastive_mod_loss,train_contrastive_mob_loss,train_total_loss,")
1378
- f.write("val_mlm_loss,val_contrastive_mod_loss,val_contrastive_mob_loss,val_total_loss,learning_rate\n")
1379
-
1380
- print("\n" + "="*80)
1381
- print("🚀 Starting training with contrastive learning!")
1382
- print("="*80 + "\n")
1383
-
1384
- if __name__ == "__main__":
1385
- best_val_loss = float('inf')
1386
-
1387
- for epoch in range(EPOCHS):
1388
- print(f"\n{'='*80}")
1389
- print(f"Epoch {epoch+1}/{EPOCHS}")
1390
- print(f"{'='*80}")
1391
-
1392
- # Train
1393
- train_metrics = train_epoch_contrastive(
1394
- model, train_loaders, optimizer, scheduler, device, epoch, train_metadata
1395
- )
1396
-
1397
- # Validate
1398
- val_metrics = validate_epoch_contrastive(
1399
- model, val_loaders, device, epoch
1400
- )
1401
-
1402
- # Log metrics
1403
- current_lr = optimizer.param_groups[0]['lr']
1404
- print(f"\nEpoch {epoch+1} Results:")
1405
- print(f" Train - MLM: {train_metrics['mlm_loss']:.4f}, "
1406
- f"ContrastMod: {train_metrics['contrastive_mod_loss']:.4f}, "
1407
- f"ContrastMob: {train_metrics['contrastive_mob_loss']:.4f}, "
1408
- f"Total: {train_metrics['total_loss']:.4f}")
1409
- print(f" Val - MLM: {val_metrics['mlm_loss']:.4f}, "
1410
- f"ContrastMod: {val_metrics['contrastive_mod_loss']:.4f}, "
1411
- f"ContrastMob: {val_metrics['contrastive_mob_loss']:.4f}, "
1412
- f"Total: {val_metrics['total_loss']:.4f}")
1413
- print(f" Learning Rate: {current_lr:.6f}")
1414
-
1415
- # Save to log
1416
- with open(log_path, 'a') as f:
1417
- f.write(f"{epoch+1},{train_metrics['mlm_loss']:.6f},"
1418
- f"{train_metrics['contrastive_mod_loss']:.6f},"
1419
- f"{train_metrics['contrastive_mob_loss']:.6f},"
1420
- f"{train_metrics['total_loss']:.6f},"
1421
- f"{val_metrics['mlm_loss']:.6f},"
1422
- f"{val_metrics['contrastive_mod_loss']:.6f},"
1423
- f"{val_metrics['contrastive_mob_loss']:.6f},"
1424
- f"{val_metrics['total_loss']:.6f},"
1425
- f"{current_lr:.8f}\n")
1426
-
1427
- # Save best model
1428
- if val_metrics['total_loss'] < best_val_loss:
1429
- best_val_loss = val_metrics['total_loss']
1430
- checkpoint_path = os.path.join(save_dir, "best_model_contrastive.pth")
1431
- if isinstance(model, nn.DataParallel):
1432
- torch.save(model.module.state_dict(), checkpoint_path)
1433
- else:
1434
- torch.save(model.state_dict(), checkpoint_path)
1435
- print(f" ✅ Saved best model to {checkpoint_path}")
1436
-
1437
- # Save periodic checkpoint
1438
- if (epoch + 1) % 5 == 0:
1439
- checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch{epoch+1}_contrastive.pth")
1440
- if isinstance(model, nn.DataParallel):
1441
- torch.save(model.module.state_dict(), checkpoint_path)
1442
- else:
1443
- torch.save(model.state_dict(), checkpoint_path)
1444
- print(f" 💾 Saved checkpoint to {checkpoint_path}")
1445
-
1446
- print("\n" + "="*80)
1447
- print("🎉 Training completed successfully!")
1448
- print(f"📁 Models saved to: {save_dir}")
1449
- print(f"📊 Training log: {log_path}")
1450
- print("="*80 + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretraining/train_lwm_spectro_no_contrast.py DELETED
@@ -1,1136 +0,0 @@
1
- #!/usr/bin/env python3
2
- # =============================================================================
3
- # 1. IMPORTS AND WARNINGS SETUP
4
- # - Load necessary PyTorch modules, utilities, and suppress UserWarnings
5
- # =============================================================================
6
- import sys
7
- import os
8
- import argparse
9
- # Add project root to path (Windows compatible)
10
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
- sys.path.insert(0, project_root)
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- from torch.utils.data import DataLoader, IterableDataset
16
- import torch.distributed as dist
17
- import torch.optim as optim
18
- from utils import (generate_spectrograms_and_labels, tokenizer_train,
19
- count_parameters, train_lwm)
20
- import numpy as np
21
- import pretrained_model # Assuming this contains the LWM model definition
22
- from torch.optim.lr_scheduler import LambdaLR
23
- from torch.optim import AdamW
24
- import warnings
25
- import platform
26
- import re
27
- from tqdm import tqdm
28
- from datetime import datetime
29
- import concurrent.futures
30
- import multiprocessing
31
- from collections import Counter
32
- from functools import lru_cache
33
- import json
34
- import random
35
- import math
36
- from typing import Any, Dict, Optional, List, Tuple
37
- import time
38
-
39
-
40
- LOG_ALL_RANKS = False
41
-
42
- SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
43
- DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
44
- DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
45
-
46
-
47
- def _is_hpu_available() -> bool:
48
- hpu = getattr(torch, "hpu", None)
49
- if hpu is None:
50
- return False
51
- is_available = getattr(hpu, "is_available", None)
52
- available = False
53
- if callable(is_available):
54
- try:
55
- available = bool(is_available())
56
- except Exception:
57
- available = False
58
- if not available:
59
- # Try initializing the Habana runtime lazily
60
- try:
61
- import habana_frameworks.torch.core as htcore # type: ignore
62
-
63
- init_fn = getattr(htcore, "hpu_initialize", None)
64
- if callable(init_fn):
65
- init_fn()
66
- else:
67
- inference_init = getattr(htcore, "hpu_inference_initialize", None)
68
- if callable(inference_init):
69
- inference_init()
70
- available = bool(is_available())
71
- except Exception:
72
- available = False
73
- return available
74
-
75
-
76
- def _get_hpu_device_count() -> int:
77
- hpu = getattr(torch, "hpu", None)
78
- if hpu is None:
79
- return 0
80
- device_count_fn = getattr(hpu, "device_count", None)
81
- if callable(device_count_fn):
82
- try:
83
- return int(device_count_fn())
84
- except Exception:
85
- return 0
86
- return 1 if _is_hpu_available() else 0
87
-
88
-
89
- def _initialize_distributed(hpu_available: bool, backend_override: Optional[str] = None) -> Dict[str, Any]:
90
- context: Dict[str, Any] = {
91
- "is_distributed": False,
92
- "backend": None,
93
- "rank": 0,
94
- "world_size": 1,
95
- "local_rank": 0,
96
- "is_primary": True,
97
- }
98
- if not dist.is_available():
99
- return context
100
-
101
- required_env = ("RANK", "WORLD_SIZE")
102
- if not all(key in os.environ for key in required_env):
103
- return context
104
-
105
- if dist.is_initialized():
106
- context["is_distributed"] = True
107
- context["backend"] = dist.get_backend()
108
- context["rank"] = dist.get_rank()
109
- context["world_size"] = dist.get_world_size()
110
- context["local_rank"] = int(os.environ.get("LOCAL_RANK", context["rank"]))
111
- context["is_primary"] = context["rank"] == 0
112
- return context
113
-
114
- backend = backend_override or os.environ.get("LWM_DISTRIBUTED_BACKEND")
115
- if not backend:
116
- if hpu_available:
117
- backend = "hccl"
118
- elif torch.cuda.is_available():
119
- backend = "nccl"
120
- else:
121
- backend = "gloo"
122
-
123
- dist.init_process_group(backend=backend, init_method="env://")
124
-
125
- context["is_distributed"] = True
126
- context["backend"] = backend
127
- context["rank"] = dist.get_rank()
128
- context["world_size"] = dist.get_world_size()
129
- context["local_rank"] = int(os.environ.get("LOCAL_RANK", context["rank"]))
130
- context["is_primary"] = context["rank"] == 0
131
- return context
132
-
133
-
134
- def _broadcast_object(obj: Any, src: int = 0) -> Any:
135
- if not dist.is_available() or not dist.is_initialized():
136
- return obj
137
- object_list = [obj]
138
- dist.broadcast_object_list(object_list, src=src)
139
- return object_list[0]
140
-
141
-
142
- def _should_log(context: Dict[str, Any]) -> bool:
143
- return LOG_ALL_RANKS or (not context.get("is_distributed")) or context.get("is_primary", True)
144
-
145
-
146
- def _barrier(context: Dict[str, Any]) -> None:
147
- if context.get("is_distributed") and dist.is_available() and dist.is_initialized():
148
- dist.barrier()
149
-
150
-
151
- def _parse_snr_and_doppler(path: str) -> tuple[float, int]:
152
- snr_db = 0.0
153
- doppler_id = 0
154
-
155
- matches = SNR_PATTERN.findall(path)
156
- if matches:
157
- try:
158
- snr_db = float(matches[-1])
159
- except ValueError:
160
- snr_db = 0.0
161
-
162
- normalized_path = os.path.normpath(path)
163
- parts = normalized_path.split(os.sep)
164
- for part in parts:
165
- if part in DOPPLER_MAP:
166
- doppler_id = DOPPLER_MAP[part]
167
- break
168
-
169
- return snr_db, doppler_id
170
-
171
- def _parse_runtime_args():
172
- parser = argparse.ArgumentParser(add_help=False)
173
- parser.add_argument(
174
- "--device",
175
- default=os.environ.get("LWM_DEVICE", "auto"),
176
- choices=("auto", "cpu", "cuda", "hpu", "mps"),
177
- help="Select accelerator device (default: auto)."
178
- )
179
- parser.add_argument(
180
- "--dist-backend",
181
- dest="dist_backend",
182
- default=os.environ.get("LWM_DIST_BACKEND"),
183
- help="Override torch.distributed backend."
184
- )
185
- parser.add_argument(
186
- "--log-all-ranks",
187
- action="store_true",
188
- help="If set, every rank prints logs instead of rank 0 only."
189
- )
190
- args, remaining = parser.parse_known_args()
191
- sys.argv = [sys.argv[0]] + remaining
192
- return args
193
-
194
-
195
- warnings.filterwarnings("ignore", category=UserWarning)
196
-
197
- RUNTIME_ARGS = _parse_runtime_args()
198
- if getattr(RUNTIME_ARGS, "dist_backend", None) and RUNTIME_ARGS.dist_backend not in {"gloo", "nccl", "hccl"}:
199
- raise ValueError(f"Unsupported dist backend override: {RUNTIME_ARGS.dist_backend}")
200
- LOG_ALL_RANKS = bool(getattr(RUNTIME_ARGS, "log_all_ranks", False))
201
-
202
- TRAIN_SPLIT_FRACTION = 0.8
203
- VAL_SPLIT_FRACTION = 1.0 - TRAIN_SPLIT_FRACTION
204
- DEFAULT_SAMPLES_PER_SCENARIO = int(os.environ.get("LWM_SAMPLES_PER_SCENARIO", "1000"))
205
-
206
- # Use simple progress display instead of tqdm on Windows
207
- USE_TQDM = platform.system() != 'Windows'
208
-
209
- HPU_AVAILABLE = _is_hpu_available()
210
- distributed_context = _initialize_distributed(HPU_AVAILABLE, backend_override=getattr(RUNTIME_ARGS, "dist_backend", None))
211
- LOG_PRIMARY = _should_log(distributed_context)
212
- HPU_DEBUG_LOG = os.environ.get("LWM_DEBUG_HPU_INIT", "").lower() in {"1", "true", "yes"}
213
-
214
-
215
- def _debug_hpu(message: str) -> None:
216
- if not HPU_DEBUG_LOG:
217
- return
218
- rank = distributed_context.get("rank", 0)
219
- print(f"[HPU-DEBUG rank {rank}] {message}", flush=True)
220
-
221
- if distributed_context["is_distributed"] and LOG_PRIMARY:
222
- print(
223
- f"🔗 Distributed initialized -> backend={distributed_context['backend']}, "
224
- f"world_size={distributed_context['world_size']}, rank={distributed_context['rank']}"
225
- )
226
-
227
- # CPU 코어 수 계산 (메모리 사용량 고려하여 보수적으로 설정)
228
- total_cores = multiprocessing.cpu_count()
229
- if total_cores >= 16:
230
- MAX_WORKERS = min(8, total_cores // 2) # 고성능 서버의 경우 8코어로 제한
231
- else:
232
- MAX_WORKERS = max(2, total_cores // 2) # 일반 시스템의 경우 절반 사용
233
- if LOG_PRIMARY:
234
- print(f"🚀 Using {MAX_WORKERS}/{total_cores} CPU cores for parallel processing")
235
-
236
- def process_single_scenario(scenario_info):
237
- """단일 시나리오를 처리하는 함수 (멀티프로세싱용)"""
238
- scenario_name, spectrogram_path = scenario_info
239
-
240
- try:
241
- # 메모리 효율성을 위해 필요한 데이터만 로드
242
- scenario_spectrograms, scenario_labels = generate_spectrograms_and_labels(
243
- scenario_name=scenario_name,
244
- spectrogram_path=spectrogram_path,
245
- cache_path=None, # 메모리 문제로 캐시 비활성화
246
- )
247
-
248
- snr_db, doppler_id = _parse_snr_and_doppler(spectrogram_path)
249
-
250
- # 데이터 분할 (인덱스만 계산)
251
- total_samples = len(scenario_spectrograms)
252
- train_size = int(TRAIN_SPLIT_FRACTION * total_samples)
253
- val_size = total_samples - train_size
254
-
255
- # 메모리 절약을 위해 numpy array로 유지 (필요할 때만 tensor로 변환)
256
- train_data = np.array(scenario_spectrograms[:train_size], dtype=np.float32)
257
- val_data = np.array(scenario_spectrograms[train_size:], dtype=np.float32)
258
-
259
- snr_array = np.full(total_samples, snr_db, dtype=np.float32)
260
- doppler_array = np.full(total_samples, doppler_id, dtype=np.int64)
261
- train_meta = {
262
- 'snr_db': snr_array[:train_size],
263
- 'doppler_id': doppler_array[:train_size],
264
- }
265
- val_meta = {
266
- 'snr_db': snr_array[train_size:],
267
- 'doppler_id': doppler_array[train_size:],
268
- }
269
-
270
- # 불필요한 데이터 즉시 삭제
271
- del scenario_spectrograms
272
-
273
- return {
274
- 'scenario': scenario_name,
275
- 'train_data': train_data,
276
- 'val_data': val_data,
277
- 'train_meta': train_meta,
278
- 'val_meta': val_meta,
279
- 'train_size': len(train_data),
280
- 'val_size': len(val_data)
281
- }
282
- except Exception as e:
283
- context = globals().get("distributed_context", {})
284
- if LOG_PRIMARY or not context.get("is_distributed", False):
285
- print(f"❌ Error processing scenario {scenario_name}: {e}")
286
- return None
287
-
288
- # GPU Memory Monitor import (for Lambda) - Removed
289
-
290
- class StreamingMaskedSpectrogramDataset(IterableDataset):
291
- """Stream spectrogram samples scenario-by-scenario to limit peak memory usage."""
292
-
293
- def __init__(
294
- self,
295
- scenario_info_list,
296
- split,
297
- normalization_mode,
298
- dataset_stats,
299
- mask_percent,
300
- max_len,
301
- seed=42,
302
- shuffle=True,
303
- rank: int = 0,
304
- world_size: int = 1,
305
- ):
306
- super().__init__()
307
- if split not in {"train", "val"}:
308
- raise ValueError(f"Unsupported split '{split}'. Expected 'train' or 'val'.")
309
- self.scenario_info_list = list(scenario_info_list)
310
- self.split = split
311
- self.normalization_mode = normalization_mode
312
- self.dataset_stats = dataset_stats or {'mean': 0.0, 'std': 1.0, 'normalization': normalization_mode}
313
- self.mask_percent = mask_percent
314
- self.max_len = max_len
315
- self.seed = seed
316
- self.shuffle = shuffle
317
- self._epoch = 0
318
- self.num_samples = 0 # Populated after dataset summary
319
- self.rank = rank
320
- self.world_size = max(1, world_size)
321
-
322
- def _format_sample(self, sample_dict):
323
- input_ids = torch.from_numpy(sample_dict['input_ids']).float()
324
- masked_tokens = torch.from_numpy(sample_dict['masked_tokens']).float()
325
- masked_pos = torch.from_numpy(sample_dict['masked_pos']).long()
326
- snr_db = torch.tensor(sample_dict.get('snr_db', 0.0), dtype=torch.float32)
327
- doppler_id = torch.tensor(sample_dict.get('doppler_id', 0), dtype=torch.long)
328
- power_stats = torch.tensor(sample_dict.get('power_stats', np.zeros(2, dtype=np.float32)), dtype=torch.float32)
329
- snr_id = torch.tensor(sample_dict.get('snr_id', -1), dtype=torch.long)
330
- modulation_id = torch.tensor(sample_dict.get('modulation_id', -1), dtype=torch.long)
331
- return (
332
- input_ids,
333
- masked_tokens,
334
- masked_pos,
335
- snr_db,
336
- doppler_id,
337
- power_stats,
338
- snr_id,
339
- modulation_id,
340
- )
341
-
342
- def __iter__(self):
343
- order = list(self.scenario_info_list)
344
- if self.shuffle and order:
345
- rng = random.Random(self.seed + self._epoch)
346
- rng.shuffle(order)
347
- epoch_seed = self.seed + self._epoch
348
- self._epoch += 1
349
-
350
- for idx, (scenario_name, spectrogram_path) in enumerate(order):
351
- if self.world_size > 1 and (idx % self.world_size) != self.rank:
352
- continue
353
- result = process_single_scenario((scenario_name, spectrogram_path))
354
- if result is None:
355
- continue
356
-
357
- data_key = 'train_data' if self.split == 'train' else 'val_data'
358
- meta_key = 'train_meta' if self.split == 'train' else 'val_meta'
359
- spectrograms = result.get(data_key)
360
- metadata = result.get(meta_key)
361
-
362
- if spectrograms is None or len(spectrograms) == 0:
363
- continue
364
-
365
- scenario_seed = (epoch_seed + idx) % (2**32)
366
- tokenized = tokenizer_train(
367
- spectrograms,
368
- max_len=self.max_len,
369
- masking_percent=self.mask_percent,
370
- mask=True,
371
- seed=scenario_seed,
372
- metadata=metadata,
373
- dataset_stats=self.dataset_stats,
374
- normalization=self.normalization_mode,
375
- show_progress=False,
376
- )
377
-
378
- for samples in tokenized.values():
379
- for sample_dict in samples:
380
- yield self._format_sample(sample_dict)
381
-
382
- del tokenized, spectrograms, metadata, result
383
-
384
-
385
- def summarize_scenarios(scenario_info_list, normalization_mode):
386
- """Calculate dataset-level normalization stats and sample counts without storing all data in memory."""
387
- total_sum = 0.0
388
- total_sq = 0.0
389
- total_count = 0
390
- train_samples = 0
391
- val_samples = 0
392
-
393
- iterator = scenario_info_list
394
- if USE_TQDM and LOG_PRIMARY:
395
- iterator = tqdm(scenario_info_list, desc="Summarizing scenarios", unit="scenario")
396
-
397
- for scenario_name, spectrogram_path in iterator:
398
- result = process_single_scenario((scenario_name, spectrogram_path))
399
- if result is None:
400
- continue
401
-
402
- train_data = result.get('train_data')
403
- val_data = result.get('val_data')
404
-
405
- if isinstance(train_data, np.ndarray):
406
- train_samples += train_data.shape[0]
407
- if normalization_mode == "dataset" and train_data.size > 0:
408
- arr64 = train_data.astype(np.float64, copy=False)
409
- total_sum += arr64.sum()
410
- total_sq += np.square(arr64).sum(dtype=np.float64)
411
- total_count += arr64.size
412
-
413
- if isinstance(val_data, np.ndarray):
414
- val_samples += val_data.shape[0]
415
-
416
- del result
417
-
418
- if normalization_mode == "dataset":
419
- if total_count == 0:
420
- raise ValueError("Unable to compute dataset statistics: no training samples available.")
421
- mean = float(total_sum / total_count)
422
- variance = max(float(total_sq / total_count - mean ** 2), 1e-12)
423
- std = float(np.sqrt(variance))
424
- else:
425
- mean = 0.0
426
- std = 1.0
427
-
428
- stats = {'mean': mean, 'std': std, 'normalization': normalization_mode}
429
- return stats, train_samples, val_samples
430
-
431
-
432
- # =============================================================================
433
- # 2. SCENARIO LIST DEFINITION
434
- # - Define the list of scenario names to iterate over for data generation
435
- # =============================================================================
436
-
437
- # Supported communications; can be limited via CLI
438
- SUPPORTED_COMM_TYPES = {"LTE", "WiFi", "5G"}
439
-
440
-
441
- def _parse_standard_args():
442
- parser = argparse.ArgumentParser(add_help=False)
443
- parser.add_argument('--standards', nargs='+', choices=SUPPORTED_COMM_TYPES,
444
- help='Specify one or more communication types to include (default: all).')
445
- for comm in SUPPORTED_COMM_TYPES:
446
- parser.add_argument(f'--{comm}', dest=f'flag_{comm}', action='store_true',
447
- help=f'Include only {comm} data (can be combined).')
448
- parser.add_argument('--city', '--cities', dest='cities', nargs='+',
449
- help='Limit scenarios to one or more city prefixes (e.g., "0" or "city_0").')
450
- parser.add_argument(
451
- '--normalization',
452
- choices=('per_sample', 'dataset'),
453
- default='per_sample',
454
- help='Normalization mode applied during tokenization (default: %(default)s).'
455
- )
456
- parser.add_argument('--help', action='help')
457
-
458
- args, remaining = parser.parse_known_args()
459
-
460
- enabled = set(SUPPORTED_COMM_TYPES)
461
- if args.standards:
462
- enabled = set(args.standards)
463
- else:
464
- flagged = {comm for comm in SUPPORTED_COMM_TYPES if getattr(args, f'flag_{comm}', False)}
465
- if flagged:
466
- enabled = flagged
467
-
468
- selected_cities: list[str] | None = None
469
- if args.cities:
470
- selected_cities = []
471
- for city_token in args.cities:
472
- token = str(city_token).strip()
473
- if not token:
474
- continue
475
- if token.startswith('city_'):
476
- selected_cities.append(token)
477
- else:
478
- selected_cities.append(f'city_{token}')
479
- if not selected_cities:
480
- selected_cities = None
481
-
482
- # Return remaining args to allow downstream parsing if needed
483
- sys.argv = [sys.argv[0]] + remaining
484
- return enabled, selected_cities, args.normalization
485
-
486
-
487
- ENABLED_COMM_TYPES, ENABLED_CITY_PREFIXES, NORMALIZATION_MODE = _parse_standard_args()
488
- MAX_SCENARIOS = int(os.environ.get("LWM_MAX_SCENARIOS", "0")) or None
489
-
490
- SCENARIO_ENTRIES: Optional[List[Tuple[str, str, str, str]]] = None
491
-
492
-
493
- def _scenario_manifest_path() -> str:
494
- """Build cache file path based on selected comm types and city filters."""
495
- comm_token = "-".join(sorted(ENABLED_COMM_TYPES)) if ENABLED_COMM_TYPES else "all"
496
- city_token = "-".join(sorted(ENABLED_CITY_PREFIXES)) if ENABLED_CITY_PREFIXES else "all"
497
- limit_token = MAX_SCENARIOS if MAX_SCENARIOS is not None else "all"
498
- filename = f"_scenario_entries_{comm_token}_{city_token}_max{limit_token}.json"
499
- return os.path.join("spectrograms", filename)
500
-
501
-
502
- def _get_scenario_entries() -> List[Tuple[str, str, str, str]]:
503
- """Gather scenario metadata once on rank 0 and share via disk cache. Avoids long-lived collectives."""
504
- global SCENARIO_ENTRIES
505
- if SCENARIO_ENTRIES is not None:
506
- return SCENARIO_ENTRIES
507
-
508
- manifest_path = _scenario_manifest_path()
509
- refresh_requested = os.environ.get("LWM_REFRESH_SCENARIOS", "").lower() in {"1", "true", "yes"}
510
-
511
- def _load_manifest() -> Optional[List[Tuple[str, str, str, str]]]:
512
- try:
513
- with open(manifest_path, "r", encoding="utf-8") as f:
514
- raw_entries = json.load(f)
515
- except FileNotFoundError:
516
- return None
517
- except Exception as exc:
518
- if LOG_PRIMARY:
519
- print(f"⚠️ Unable to read scenario manifest {manifest_path}: {exc}", flush=True)
520
- return None
521
-
522
- entries: List[Tuple[str, str, str, str]] = []
523
- for item in raw_entries:
524
- if isinstance(item, dict):
525
- entries.append(
526
- (
527
- item.get("scenario_id", ""),
528
- item.get("file_path", ""),
529
- item.get("city_name", ""),
530
- item.get("base_token", ""),
531
- )
532
- )
533
- elif isinstance(item, (list, tuple)) and len(item) == 4:
534
- entries.append((str(item[0]), str(item[1]), str(item[2]), str(item[3])))
535
- return entries if entries else None
536
-
537
- def _save_manifest(entries_to_save: List[Tuple[str, str, str, str]]) -> None:
538
- try:
539
- os.makedirs(os.path.dirname(manifest_path), exist_ok=True)
540
- tmp_path = f"{manifest_path}.tmp"
541
- payload = [
542
- {
543
- "scenario_id": scenario_id,
544
- "file_path": file_path,
545
- "city_name": city_name,
546
- "base_token": base_token,
547
- }
548
- for scenario_id, file_path, city_name, base_token in entries_to_save
549
- ]
550
- with open(tmp_path, "w", encoding="utf-8") as f:
551
- json.dump(payload, f)
552
- os.replace(tmp_path, manifest_path)
553
- if LOG_PRIMARY:
554
- print(f"📊 [debug] Scenario manifest saved to {manifest_path}", flush=True)
555
- except Exception as exc:
556
- if LOG_PRIMARY:
557
- print(f"⚠️ Failed to save scenario manifest {manifest_path}: {exc}", flush=True)
558
-
559
- entries: Optional[List[Tuple[str, str, str, str]]] = None
560
- if distributed_context["is_distributed"]:
561
- entries = None if refresh_requested else _load_manifest()
562
- if entries is None:
563
- if distributed_context["is_primary"]:
564
- if LOG_PRIMARY:
565
- print("📊 [debug] Rank0 starting scenario discovery", flush=True)
566
- entries = _collect_scenario_file_info()
567
- if LOG_PRIMARY:
568
- print(f"📊 [debug] Rank0 collected {len(entries)} scenario entries", flush=True)
569
- _save_manifest(entries)
570
- else:
571
- deadline = time.time() + 300.0
572
- while time.time() < deadline:
573
- entries = _load_manifest()
574
- if entries is not None:
575
- break
576
- time.sleep(1.0)
577
- if entries is None:
578
- raise RuntimeError(
579
- f"Scenario manifest {manifest_path} not found after waiting. "
580
- "Run with LWM_REFRESH_SCENARIOS=1 on a single rank to regenerate."
581
- )
582
- elif LOG_PRIMARY and distributed_context["is_primary"]:
583
- print(f"📊 [debug] Rank0 loaded {len(entries)} scenario entries from manifest", flush=True)
584
- else:
585
- entries = None if refresh_requested else _load_manifest()
586
- if entries is None:
587
- if LOG_PRIMARY:
588
- print("📊 [debug] Single-process scenario discovery", flush=True)
589
- entries = _collect_scenario_file_info()
590
- if LOG_PRIMARY:
591
- print(f"📊 [debug] Collected {len(entries)} scenario entries (single process)", flush=True)
592
- _save_manifest(entries)
593
- elif LOG_PRIMARY:
594
- print(f"📊 [debug] Loaded {len(entries)} scenario entries from manifest", flush=True)
595
-
596
- if entries is None:
597
- entries = []
598
- SCENARIO_ENTRIES = entries
599
- return entries
600
-
601
-
602
- def _extract_scenario_token(file_path):
603
- """Derive the base scenario token (without city) from the file path."""
604
- normalized_path = os.path.normpath(file_path)
605
- parts = normalized_path.split(os.sep)
606
-
607
- scenario_parts = []
608
- for i, part in enumerate(parts):
609
- if part in SUPPORTED_COMM_TYPES:
610
- if i + 4 < len(parts):
611
- scenario_parts = [part] + parts[i + 1:i + 5]
612
- break
613
- return '_'.join(scenario_parts) if scenario_parts else None
614
-
615
-
616
- @lru_cache(maxsize=1)
617
- def _collect_scenario_file_info():
618
- import glob
619
-
620
- if LOG_PRIMARY:
621
- print("📊 [debug] _collect_scenario_file_info scanning directories...", flush=True)
622
- city_dirs = []
623
- for d in sorted(glob.glob(os.path.join('spectrograms', 'city_*'))):
624
- if not os.path.isdir(d):
625
- continue
626
- city_dirs.append(d)
627
-
628
- scenario_entries = []
629
- for city_dir in city_dirs:
630
- city_name = os.path.basename(city_dir)
631
- if ENABLED_CITY_PREFIXES:
632
- if not any(city_name.startswith(prefix) for prefix in ENABLED_CITY_PREFIXES):
633
- continue
634
- pattern = os.path.join(city_dir, '**', '512FFT', '**', 'spectrograms', '*.pkl')
635
- city_files = sorted(glob.glob(pattern, recursive=True))
636
- for file_path in city_files:
637
- base_token = _extract_scenario_token(file_path)
638
- if not base_token:
639
- continue
640
- scenario_id = f"{city_name}::{base_token}"
641
- comm_type = base_token.split('_', 1)[0]
642
- if comm_type not in ENABLED_COMM_TYPES:
643
- continue
644
- scenario_entries.append((scenario_id, file_path, city_name, base_token))
645
-
646
- if MAX_SCENARIOS:
647
- scenario_entries = scenario_entries[:MAX_SCENARIOS]
648
-
649
- if LOG_PRIMARY:
650
- print(f"📊 [debug] _collect_scenario_file_info found {len(scenario_entries)} entries", flush=True)
651
- return scenario_entries
652
-
653
-
654
- def scenarios_list():
655
- scenario_entries = _get_scenario_entries()
656
-
657
- if not scenario_entries:
658
- if LOG_PRIMARY:
659
- print("⚠️ No spectrogram files found for pretraining.", flush=True)
660
- return np.array([])
661
-
662
- if LOG_PRIMARY:
663
- print(f"📊 [debug] scenarios_list received {len(scenario_entries)} entries", flush=True)
664
- print(f"Enabled communication types: {sorted(ENABLED_COMM_TYPES)}", flush=True)
665
- if ENABLED_CITY_PREFIXES:
666
- print(f"Selected city prefixes: {sorted(ENABLED_CITY_PREFIXES)}", flush=True)
667
- city_counts = Counter(entry[2] for entry in scenario_entries)
668
- print("Using scenarios from the following city datasets:", flush=True)
669
- for city_name, count in city_counts.items():
670
- print(f" - {city_name}: {count} files", flush=True)
671
-
672
- print(f"Total scenarios selected: {len(scenario_entries)}", flush=True)
673
- return np.array([entry[0] for entry in scenario_entries])
674
-
675
-
676
- # =============================================================================
677
- # 3. SCENARIO PROPERTIES MAPPING
678
- # - Map each scenario name to its corresponding properties
679
- # =============================================================================
680
-
681
- def scenario_prop():
682
- scenario_entries = _get_scenario_entries()
683
-
684
- row_column_users = {}
685
- for scenario_id, file_path, city_name, _ in scenario_entries:
686
- row_column_users[scenario_id] = {
687
- 'spectrogram_path': file_path,
688
- 'cache_path': os.path.join('spectrograms', city_name, 'spectrogram_cache_128x128.pkl')
689
- }
690
-
691
- return row_column_users
692
-
693
- # =============================================================================
694
- # 4. TRAINING PARAMETERS AND HYPERPARAMETERS
695
- # - Set training epochs, batch sizes, learning rates, model dimensions, etc.
696
- # =============================================================================
697
-
698
- EPOCHS = 20 # Increased for better convergence
699
- # Optimized batch size for A100 GPU (40GB)
700
- BATCH_SIZE = 16
701
- VAL_BATCH_SIZE = 16
702
- WARMUP_EPOCHS = 5
703
- BASE_LR = 5e-4
704
- MIN_LR = 1e-8
705
- # Updated for 128x128 spectrograms
706
- N_ROWS = 4
707
- N_COLUMNS = 4
708
- ELEMENT_LENGTH = N_ROWS * N_COLUMNS # Real-valued spectrograms (no complex interleaving)
709
- D_MODEL = 128
710
- MAX_LEN = 1025 # (128/4)^2 + 1 = 1024 + 1 for [CLS] token
711
- N_LAYERS = 12
712
- device_idx = 0
713
- WEIGHT_DECAY = 0.05
714
- BETA1 = 0.9
715
- BETA2 = 0.999
716
- MASK_PERCENT = 0.6
717
- N_HEADS = 8
718
- DROPOUT = 0.1
719
-
720
- # =============================================================================
721
- # 5. DATA GENERATION LOOP
722
- # - Iterate over scenarios to generate spectrogram samples and labels
723
- # =============================================================================
724
-
725
- scenarios = scenarios_list()
726
- scenario_properties = scenario_prop()
727
-
728
- if LOG_PRIMARY:
729
- print(f"📂 Loading {len(scenarios)} scenarios...")
730
-
731
- scenario_info_list = []
732
- missing_props = []
733
- for scenario in scenarios:
734
- props = scenario_properties.get(scenario)
735
- if props is None:
736
- missing_props.append(scenario)
737
- continue
738
- scenario_info_list.append((scenario, props["spectrogram_path"]))
739
-
740
- if distributed_context["is_distributed"] and len(scenario_info_list) < distributed_context["world_size"]:
741
- if LOG_PRIMARY:
742
- print("❌ Distributed configuration requires at least one scenario per process. "
743
- f"Found {len(scenario_info_list)} scenarios for world size {distributed_context['world_size']}.")
744
- raise ValueError("Insufficient scenarios for the requested distributed world size.")
745
-
746
- if missing_props and LOG_PRIMARY:
747
- print("⚠️ Missing metadata for the following scenarios; skipping:")
748
- for scen in missing_props:
749
- print(f" - {scen}")
750
-
751
- if LOG_PRIMARY:
752
- print(f"📂 Preparing {len(scenario_info_list)} scenarios with streaming loaders...")
753
-
754
- if NORMALIZATION_MODE == "dataset":
755
- if distributed_context["is_distributed"] and not distributed_context["is_primary"]:
756
- dataset_normalization = None
757
- train_sample_count = 0
758
- val_sample_count = 0
759
- else:
760
- dataset_normalization, train_sample_count, val_sample_count = summarize_scenarios(
761
- scenario_info_list,
762
- NORMALIZATION_MODE,
763
- )
764
- if distributed_context["is_distributed"]:
765
- payload = [dataset_normalization, train_sample_count, val_sample_count]
766
- dataset_normalization, train_sample_count, val_sample_count = _broadcast_object(payload, src=0)
767
- else:
768
- train_samples_per_scenario = int(TRAIN_SPLIT_FRACTION * DEFAULT_SAMPLES_PER_SCENARIO)
769
- val_samples_per_scenario = max(DEFAULT_SAMPLES_PER_SCENARIO - train_samples_per_scenario, 0)
770
- dataset_normalization = {'mean': 0.0, 'std': 1.0, 'normalization': NORMALIZATION_MODE}
771
- train_sample_count = len(scenario_info_list) * train_samples_per_scenario
772
- val_sample_count = len(scenario_info_list) * val_samples_per_scenario
773
- if LOG_PRIMARY:
774
- print(f" Assuming {DEFAULT_SAMPLES_PER_SCENARIO} samples per scenario ({train_samples_per_scenario} train / {val_samples_per_scenario} val)")
775
-
776
- if LOG_PRIMARY:
777
- print(f" Training samples: {train_sample_count}")
778
- print(f" Validation samples: {val_sample_count}")
779
- if train_sample_count == 0:
780
- raise ValueError("No training samples available after filtering scenarios.")
781
- if NORMALIZATION_MODE == "dataset":
782
- if LOG_PRIMARY:
783
- print(f"Dataset normalization stats -> mean: {dataset_normalization['mean']:.4f}, std: {dataset_normalization['std']:.4f}")
784
- else:
785
- if LOG_PRIMARY:
786
- print("Dataset normalization stats -> using per-sample normalization")
787
-
788
- SEED = 42
789
- torch.manual_seed(SEED)
790
- np.random.seed(SEED)
791
-
792
- world_size = max(1, distributed_context["world_size"])
793
- train_samples_per_rank = math.ceil(train_sample_count / world_size) if distributed_context["is_distributed"] else train_sample_count
794
- val_samples_per_rank = math.ceil(val_sample_count / world_size) if distributed_context["is_distributed"] else val_sample_count
795
-
796
- train_dataset = StreamingMaskedSpectrogramDataset(
797
- scenario_info_list,
798
- split="train",
799
- normalization_mode=NORMALIZATION_MODE,
800
- dataset_stats=dataset_normalization,
801
- mask_percent=MASK_PERCENT,
802
- max_len=MAX_LEN,
803
- seed=SEED,
804
- shuffle=True,
805
- rank=distributed_context["rank"],
806
- world_size=world_size,
807
- )
808
- train_dataset.num_samples = train_samples_per_rank
809
-
810
- val_dataset = StreamingMaskedSpectrogramDataset(
811
- scenario_info_list,
812
- split="val",
813
- normalization_mode=NORMALIZATION_MODE,
814
- dataset_stats=dataset_normalization,
815
- mask_percent=MASK_PERCENT,
816
- max_len=MAX_LEN,
817
- seed=SEED,
818
- shuffle=False,
819
- rank=distributed_context["rank"],
820
- world_size=world_size,
821
- )
822
- val_dataset.num_samples = val_samples_per_rank
823
-
824
- if LOG_PRIMARY:
825
- print("🔧 Creating streaming data loaders...")
826
- train_loaders = {
827
- 'stream': DataLoader(
828
- train_dataset,
829
- batch_size=BATCH_SIZE,
830
- shuffle=False,
831
- num_workers=0,
832
- pin_memory=True,
833
- )
834
- }
835
- val_loaders = {
836
- 'stream': DataLoader(
837
- val_dataset,
838
- batch_size=VAL_BATCH_SIZE,
839
- shuffle=False,
840
- num_workers=0,
841
- pin_memory=True,
842
- )
843
- }
844
- if LOG_PRIMARY:
845
- print("✅ Data loaders created successfully!")
846
-
847
- # =============================================================================
848
- # 9. MODEL INITIALIZATION
849
- # - Instantiate the LWM transformer model and optionally load pre-trained weights
850
- # - Wrap with DataParallel for multi-GPU support
851
- # =============================================================================
852
-
853
- # Device selection with HPU, CUDA, and MPS support
854
- if LOG_PRIMARY:
855
- print("🔧 Setting up device and accelerator configuration...")
856
-
857
- requested_device = getattr(RUNTIME_ARGS, "device", "auto") or "auto"
858
- requested_device = requested_device.lower()
859
- runtime_device = requested_device
860
-
861
- if runtime_device == "auto":
862
- if HPU_AVAILABLE:
863
- runtime_device = "hpu"
864
- elif torch.cuda.is_available():
865
- runtime_device = "cuda"
866
- elif torch.backends.mps.is_available():
867
- runtime_device = "mps"
868
- else:
869
- runtime_device = "cpu"
870
-
871
- if runtime_device in {"hpu", "auto"} and not HPU_AVAILABLE:
872
- if os.environ.get("HABANA_VISIBLE_DEVICES") and LOG_PRIMARY:
873
- print("⚠️ HABANA_VISIBLE_DEVICES is set but Habana PyTorch extensions are not available.")
874
- print(" Install the Habana PyTorch distribution or activate the appropriate environment.")
875
-
876
- device = torch.device("cpu")
877
- gpu_ids: list[int] = []
878
- ddp_device_ids: Optional[list[int]] = None
879
-
880
- if runtime_device == "hpu":
881
- if not HPU_AVAILABLE:
882
- raise RuntimeError("HPU device requested but torch.hpu is not available. "
883
- "Install the Habana PyTorch distribution or select --device cpu.")
884
- hpu_module = getattr(torch, "hpu", None)
885
-
886
- # Get local rank first before any HPU operations
887
- local_rank = distributed_context["local_rank"] if distributed_context["is_distributed"] else 0
888
- _debug_hpu(f"Entering HPU device setup (local_rank={local_rank}, world_size={distributed_context.get('world_size')})")
889
-
890
- # Query device count locally (safe after Habana runtime init)
891
- hpu_count = max(1, _get_hpu_device_count())
892
- if LOG_PRIMARY or HPU_DEBUG_LOG:
893
- _debug_hpu(f"Detected {hpu_count} HPU devices via local query")
894
-
895
- device = torch.device("hpu")
896
- if hpu_module is not None and hasattr(hpu_module, "set_device"):
897
- try:
898
- _debug_hpu(f"Calling torch.hpu.set_device({local_rank})")
899
- hpu_module.set_device(local_rank)
900
- _debug_hpu("torch.hpu.set_device completed successfully")
901
- except Exception as exc:
902
- _debug_hpu(f"set_device raised exception: {exc}")
903
- if LOG_PRIMARY:
904
- print(f" ⚠️ Unable to set HPU device {local_rank}: {exc}")
905
- ddp_device_ids = [local_rank] if distributed_context["is_distributed"] else None
906
- if LOG_PRIMARY:
907
- if hpu_count > 0:
908
- print(f" HPU available: {hpu_count} device(s) detected")
909
- if distributed_context["is_distributed"]:
910
- print(f" Using HPU local rank: {local_rank}")
911
- elif runtime_device == "cuda":
912
- if not torch.cuda.is_available():
913
- raise RuntimeError("CUDA device requested but torch.cuda.is_available() is False.")
914
- device_count = torch.cuda.device_count()
915
- if LOG_PRIMARY:
916
- print(f" CUDA available: {device_count} GPU(s) detected")
917
- if distributed_context["is_distributed"]:
918
- local_rank = distributed_context["local_rank"]
919
- torch.cuda.set_device(local_rank)
920
- device = torch.device("cuda", local_rank)
921
- ddp_device_ids = [local_rank]
922
- if LOG_PRIMARY:
923
- print(f" Using CUDA local rank: {local_rank}")
924
- else:
925
- device = torch.device("cuda:0")
926
- gpu_ids = list(range(device_count))
927
- if LOG_PRIMARY:
928
- print(f" Using CUDA GPUs: {gpu_ids}")
929
- for i in gpu_ids:
930
- try:
931
- mem_total = torch.cuda.get_device_properties(i).total_memory / 1024**3
932
- mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
933
- if LOG_PRIMARY:
934
- print(f" GPU {i}: Total: {mem_total:.1f}GB, Allocated: {mem_allocated:.1f}GB")
935
- except Exception as exc:
936
- if LOG_PRIMARY:
937
- print(f" GPU {i}: Error getting memory info - {exc}")
938
- elif runtime_device == "mps":
939
- if not torch.backends.mps.is_available():
940
- raise RuntimeError("MPS device requested but torch.backends.mps.is_available() is False.")
941
- device = torch.device("mps")
942
- if LOG_PRIMARY:
943
- print(" Using MPS (Apple Silicon GPU)")
944
- elif runtime_device == "cpu":
945
- device = torch.device("cpu")
946
- if LOG_PRIMARY:
947
- print(" Using CPU")
948
- else:
949
- raise ValueError(f"Unsupported device selection: {runtime_device}")
950
-
951
- distributed_context["device_type"] = device.type
952
- if LOG_PRIMARY:
953
- print(f" Final device: {device}")
954
- if gpu_ids:
955
- print(f" GPU IDs for DataParallel: {gpu_ids}")
956
-
957
- if LOG_PRIMARY:
958
- print("🤖 Initializing LWM model...")
959
- print(f" Model parameters: element_length={ELEMENT_LENGTH}, d_model={D_MODEL}, n_layers={N_LAYERS}, max_len={MAX_LEN}, n_heads={N_HEADS}")
960
-
961
- try:
962
- model = pretrained_model.lwm(
963
- element_length=ELEMENT_LENGTH, # Real-valued spectrograms
964
- d_model=D_MODEL,
965
- n_layers=N_LAYERS,
966
- max_len=MAX_LEN, # Use pre-calculated value for safety
967
- n_heads=N_HEADS,
968
- dropout=DROPOUT
969
- )
970
- if LOG_PRIMARY:
971
- print(" ✅ Model created successfully")
972
- print(f" Moving model to device: {device}")
973
- # MPS only supports float32, so set dtype
974
- if 'mps' in str(device):
975
- model = model.to(device).float()
976
- if LOG_PRIMARY:
977
- print(" ✅ Model moved to MPS device (float32)")
978
- else:
979
- model = model.to(device)
980
- if LOG_PRIMARY:
981
- print(" ✅ Model moved to device successfully")
982
-
983
- # Synchronize all processes after moving model to device
984
- # This prevents memory contention issues in multi-HPU/GPU setups
985
- if distributed_context["is_distributed"]:
986
- torch.distributed.barrier()
987
- if LOG_PRIMARY:
988
- print(" ✅ All processes synchronized after model transfer")
989
-
990
- except Exception as e:
991
- print(f" ❌ Model initialization failed: {e}")
992
- import traceback
993
- traceback.print_exc()
994
- exit(1)
995
-
996
- # Optional: Load pre-trained model
997
- load_model = False
998
- if load_model:
999
- model.load_state_dict(torch.load("models/model_checkpoint.pth", map_location=device))
1000
- if LOG_PRIMARY:
1001
- print("Pre-trained model loaded successfully.")
1002
-
1003
- # Wrap model for parallel/distributed execution
1004
- if distributed_context["is_distributed"]:
1005
- # Additional barrier before DDP wrapping to ensure all processes are ready
1006
- torch.distributed.barrier()
1007
-
1008
- ddp_kwargs: Dict[str, Any] = {"broadcast_buffers": False}
1009
- if ddp_device_ids:
1010
- ddp_kwargs["device_ids"] = ddp_device_ids
1011
- ddp_kwargs["output_device"] = ddp_device_ids[0]
1012
- else:
1013
- ddp_kwargs["device_ids"] = None
1014
- model = nn.parallel.DistributedDataParallel(model, **ddp_kwargs)
1015
- if LOG_PRIMARY:
1016
- print(f"Model wrapped with DistributedDataParallel on rank {distributed_context['rank']}")
1017
- elif gpu_ids:
1018
- model = nn.DataParallel(model, device_ids=gpu_ids)
1019
- if LOG_PRIMARY:
1020
- print(f"Model loaded successfully with DataParallel on CUDA devices {gpu_ids}")
1021
- else:
1022
- if LOG_PRIMARY:
1023
- print(f"Model loaded successfully on {device}")
1024
- n_parameters = count_parameters(model, log=LOG_PRIMARY)
1025
- if LOG_PRIMARY:
1026
- print(f"Number of trainable parameters: {n_parameters:,}")
1027
-
1028
- # =============================================================================
1029
- # 10. OPTIMIZER AND LEARNING RATE SCHEDULER
1030
- # - Configure AdamW optimizer and a cosine-with-warmup LR schedule based on total steps
1031
- # =============================================================================
1032
-
1033
- steps_per_epoch = max(1, math.ceil(train_samples_per_rank / BATCH_SIZE))
1034
- TOTAL_STEPS = steps_per_epoch * EPOCHS
1035
- WARMUP_STEPS = steps_per_epoch * WARMUP_EPOCHS
1036
-
1037
- optimizer = AdamW(
1038
- model.parameters(),
1039
- lr=BASE_LR,
1040
- betas=(BETA1, BETA2),
1041
- weight_decay=WEIGHT_DECAY
1042
- )
1043
-
1044
- def lr_lambda(current_step):
1045
- if current_step < WARMUP_STEPS:
1046
- return current_step / WARMUP_STEPS
1047
- else:
1048
- scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
1049
- cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
1050
- return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
1051
-
1052
- scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
1053
-
1054
- # =============================================================================
1055
- # 11. PRE-TRAINING LOOP
1056
- # - Call the train_lwm utility to run the pre-training epochs, logging metrics and saving models
1057
- # =============================================================================
1058
-
1059
- # Create timestamp-based save directory
1060
- if distributed_context["is_distributed"]:
1061
- timestamp_source = datetime.now().strftime("%Y%m%d_%H%M%S") if LOG_PRIMARY else None
1062
- timestamp = _broadcast_object(timestamp_source, src=0)
1063
- else:
1064
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1065
- save_dir = f"models/{timestamp}"
1066
- if LOG_PRIMARY:
1067
- print(f"📁 Models and logs will be saved to: {save_dir}")
1068
- os.makedirs(save_dir, exist_ok=True)
1069
-
1070
- stats_path = os.path.join(save_dir, "dataset_stats.json")
1071
- if LOG_PRIMARY:
1072
- with open(stats_path, 'w') as f:
1073
- json.dump(dataset_normalization, f, indent=2)
1074
- print(f"📝 Saved dataset stats to {stats_path}")
1075
- _barrier(distributed_context)
1076
-
1077
- comm_selection = sorted(ENABLED_COMM_TYPES) if ENABLED_COMM_TYPES else []
1078
- if comm_selection:
1079
- comm_suffix = "_" + "-".join(comm_selection)
1080
- else:
1081
- comm_suffix = ""
1082
- if comm_selection and LOG_PRIMARY:
1083
- print(f"[INFO] Communication standards for this run: {', '.join(comm_selection)}")
1084
-
1085
- if __name__ == "__main__":
1086
- # Patch: Ensure patches is not a dict before converting to tensor
1087
- def safe_tensor_from_patches(patches, device):
1088
- if isinstance(patches, dict):
1089
- key = max(patches.keys())
1090
- patches = patches[key]
1091
- return torch.tensor(patches, dtype=torch.float32).to(device)
1092
-
1093
- # Pass this function to train_lwm if needed, or use inside train_lwm
1094
- pretrained_model = train_lwm(
1095
- model,
1096
- train_loaders,
1097
- val_loaders,
1098
- optimizer,
1099
- scheduler,
1100
- EPOCHS,
1101
- device=device,
1102
- save_dir=save_dir,
1103
- log_file="training_log.csv",
1104
- checkpoint_suffix=comm_suffix,
1105
- distributed_context=distributed_context,
1106
- # If train_lwm needs to convert patches, use safe_tensor_from_patches
1107
- )
1108
- _barrier(distributed_context)
1109
- if LOG_PRIMARY:
1110
- print("🏁 Training run complete.")
1111
- if distributed_context["is_distributed"]:
1112
- dist.destroy_process_group()
1113
- SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
1114
- DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
1115
- DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
1116
-
1117
-
1118
- def _parse_snr_and_doppler(path: str) -> tuple[float, int]:
1119
- snr_db = 0.0
1120
- doppler_id = 0
1121
-
1122
- matches = SNR_PATTERN.findall(path)
1123
- if matches:
1124
- try:
1125
- snr_db = float(matches[-1])
1126
- except ValueError:
1127
- snr_db = 0.0
1128
-
1129
- normalized_path = os.path.normpath(path)
1130
- parts = normalized_path.split(os.sep)
1131
- for part in parts:
1132
- if part in DOPPLER_MAP:
1133
- doppler_id = DOPPLER_MAP[part]
1134
- break
1135
-
1136
- return snr_db, doppler_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  # UI/Hub
2
- gradio==3.50.2
3
- huggingface_hub==0.23.4
4
 
5
  # Core
6
  torch
 
1
  # UI/Hub
2
+ gradio==6.0.1
3
+ huggingface_hub>=0.33.5,<2.0
4
 
5
  # Core
6
  torch