Spaces:
Running
Running
“Namhyun-Kim”
commited on
Commit
·
2b1a1e3
1
Parent(s):
aebafe2
Sync app to fetch data from wi-lab/lwm-spectro
Browse files- README.md +24 -0
- app.py +464 -196
- pretraining/README.md +0 -44
- pretraining/__init__.py +0 -0
- pretraining/pretrained_model.py +0 -180
- pretraining/train_lwm_spectro.py +0 -741
- pretraining/train_lwm_spectro_contrastive.py +0 -1450
- pretraining/train_lwm_spectro_no_contrast.py +0 -1136
- requirements.txt +2 -2
README.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: LWM-Spectro Lab
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "6.0.1"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# LWM-Spectro Lab
|
| 13 |
+
|
| 14 |
+
One-stop lab for exploring spectrograms, LWM embeddings, and lightweight evaluation baselines.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
- Visualize LWM embeddings or raw spectrograms with customizable filters.
|
| 18 |
+
- Inspect joint SNR/Doppler performance using cached MoE embeddings and an adaptive k-NN classifier.
|
| 19 |
+
- Upload your own datasets to compare raw channels vs. model embeddings.
|
| 20 |
+
|
| 21 |
+
## Usage
|
| 22 |
+
1. Select the **Spectrograms** and **t-SNE Analysis** tabs to explore embeddings.
|
| 23 |
+
2. Switch to **Modulation Classification** or **Joint SNR/Doppler Evaluation** to run the k-NN prototype with adjustable train/test splits.
|
| 24 |
+
3. Provide custom data (optional) to benchmark against bundled samples.
|
app.py
CHANGED
|
@@ -1,52 +1,55 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
import sys
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Dict, List,
|
| 6 |
|
| 7 |
import gradio as gr
|
|
|
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import plotly.express as px
|
| 11 |
import plotly.graph_objects as go
|
| 12 |
import torch
|
|
|
|
| 13 |
from sklearn.decomposition import PCA
|
| 14 |
from sklearn.manifold import TSNE
|
| 15 |
-
from sklearn.metrics import accuracy_score,
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 18 |
APP_DIR = Path(__file__).resolve().parent
|
| 19 |
DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
|
| 20 |
MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
label_names = [f"{snr} | {mob}" for snr, mob in
|
| 45 |
-
pair_to_name = {pair: name for pair, name in zip(
|
| 46 |
name_to_id = {name: idx for idx, name in enumerate(label_names)}
|
| 47 |
-
pair_to_id = {pair: idx for idx, pair in enumerate(
|
| 48 |
return {
|
| 49 |
-
"pairs":
|
| 50 |
"label_names": label_names,
|
| 51 |
"pair_to_name": pair_to_name,
|
| 52 |
"name_to_id": name_to_id,
|
|
@@ -54,77 +57,43 @@ def load_joint_mapping() -> Optional[Dict[str, object]]:
|
|
| 54 |
}
|
| 55 |
|
| 56 |
|
| 57 |
-
def
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
batch_size: int = 64,
|
| 61 |
-
) -> torch.Tensor:
|
| 62 |
-
router = predictor.router
|
| 63 |
-
experts = predictor.experts
|
| 64 |
-
device = predictor.device
|
| 65 |
-
embeddings: List[torch.Tensor] = []
|
| 66 |
-
|
| 67 |
-
with torch.no_grad():
|
| 68 |
-
for start in range(0, len(samples), batch_size):
|
| 69 |
-
batch = samples[start : start + batch_size]
|
| 70 |
-
specs = torch.cat([sample["data"] for sample in batch], dim=0).to(device)
|
| 71 |
-
specs_norm = normalize_per_sample_tensor(specs)
|
| 72 |
-
|
| 73 |
-
if router is not None:
|
| 74 |
-
router_logits = router(specs_norm)
|
| 75 |
-
probs = torch.softmax(router_logits, dim=1)
|
| 76 |
-
topk_vals, topk_idx = probs.topk(k=predictor.topk, dim=1)
|
| 77 |
-
weights = topk_vals / torch.clamp(topk_vals.sum(dim=1, keepdim=True), min=1e-6)
|
| 78 |
-
selected_embeddings = compute_selected_expert_embeddings(
|
| 79 |
-
experts,
|
| 80 |
-
specs_norm,
|
| 81 |
-
topk_idx,
|
| 82 |
-
allow_grad=False,
|
| 83 |
-
)
|
| 84 |
-
weighted = (weights.unsqueeze(-1) * selected_embeddings).sum(dim=1)
|
| 85 |
-
else:
|
| 86 |
-
stacked = stack_expert_embeddings(experts, specs_norm)
|
| 87 |
-
weighted = stacked.mean(dim=1)
|
| 88 |
|
| 89 |
-
embeddings.append(weighted.cpu())
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
-
def
|
|
|
|
|
|
|
| 95 |
if MOE_DATA_PATH.exists():
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
print(f"[INFO] Loaded cached MoE embeddings from {MOE_DATA_PATH}")
|
| 99 |
-
return cached, True
|
| 100 |
-
print("[WARN] Cached MoE embeddings length mismatch. Recomputing...")
|
| 101 |
-
|
| 102 |
-
if not MOE_CHECKPOINT.exists():
|
| 103 |
-
print(f"[WARN] MoE checkpoint not found at {MOE_CHECKPOINT}. Skipping MoE embeddings.")
|
| 104 |
-
return samples, False
|
| 105 |
-
|
| 106 |
-
print("[INFO] Computing MoE embeddings using router checkpoint...")
|
| 107 |
-
predictor = MoEPredictor.from_checkpoint(MOE_CHECKPOINT)
|
| 108 |
-
moe_embeddings = compute_moe_embeddings(samples, predictor)
|
| 109 |
-
for sample, emb in zip(samples, moe_embeddings):
|
| 110 |
-
sample["moe_embedding"] = emb.detach().cpu()
|
| 111 |
-
|
| 112 |
-
torch.save(samples, MOE_DATA_PATH)
|
| 113 |
-
print(f"[INFO] Saved MoE-augmented dataset to {MOE_DATA_PATH}")
|
| 114 |
-
return samples, True
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def load_data(mapping: Optional[Dict[str, object]]):
|
| 118 |
if not DEMO_DATA_PATH.exists():
|
| 119 |
raise FileNotFoundError(f"Dataset not found at {DEMO_DATA_PATH}")
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
print(f"[INFO] Loading base dataset from {DEMO_DATA_PATH}")
|
| 122 |
-
data: List[Dict[str, object]] = torch.load(DEMO_DATA_PATH)
|
| 123 |
-
data, has_moe = ensure_moe_embeddings(data)
|
| 124 |
-
|
| 125 |
-
pair_to_name = mapping["pair_to_name"] if mapping else {}
|
| 126 |
-
pair_to_id = mapping["pair_to_id"] if mapping else {}
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
records = []
|
| 129 |
for i, sample in enumerate(data):
|
| 130 |
embedding = sample["embedding"]
|
|
@@ -149,9 +118,14 @@ def load_data(mapping: Optional[Dict[str, object]]):
|
|
| 149 |
joint_label = pair_to_name.get(pair)
|
| 150 |
joint_label_id = pair_to_id.get(pair)
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
records.append(
|
| 153 |
{
|
| 154 |
-
|
| 155 |
"tech": sample["tech"],
|
| 156 |
"snr": sample["snr"],
|
| 157 |
"mod": sample["mod"],
|
|
@@ -161,11 +135,15 @@ def load_data(mapping: Optional[Dict[str, object]]):
|
|
| 161 |
"spectrogram": flat_spec,
|
| 162 |
"joint_label": joint_label,
|
| 163 |
"joint_label_id": joint_label_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
}
|
| 165 |
)
|
| 166 |
|
| 167 |
df = pd.DataFrame(records)
|
| 168 |
-
print(f"[INFO] Loaded {len(df)} samples
|
| 169 |
return df, has_moe
|
| 170 |
|
| 171 |
|
|
@@ -188,50 +166,94 @@ def apply_filters(
|
|
| 188 |
return filtered
|
| 189 |
|
| 190 |
|
| 191 |
-
def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation,
|
| 192 |
filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
|
| 193 |
if len(filtered_df) < 5:
|
| 194 |
-
return None
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
if representation == "LWM Embedding"
|
| 197 |
-
|
|
|
|
|
|
|
| 198 |
else:
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
init="pca",
|
| 211 |
-
learning_rate="auto",
|
| 212 |
-
)
|
| 213 |
-
projections = tsne.fit_transform(features)
|
| 214 |
-
filtered_df = filtered_df.copy()
|
| 215 |
-
filtered_df["x"] = projections[:, 0]
|
| 216 |
-
filtered_df["y"] = projections[:, 1]
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
fig = px.scatter(
|
| 219 |
filtered_df,
|
| 220 |
x="x",
|
| 221 |
y="y",
|
| 222 |
-
color=
|
| 223 |
hover_data=["tech", "snr", "mod", "mob"],
|
| 224 |
title=f"t-SNE of {representation} ({len(filtered_df)} samples)",
|
| 225 |
template="plotly_white",
|
| 226 |
)
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
|
| 232 |
rng = np.random.default_rng(int(seed))
|
| 233 |
-
train_indices
|
| 234 |
-
test_indices
|
| 235 |
|
| 236 |
for label_id, group in filtered_df.groupby("joint_label_id"):
|
| 237 |
indices = group.index.to_numpy()
|
|
@@ -247,47 +269,22 @@ def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -
|
|
| 247 |
return np.array(train_indices), np.array(test_indices)
|
| 248 |
|
| 249 |
|
| 250 |
-
def
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
centroid_ids: List[int] = []
|
| 262 |
-
for label_id in unique_labels:
|
| 263 |
-
mask = train_labels == label_id
|
| 264 |
-
centroids.append(train_embeddings[mask].mean(axis=0))
|
| 265 |
-
centroid_ids.append(int(label_id))
|
| 266 |
-
|
| 267 |
-
centroids = np.stack(centroids)
|
| 268 |
-
centroid_ids = np.array(centroid_ids, dtype=int)
|
| 269 |
-
|
| 270 |
-
dists = ((test_embeddings[:, None, :] - centroids[None, :, :]) ** 2).sum(axis=-1)
|
| 271 |
-
preds = centroid_ids[np.argmin(dists, axis=1)]
|
| 272 |
-
|
| 273 |
-
accuracy = accuracy_score(test_labels, preds)
|
| 274 |
-
macro_f1 = f1_score(test_labels, preds, average="macro", labels=centroid_ids, zero_division=0)
|
| 275 |
-
|
| 276 |
-
active_ids = sorted(np.unique(np.concatenate([test_labels, preds])))
|
| 277 |
-
label_names = [CLASS_LABELS[i] for i in active_ids]
|
| 278 |
-
cm = confusion_matrix(test_labels, preds, labels=active_ids)
|
| 279 |
|
| 280 |
-
return {
|
| 281 |
-
"accuracy": accuracy,
|
| 282 |
-
"macro_f1": macro_f1,
|
| 283 |
-
"confusion": cm,
|
| 284 |
-
"label_names": label_names,
|
| 285 |
-
"train_size": len(train_idx),
|
| 286 |
-
"test_size": len(test_idx),
|
| 287 |
-
}
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
|
|
|
| 291 |
fig = go.Figure(
|
| 292 |
data=go.Heatmap(
|
| 293 |
z=confusion,
|
|
@@ -298,7 +295,7 @@ def plot_confusion_heatmap(confusion: np.ndarray, label_names: List[str]) -> go.
|
|
| 298 |
)
|
| 299 |
)
|
| 300 |
fig.update_layout(
|
| 301 |
-
title=
|
| 302 |
xaxis_title="Predicted",
|
| 303 |
yaxis_title="True",
|
| 304 |
xaxis=dict(tickangle=45),
|
|
@@ -307,70 +304,312 @@ def plot_confusion_heatmap(confusion: np.ndarray, label_names: List[str]) -> go.
|
|
| 307 |
|
| 308 |
|
| 309 |
def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter):
|
| 310 |
-
if
|
| 311 |
fig = go.Figure()
|
| 312 |
fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 313 |
-
return fig, "MoE embeddings are not available
|
| 314 |
|
| 315 |
filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter)
|
| 316 |
if filtered.empty:
|
| 317 |
fig = go.Figure()
|
| 318 |
fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 319 |
-
return fig, "No samples match the selected filters."
|
| 320 |
|
| 321 |
if filtered["joint_label_id"].nunique() < 2:
|
| 322 |
fig = go.Figure()
|
| 323 |
fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 324 |
-
return fig, "Need at least two joint SNR/Doppler classes to evaluate."
|
|
|
|
|
|
|
| 325 |
|
| 326 |
try:
|
| 327 |
train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed)
|
| 328 |
except ValueError as exc:
|
| 329 |
fig = go.Figure()
|
| 330 |
fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 331 |
-
return fig, str(exc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
metrics = compute_centroid_metrics(filtered, train_idx, test_idx)
|
| 334 |
-
fig = plot_confusion_heatmap(metrics["confusion"], metrics["label_names"])
|
| 335 |
status = (
|
| 336 |
-
f"
|
| 337 |
-
f"Test
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
)
|
| 341 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
|
| 344 |
mapping_info = load_joint_mapping()
|
| 345 |
df, has_moe_embeddings = load_data(mapping_info)
|
| 346 |
-
CLASS_LABELS
|
| 347 |
|
| 348 |
-
|
| 349 |
-
joint_eval_df =
|
| 350 |
-
joint_eval_df = joint_eval_df[joint_eval_df["moe_embedding"].notna()]
|
| 351 |
|
| 352 |
tech_choices = sorted(df["tech"].unique())
|
| 353 |
snr_choices = sorted(df["snr"].unique())
|
| 354 |
mod_choices = sorted(df["mod"].unique())
|
| 355 |
mob_choices = sorted(df["mob"].unique())
|
| 356 |
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
|
| 361 |
gr.Markdown(
|
| 362 |
"""
|
| 363 |
-
Compare **LWM embeddings** vs **Raw Spectrograms** for visualization, then evaluate **MoE embeddings**
|
| 364 |
-
with a lightweight prototype classifier for joint SNR/Doppler recognition.
|
| 365 |
"""
|
| 366 |
)
|
| 367 |
|
| 368 |
with gr.Tabs():
|
| 369 |
-
with gr.Tab("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
with gr.Row():
|
| 371 |
with gr.Column(scale=1, min_width=300):
|
| 372 |
gr.Markdown("### Filters")
|
| 373 |
-
tech_filter = gr.CheckboxGroup(choices=tech_choices, value=
|
| 374 |
snr_filter = gr.Dropdown(
|
| 375 |
choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)"
|
| 376 |
)
|
|
@@ -387,14 +626,17 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
|
| 387 |
value="LWM Embedding",
|
| 388 |
label="Representation",
|
| 389 |
)
|
| 390 |
-
color_by = gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
with gr.Accordion("Advanced t-SNE Settings", open=False):
|
| 393 |
perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity")
|
| 394 |
n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
|
| 395 |
|
| 396 |
btn = gr.Button("Update Plot", variant="primary")
|
| 397 |
-
status = gr.Textbox(label="Status", interactive=False)
|
| 398 |
|
| 399 |
with gr.Column(scale=3):
|
| 400 |
plot = gr.Plot(label="t-SNE Visualization")
|
|
@@ -402,19 +644,43 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
|
| 402 |
btn.click(
|
| 403 |
plot_tsne,
|
| 404 |
inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
|
| 405 |
-
outputs=[plot
|
| 406 |
)
|
| 407 |
|
| 408 |
demo.load(
|
| 409 |
plot_tsne,
|
| 410 |
inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
|
| 411 |
-
outputs=[plot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
)
|
| 413 |
|
| 414 |
-
with gr.Tab("
|
| 415 |
if evaluation_disabled:
|
| 416 |
gr.Markdown(
|
| 417 |
-
"⚠️ MoE embeddings are
|
| 418 |
)
|
| 419 |
|
| 420 |
with gr.Row():
|
|
@@ -422,7 +688,7 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
|
| 422 |
gr.Markdown("### Evaluation Filters")
|
| 423 |
eval_tech_filter = gr.CheckboxGroup(
|
| 424 |
choices=tech_choices,
|
| 425 |
-
value=
|
| 426 |
label="Technology",
|
| 427 |
interactive=not evaluation_disabled,
|
| 428 |
)
|
|
@@ -468,13 +734,15 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
|
| 468 |
eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled)
|
| 469 |
|
| 470 |
with gr.Column(scale=3):
|
| 471 |
-
|
| 472 |
-
|
|
|
|
|
|
|
| 473 |
|
| 474 |
eval_btn.click(
|
| 475 |
run_joint_evaluation,
|
| 476 |
inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter],
|
| 477 |
-
outputs=[eval_plot, eval_status],
|
| 478 |
)
|
| 479 |
|
| 480 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Dict, List, Tuple, Optional
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
| 11 |
import plotly.express as px
|
| 12 |
import plotly.graph_objects as go
|
| 13 |
import torch
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
from sklearn.decomposition import PCA
|
| 16 |
from sklearn.manifold import TSNE
|
| 17 |
+
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
|
| 18 |
+
from sklearn.neighbors import KNeighborsClassifier
|
| 19 |
+
from sklearn.preprocessing import StandardScaler
|
| 20 |
|
|
|
|
| 21 |
APP_DIR = Path(__file__).resolve().parent
|
| 22 |
DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
|
| 23 |
MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
|
| 24 |
+
HUB_REPO_ID = "wi-lab/lwm-spectro"
|
| 25 |
+
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_HUB_TOKEN")
|
| 26 |
+
|
| 27 |
+
# Fixed ordering for the 14 joint SNR/Doppler labels
|
| 28 |
+
JOINT_LABELS = [
|
| 29 |
+
("SNR-5dB", "pedestrian"),
|
| 30 |
+
("SNR-5dB", "vehicular"),
|
| 31 |
+
("SNR0dB", "pedestrian"),
|
| 32 |
+
("SNR0dB", "vehicular"),
|
| 33 |
+
("SNR5dB", "pedestrian"),
|
| 34 |
+
("SNR5dB", "vehicular"),
|
| 35 |
+
("SNR10dB", "pedestrian"),
|
| 36 |
+
("SNR10dB", "vehicular"),
|
| 37 |
+
("SNR15dB", "pedestrian"),
|
| 38 |
+
("SNR15dB", "vehicular"),
|
| 39 |
+
("SNR20dB", "pedestrian"),
|
| 40 |
+
("SNR20dB", "vehicular"),
|
| 41 |
+
("SNR25dB", "pedestrian"),
|
| 42 |
+
("SNR25dB", "vehicular"),
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_joint_mapping() -> Dict[str, object]:
|
| 47 |
+
label_names = [f"{snr} | {mob}" for snr, mob in JOINT_LABELS]
|
| 48 |
+
pair_to_name = {pair: name for pair, name in zip(JOINT_LABELS, label_names)}
|
| 49 |
name_to_id = {name: idx for idx, name in enumerate(label_names)}
|
| 50 |
+
pair_to_id = {pair: idx for idx, pair in enumerate(JOINT_LABELS)}
|
| 51 |
return {
|
| 52 |
+
"pairs": JOINT_LABELS,
|
| 53 |
"label_names": label_names,
|
| 54 |
"pair_to_name": pair_to_name,
|
| 55 |
"name_to_id": name_to_id,
|
|
|
|
| 57 |
}
|
| 58 |
|
| 59 |
|
| 60 |
+
def _safe_load_tensor(path: Path):
|
| 61 |
+
# Torch 2.6 defaults to weights_only=True, which breaks our saved dicts.
|
| 62 |
+
return torch.load(path, weights_only=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
|
|
|
| 64 |
|
| 65 |
+
def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]:
|
| 66 |
+
"""Ensure a file exists locally; try Hub download if missing."""
|
| 67 |
+
if local_path.exists():
|
| 68 |
+
return local_path
|
| 69 |
+
try:
|
| 70 |
+
cached = hf_hub_download(repo_id=HUB_REPO_ID, filename=hub_filename, token=HF_TOKEN)
|
| 71 |
+
cached_path = Path(cached)
|
| 72 |
+
shutil.copyfile(cached_path, local_path)
|
| 73 |
+
print(f"[INFO] Downloaded {hub_filename} from Hub to {local_path}")
|
| 74 |
+
return local_path
|
| 75 |
+
except Exception as exc:
|
| 76 |
+
print(f"[WARN] Could not download {hub_filename} from Hub ({exc}); continuing without it.")
|
| 77 |
+
return None
|
| 78 |
|
| 79 |
|
| 80 |
+
def load_augmented_samples() -> Tuple[List[Dict[str, object]], bool]:
|
| 81 |
+
_ensure_local_file(MOE_DATA_PATH, "demo_data_moe.pt")
|
| 82 |
+
_ensure_local_file(DEMO_DATA_PATH, "demo_data.pt")
|
| 83 |
if MOE_DATA_PATH.exists():
|
| 84 |
+
print(f"[INFO] Loading MoE-augmented dataset from {MOE_DATA_PATH}")
|
| 85 |
+
return _safe_load_tensor(MOE_DATA_PATH), True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
if not DEMO_DATA_PATH.exists():
|
| 87 |
raise FileNotFoundError(f"Dataset not found at {DEMO_DATA_PATH}")
|
| 88 |
+
print(f"[WARN] {MOE_DATA_PATH} missing; falling back to base data only")
|
| 89 |
+
return _safe_load_tensor(DEMO_DATA_PATH), False
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
def load_data(mapping: Dict[str, object]):
|
| 93 |
+
data, has_moe = load_augmented_samples()
|
| 94 |
+
pair_to_name = mapping["pair_to_name"]
|
| 95 |
+
pair_to_id = mapping["pair_to_id"]
|
| 96 |
+
|
| 97 |
records = []
|
| 98 |
for i, sample in enumerate(data):
|
| 99 |
embedding = sample["embedding"]
|
|
|
|
| 118 |
joint_label = pair_to_name.get(pair)
|
| 119 |
joint_label_id = pair_to_id.get(pair)
|
| 120 |
|
| 121 |
+
tsne_x = sample.get("tsne_x")
|
| 122 |
+
tsne_y = sample.get("tsne_y")
|
| 123 |
+
tsne_raw_x = sample.get("tsne_raw_x")
|
| 124 |
+
tsne_raw_y = sample.get("tsne_raw_y")
|
| 125 |
+
|
| 126 |
records.append(
|
| 127 |
{
|
| 128 |
+
"index": i,
|
| 129 |
"tech": sample["tech"],
|
| 130 |
"snr": sample["snr"],
|
| 131 |
"mod": sample["mod"],
|
|
|
|
| 135 |
"spectrogram": flat_spec,
|
| 136 |
"joint_label": joint_label,
|
| 137 |
"joint_label_id": joint_label_id,
|
| 138 |
+
"tsne_x": tsne_x,
|
| 139 |
+
"tsne_y": tsne_y,
|
| 140 |
+
"tsne_raw_x": tsne_raw_x,
|
| 141 |
+
"tsne_raw_y": tsne_raw_y,
|
| 142 |
}
|
| 143 |
)
|
| 144 |
|
| 145 |
df = pd.DataFrame(records)
|
| 146 |
+
print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})")
|
| 147 |
return df, has_moe
|
| 148 |
|
| 149 |
|
|
|
|
| 166 |
return filtered
|
| 167 |
|
| 168 |
|
| 169 |
+
def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, color_label, perplexity, n_iter):
|
| 170 |
filtered_df = apply_filters(df, tech_filter, snr_filter, mod_filter, mob_filter)
|
| 171 |
if len(filtered_df) < 5:
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
color_column = COLOR_OPTIONS.get(color_label, "snr")
|
| 175 |
|
| 176 |
+
tsne_cols = ("tsne_x", "tsne_y") if representation == "LWM Embedding" else ("tsne_raw_x", "tsne_raw_y")
|
| 177 |
+
has_cached = all(col in filtered_df.columns for col in tsne_cols)
|
| 178 |
+
if has_cached:
|
| 179 |
+
valid = filtered_df[tsne_cols[0]].notna().all() and filtered_df[tsne_cols[1]].notna().all()
|
| 180 |
else:
|
| 181 |
+
valid = False
|
| 182 |
+
|
| 183 |
+
if valid:
|
| 184 |
+
filtered_df = filtered_df.copy()
|
| 185 |
+
filtered_df["x"] = filtered_df[tsne_cols[0]]
|
| 186 |
+
filtered_df["y"] = filtered_df[tsne_cols[1]]
|
| 187 |
+
else:
|
| 188 |
+
sampled_df = filtered_df
|
| 189 |
+
if len(sampled_df) > 1200:
|
| 190 |
+
sampled_df = sampled_df.sample(n=1200, random_state=42)
|
| 191 |
+
sampled_df = sampled_df.copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
if representation == "LWM Embedding":
|
| 194 |
+
features = np.stack(sampled_df["embedding"].values)
|
| 195 |
+
else:
|
| 196 |
+
features = np.stack(sampled_df["spectrogram"].values)
|
| 197 |
+
if features.shape[1] > 50:
|
| 198 |
+
pca = PCA(n_components=50, random_state=42)
|
| 199 |
+
features = pca.fit_transform(features)
|
| 200 |
+
|
| 201 |
+
eff_perplexity = min(perplexity, len(sampled_df) - 1)
|
| 202 |
+
eff_perplexity = max(5, eff_perplexity)
|
| 203 |
+
tsne = TSNE(
|
| 204 |
+
n_components=2,
|
| 205 |
+
perplexity=eff_perplexity,
|
| 206 |
+
n_iter=n_iter,
|
| 207 |
+
random_state=42,
|
| 208 |
+
init="pca",
|
| 209 |
+
learning_rate="auto",
|
| 210 |
+
)
|
| 211 |
+
try:
|
| 212 |
+
projections = tsne.fit_transform(features)
|
| 213 |
+
except Exception as exc:
|
| 214 |
+
pca = PCA(n_components=2, random_state=42)
|
| 215 |
+
projections = pca.fit_transform(features)
|
| 216 |
+
sampled_df["x"] = projections[:, 0]
|
| 217 |
+
sampled_df["y"] = projections[:, 1]
|
| 218 |
+
filtered_df = sampled_df
|
| 219 |
fig = px.scatter(
|
| 220 |
filtered_df,
|
| 221 |
x="x",
|
| 222 |
y="y",
|
| 223 |
+
color=color_column,
|
| 224 |
hover_data=["tech", "snr", "mod", "mob"],
|
| 225 |
title=f"t-SNE of {representation} ({len(filtered_df)} samples)",
|
| 226 |
template="plotly_white",
|
| 227 |
)
|
| 228 |
+
height = 680 if color_label == "SNR" else 640
|
| 229 |
+
fig.update_layout(
|
| 230 |
+
legend_title_text=color_label,
|
| 231 |
+
width=640,
|
| 232 |
+
height=height,
|
| 233 |
+
)
|
| 234 |
+
fig.update_yaxes(scaleanchor="x", scaleratio=1)
|
| 235 |
+
return fig
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def build_raw_feature_matrix(samples: pd.Series, max_components: int = 256) -> np.ndarray:
|
| 239 |
+
raw_flat = []
|
| 240 |
+
for spec in samples:
|
| 241 |
+
arr = np.asarray(spec, dtype=np.float32)
|
| 242 |
+
raw_flat.append(arr.reshape(-1))
|
| 243 |
+
matrix = np.stack(raw_flat)
|
| 244 |
+
matrix = np.nan_to_num(matrix, copy=False)
|
| 245 |
+
scaler = StandardScaler()
|
| 246 |
+
matrix = scaler.fit_transform(matrix)
|
| 247 |
+
if max_components and matrix.shape[1] > max_components:
|
| 248 |
+
projector = PCA(n_components=max_components, random_state=42)
|
| 249 |
+
matrix = projector.fit_transform(matrix)
|
| 250 |
+
return matrix
|
| 251 |
|
| 252 |
|
| 253 |
def stratified_split(filtered_df: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
|
| 254 |
rng = np.random.default_rng(int(seed))
|
| 255 |
+
train_indices = []
|
| 256 |
+
test_indices = []
|
| 257 |
|
| 258 |
for label_id, group in filtered_df.groupby("joint_label_id"):
|
| 259 |
indices = group.index.to_numpy()
|
|
|
|
| 269 |
return np.array(train_indices), np.array(test_indices)
|
| 270 |
|
| 271 |
|
| 272 |
+
def select_knn_k(train_labels: np.ndarray, max_k: int = 9) -> int:
|
| 273 |
+
if train_labels.size == 0:
|
| 274 |
+
return 1
|
| 275 |
+
class_counts = pd.Series(train_labels).value_counts()
|
| 276 |
+
min_class = int(class_counts.min())
|
| 277 |
+
heuristic = int(np.sqrt(train_labels.size))
|
| 278 |
+
candidate = max(1, min(max_k, heuristic))
|
| 279 |
+
k = max(1, min(candidate, min_class))
|
| 280 |
+
if k % 2 == 0 and k > 1:
|
| 281 |
+
k -= 1
|
| 282 |
+
return k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
def plot_confusion_heatmap(
|
| 286 |
+
confusion: np.ndarray, label_names: List[str], title: str = "Prototype Classifier Confusion Matrix"
|
| 287 |
+
) -> go.Figure:
|
| 288 |
fig = go.Figure(
|
| 289 |
data=go.Heatmap(
|
| 290 |
z=confusion,
|
|
|
|
| 295 |
)
|
| 296 |
)
|
| 297 |
fig.update_layout(
|
| 298 |
+
title=title,
|
| 299 |
xaxis_title="Predicted",
|
| 300 |
yaxis_title="True",
|
| 301 |
xaxis=dict(tickangle=45),
|
|
|
|
| 304 |
|
| 305 |
|
| 306 |
def run_joint_evaluation(train_pct, seed, tech_filter, snr_filter, mod_filter, mob_filter):
|
| 307 |
+
if evaluation_disabled:
|
| 308 |
fig = go.Figure()
|
| 309 |
fig.update_layout(title="MoE embeddings unavailable", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 310 |
+
return fig, fig, "MoE embeddings are not available in this Space build."
|
| 311 |
|
| 312 |
filtered = apply_filters(joint_eval_df, tech_filter, snr_filter, mod_filter, mob_filter)
|
| 313 |
if filtered.empty:
|
| 314 |
fig = go.Figure()
|
| 315 |
fig.update_layout(title="No samples after filtering", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 316 |
+
return fig, fig, "No samples match the selected filters."
|
| 317 |
|
| 318 |
if filtered["joint_label_id"].nunique() < 2:
|
| 319 |
fig = go.Figure()
|
| 320 |
fig.update_layout(title="Need at least two classes", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 321 |
+
return fig, fig, "Need at least two joint SNR/Doppler classes to evaluate."
|
| 322 |
+
|
| 323 |
+
filtered = filtered.reset_index(drop=True)
|
| 324 |
|
| 325 |
try:
|
| 326 |
train_idx, test_idx = stratified_split(filtered, train_pct / 100.0, seed)
|
| 327 |
except ValueError as exc:
|
| 328 |
fig = go.Figure()
|
| 329 |
fig.update_layout(title="Unable to split dataset", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 330 |
+
return fig, fig, str(exc)
|
| 331 |
+
|
| 332 |
+
labels = filtered["joint_label_id"].to_numpy(dtype=int)
|
| 333 |
+
moe_features = np.stack(filtered["moe_embedding"].values)
|
| 334 |
+
raw_features = build_raw_feature_matrix(filtered["spectrogram"], max_components=256)
|
| 335 |
+
|
| 336 |
+
train_labels = labels[train_idx]
|
| 337 |
+
knn_k = select_knn_k(train_labels)
|
| 338 |
+
|
| 339 |
+
moe_metrics = compute_knn_metrics(moe_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS)
|
| 340 |
+
raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k, label_lookup=CLASS_LABELS)
|
| 341 |
+
|
| 342 |
+
moe_fig = plot_confusion_heatmap(
|
| 343 |
+
moe_metrics["confusion"], moe_metrics["label_names"], title=f"MoE Embedding Confusion (k={moe_metrics['k']})"
|
| 344 |
+
)
|
| 345 |
+
raw_fig = plot_confusion_heatmap(
|
| 346 |
+
raw_metrics["confusion"], raw_metrics["label_names"], title=f"Raw Spectrogram Confusion (k={raw_metrics['k']})"
|
| 347 |
+
)
|
| 348 |
|
|
|
|
|
|
|
| 349 |
status = (
|
| 350 |
+
f"### Joint SNR/Doppler Metrics\n"
|
| 351 |
+
f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Train %:** {train_pct}% | **Seed:** {seed} | **k-NN k:** {knn_k}\n\n"
|
| 352 |
+
"| Representation | Accuracy | Macro F1 |\n"
|
| 353 |
+
"| --- | --- | --- |\n"
|
| 354 |
+
f"| **MoE Embedding** | {moe_metrics['accuracy'] * 100:.2f}% | {moe_metrics['macro_f1']:.3f} |\n"
|
| 355 |
+
f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |"
|
| 356 |
+
)
|
| 357 |
+
return moe_fig, raw_fig, status
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def stratified_split_mod(df_subset: pd.DataFrame, train_ratio: float, seed: int) -> Tuple[np.ndarray, np.ndarray]:
|
| 361 |
+
rng = np.random.default_rng(int(seed))
|
| 362 |
+
train_idx = []
|
| 363 |
+
test_idx = []
|
| 364 |
+
for _, group in df_subset.groupby("mod"):
|
| 365 |
+
indices = group.index.to_numpy()
|
| 366 |
+
if indices.size < 2:
|
| 367 |
+
raise ValueError("Each modulation needs at least 2 samples.")
|
| 368 |
+
rng.shuffle(indices)
|
| 369 |
+
split = int(round(len(indices) * train_ratio))
|
| 370 |
+
split = max(1, min(len(indices) - 1, split))
|
| 371 |
+
train_idx.extend(indices[:split])
|
| 372 |
+
test_idx.extend(indices[split:])
|
| 373 |
+
return np.array(train_idx), np.array(test_idx)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def compute_knn_metrics(
|
| 377 |
+
features: np.ndarray,
|
| 378 |
+
labels: np.ndarray,
|
| 379 |
+
train_idx: np.ndarray,
|
| 380 |
+
test_idx: np.ndarray,
|
| 381 |
+
knn_k: int,
|
| 382 |
+
label_lookup: List[str] | None = None,
|
| 383 |
+
) -> Dict[str, object]:
|
| 384 |
+
train_features = features[train_idx]
|
| 385 |
+
test_features = features[test_idx]
|
| 386 |
+
train_labels = labels[train_idx]
|
| 387 |
+
test_labels = labels[test_idx]
|
| 388 |
+
|
| 389 |
+
candidate_k = max(1, min(int(knn_k), len(train_labels)))
|
| 390 |
+
if candidate_k % 2 == 0 and candidate_k > 1:
|
| 391 |
+
candidate_k -= 1
|
| 392 |
+
knn = KNeighborsClassifier(n_neighbors=candidate_k, metric="euclidean")
|
| 393 |
+
knn.fit(train_features, train_labels)
|
| 394 |
+
preds = knn.predict(test_features)
|
| 395 |
+
|
| 396 |
+
acc = accuracy_score(test_labels, preds)
|
| 397 |
+
active_labels = np.unique(np.concatenate([train_labels, test_labels, preds]))
|
| 398 |
+
macro = f1_score(test_labels, preds, labels=active_labels, average="macro", zero_division=0)
|
| 399 |
+
|
| 400 |
+
if label_lookup is None:
|
| 401 |
+
label_names = [str(lbl) for lbl in active_labels]
|
| 402 |
+
else:
|
| 403 |
+
label_names = [label_lookup[int(lbl)] for lbl in active_labels]
|
| 404 |
+
|
| 405 |
+
cm = confusion_matrix(test_labels, preds, labels=active_labels)
|
| 406 |
+
return {
|
| 407 |
+
"accuracy": acc,
|
| 408 |
+
"macro_f1": macro,
|
| 409 |
+
"confusion": cm,
|
| 410 |
+
"label_names": label_names,
|
| 411 |
+
"k": candidate_k,
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def evaluate_modulation(tech: str, train_pct: int, seed: int):
|
| 416 |
+
if not tech:
|
| 417 |
+
fig = go.Figure()
|
| 418 |
+
fig.update_layout(title="Select a technology to evaluate.", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 419 |
+
return fig, fig, "No technology selected."
|
| 420 |
+
|
| 421 |
+
subset = df[df["tech"] == tech].copy().reset_index(drop=True)
|
| 422 |
+
if subset.empty or subset["mod"].nunique() < 2:
|
| 423 |
+
fig = go.Figure()
|
| 424 |
+
fig.update_layout(
|
| 425 |
+
title="Need at least two modulation classes for this technology.",
|
| 426 |
+
xaxis=dict(visible=False),
|
| 427 |
+
yaxis=dict(visible=False),
|
| 428 |
+
)
|
| 429 |
+
return fig, fig, "Not enough modulation classes."
|
| 430 |
+
|
| 431 |
+
try:
|
| 432 |
+
train_idx, test_idx = stratified_split_mod(subset, train_pct / 100.0, seed)
|
| 433 |
+
except ValueError as exc:
|
| 434 |
+
fig = go.Figure()
|
| 435 |
+
fig.update_layout(title=str(exc), xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 436 |
+
return fig, fig, str(exc)
|
| 437 |
+
|
| 438 |
+
labels = subset["mod"].astype(str).to_numpy()
|
| 439 |
+
emb_features = np.stack(subset["embedding"].values)
|
| 440 |
+
|
| 441 |
+
raw_features = build_raw_feature_matrix(subset["spectrogram"], max_components=256)
|
| 442 |
+
|
| 443 |
+
train_labels = labels[train_idx]
|
| 444 |
+
class_counts = pd.Series(train_labels).value_counts()
|
| 445 |
+
if class_counts.empty:
|
| 446 |
+
fig = go.Figure()
|
| 447 |
+
fig.update_layout(title="No modulation classes found.", xaxis=dict(visible=False), yaxis=dict(visible=False))
|
| 448 |
+
return fig, fig, "No modulation classes found."
|
| 449 |
+
|
| 450 |
+
knn_k = select_knn_k(train_labels)
|
| 451 |
+
|
| 452 |
+
emb_metrics = compute_knn_metrics(emb_features, labels, train_idx, test_idx, knn_k)
|
| 453 |
+
raw_metrics = compute_knn_metrics(raw_features, labels, train_idx, test_idx, knn_k)
|
| 454 |
+
|
| 455 |
+
emb_fig = plot_confusion_heatmap(emb_metrics["confusion"], emb_metrics["label_names"], title="Embedding Confusion")
|
| 456 |
+
raw_fig = plot_confusion_heatmap(raw_metrics["confusion"], raw_metrics["label_names"], title="Raw Confusion")
|
| 457 |
+
|
| 458 |
+
summary = (
|
| 459 |
+
f"### {tech} Modulation Metrics\n"
|
| 460 |
+
f"**Train/Test Samples:** {len(train_idx)} / {len(test_idx)} | **Classifier:** k-NN (k = {emb_metrics['k']})\n\n"
|
| 461 |
+
"| Representation | Accuracy | Macro F1 |\n"
|
| 462 |
+
"| --- | --- | --- |\n"
|
| 463 |
+
f"| **LWM Embedding** | {emb_metrics['accuracy'] * 100:.2f}% | {emb_metrics['macro_f1']:.3f} |\n"
|
| 464 |
+
f"| **Raw Spectrogram** | {raw_metrics['accuracy'] * 100:.2f}% | {raw_metrics['macro_f1']:.3f} |"
|
| 465 |
)
|
| 466 |
+
return emb_fig, raw_fig, summary
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _reshape_spectrogram(spec: np.ndarray) -> np.ndarray:
|
| 470 |
+
arr = np.asarray(spec)
|
| 471 |
+
if arr.ndim == 1:
|
| 472 |
+
side = int(round(arr.size ** 0.5))
|
| 473 |
+
if side * side == arr.size:
|
| 474 |
+
arr = arr.reshape(side, side)
|
| 475 |
+
else:
|
| 476 |
+
arr = arr.reshape(-1, side)
|
| 477 |
+
elif arr.ndim == 3:
|
| 478 |
+
arr = arr.squeeze()
|
| 479 |
+
return arr
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _spectrogram_to_image(spec: np.ndarray, title: str) -> np.ndarray:
|
| 483 |
+
normalized = spec.astype(np.float32)
|
| 484 |
+
if np.isnan(normalized).any():
|
| 485 |
+
normalized = np.nan_to_num(normalized)
|
| 486 |
+
vmin, vmax = normalized.min(), normalized.max()
|
| 487 |
+
if vmax - vmin > 0:
|
| 488 |
+
normalized = (normalized - vmin) / (vmax - vmin)
|
| 489 |
+
fig, ax = plt.subplots(figsize=(3, 3))
|
| 490 |
+
im = ax.imshow(normalized, cmap="turbo", aspect="auto", origin="lower")
|
| 491 |
+
ax.set_xticks([])
|
| 492 |
+
ax.set_yticks([])
|
| 493 |
+
ax.set_title(title, fontsize=8)
|
| 494 |
+
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 495 |
+
cbar.ax.tick_params(labelsize=6)
|
| 496 |
+
fig.tight_layout(pad=0.5)
|
| 497 |
+
canvas = FigureCanvasAgg(fig)
|
| 498 |
+
canvas.draw()
|
| 499 |
+
width, height = canvas.get_width_height()
|
| 500 |
+
buf = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(height, width, 4)
|
| 501 |
+
image = buf[..., :3].copy()
|
| 502 |
+
plt.close(fig)
|
| 503 |
+
return image
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def render_spectrogram_gallery(tech, snr, mod, mob, sample_count, seed):
|
| 507 |
+
tech_list = [tech] if tech else None
|
| 508 |
+
snr_list = [snr] if snr else None
|
| 509 |
+
mod_list = [mod] if mod else None
|
| 510 |
+
mob_list = [mob] if mob else None
|
| 511 |
+
|
| 512 |
+
filtered = apply_filters(df, tech_list, snr_list, mod_list, mob_list)
|
| 513 |
+
if filtered.empty:
|
| 514 |
+
return [], "No spectrograms match the selected filters."
|
| 515 |
+
|
| 516 |
+
sample_count = max(1, int(sample_count))
|
| 517 |
+
rng = np.random.default_rng(int(seed))
|
| 518 |
+
if len(filtered) > sample_count:
|
| 519 |
+
indices = rng.choice(filtered.index.to_numpy(), size=sample_count, replace=False)
|
| 520 |
+
subset = filtered.loc[indices]
|
| 521 |
+
else:
|
| 522 |
+
subset = filtered
|
| 523 |
+
|
| 524 |
+
gallery_items = []
|
| 525 |
+
for _, row in subset.iterrows():
|
| 526 |
+
spec = _reshape_spectrogram(row["spectrogram"])
|
| 527 |
+
caption = f"{row['tech']} | {row['mod']} | {row['snr']} | {row['mob']}"
|
| 528 |
+
img = _spectrogram_to_image(spec, caption)
|
| 529 |
+
gallery_items.append((img, caption))
|
| 530 |
+
|
| 531 |
+
status = f"Showing {len(subset)} spectrograms (seed={seed})."
|
| 532 |
+
return gallery_items, status
|
| 533 |
|
| 534 |
|
| 535 |
mapping_info = load_joint_mapping()
|
| 536 |
df, has_moe_embeddings = load_data(mapping_info)
|
| 537 |
+
CLASS_LABELS = mapping_info["label_names"]
|
| 538 |
|
| 539 |
+
has_moe_column = df["moe_embedding"].apply(lambda x: x is not None)
|
| 540 |
+
joint_eval_df = df[has_moe_column & df["joint_label_id"].notna()]
|
|
|
|
| 541 |
|
| 542 |
tech_choices = sorted(df["tech"].unique())
|
| 543 |
snr_choices = sorted(df["snr"].unique())
|
| 544 |
mod_choices = sorted(df["mod"].unique())
|
| 545 |
mob_choices = sorted(df["mob"].unique())
|
| 546 |
|
| 547 |
+
TECH_TO_MODS: Dict[str, List[str]] = {
|
| 548 |
+
tech: sorted(df.loc[df["tech"] == tech, "mod"].unique().tolist()) for tech in tech_choices
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
COLOR_OPTIONS: Dict[str, str] = {
|
| 552 |
+
"SNR": "snr",
|
| 553 |
+
"Modulation": "mod",
|
| 554 |
+
"Mobility": "mob",
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
default_tech = tech_choices[:1] if tech_choices else []
|
| 558 |
+
initial_spec_mod_choices = TECH_TO_MODS.get(default_tech[0], mod_choices) if default_tech else mod_choices
|
| 559 |
|
| 560 |
+
evaluation_disabled = (not has_moe_embeddings) or joint_eval_df.empty
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def update_modulation_choices(selected_tech: Optional[str]):
|
| 564 |
+
choices = mod_choices
|
| 565 |
+
if selected_tech:
|
| 566 |
+
choices = TECH_TO_MODS.get(selected_tech, mod_choices)
|
| 567 |
+
return gr.Dropdown.update(choices=choices, value=None)
|
| 568 |
+
|
| 569 |
+
with gr.Blocks(title="LWM-Spectro Lab") as demo:
|
| 570 |
gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
|
| 571 |
gr.Markdown(
|
| 572 |
"""
|
| 573 |
+
Compare **LWM embeddings** vs **Raw Spectrograms** for visualization, then evaluate **precomputed MoE embeddings**
|
| 574 |
+
with a lightweight k-NN prototype classifier for joint SNR/Doppler recognition.
|
| 575 |
"""
|
| 576 |
)
|
| 577 |
|
| 578 |
with gr.Tabs():
|
| 579 |
+
with gr.Tab("Spectrograms"):
|
| 580 |
+
gr.Markdown("Visualize raw 128×128 spectrograms with optional filters.")
|
| 581 |
+
with gr.Row():
|
| 582 |
+
with gr.Column(scale=1, min_width=320):
|
| 583 |
+
spec_tech = gr.Dropdown(
|
| 584 |
+
choices=tech_choices,
|
| 585 |
+
value=default_tech[0] if default_tech else None,
|
| 586 |
+
label="Technology",
|
| 587 |
+
)
|
| 588 |
+
spec_snr = gr.Dropdown(choices=snr_choices, value=None, label="SNR (optional)")
|
| 589 |
+
spec_mod = gr.Dropdown(choices=initial_spec_mod_choices, value=None, label="Modulation (optional)")
|
| 590 |
+
spec_mob = gr.Dropdown(choices=mob_choices, value=None, label="Mobility (optional)")
|
| 591 |
+
spec_count = gr.Slider(minimum=1, maximum=12, step=1, value=6, label="Samples to show")
|
| 592 |
+
spec_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=0, label="Random seed")
|
| 593 |
+
spec_btn = gr.Button("Show spectrograms", variant="primary")
|
| 594 |
+
with gr.Column(scale=3):
|
| 595 |
+
gallery = gr.Gallery(
|
| 596 |
+
label="Spectrogram Samples",
|
| 597 |
+
columns=[3],
|
| 598 |
+
rows=[3],
|
| 599 |
+
height=560,
|
| 600 |
+
preview=True,
|
| 601 |
+
)
|
| 602 |
+
gallery_status = gr.Textbox(label="Status", interactive=False)
|
| 603 |
+
spec_inputs = [spec_tech, spec_snr, spec_mod, spec_mob, spec_count, spec_seed]
|
| 604 |
+
spec_btn.click(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status])
|
| 605 |
+
demo.load(render_spectrogram_gallery, inputs=spec_inputs, outputs=[gallery, gallery_status])
|
| 606 |
+
spec_tech.change(update_modulation_choices, inputs=spec_tech, outputs=spec_mod)
|
| 607 |
+
|
| 608 |
+
with gr.Tab("t-SNE Analysis"):
|
| 609 |
with gr.Row():
|
| 610 |
with gr.Column(scale=1, min_width=300):
|
| 611 |
gr.Markdown("### Filters")
|
| 612 |
+
tech_filter = gr.CheckboxGroup(choices=tech_choices, value=default_tech, label="Technology")
|
| 613 |
snr_filter = gr.Dropdown(
|
| 614 |
choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)"
|
| 615 |
)
|
|
|
|
| 626 |
value="LWM Embedding",
|
| 627 |
label="Representation",
|
| 628 |
)
|
| 629 |
+
color_by = gr.Dropdown(
|
| 630 |
+
choices=list(COLOR_OPTIONS.keys()),
|
| 631 |
+
value="SNR",
|
| 632 |
+
label="Color By",
|
| 633 |
+
)
|
| 634 |
|
| 635 |
with gr.Accordion("Advanced t-SNE Settings", open=False):
|
| 636 |
perplexity = gr.Slider(minimum=5, maximum=50, value=30, step=1, label="Perplexity")
|
| 637 |
n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
|
| 638 |
|
| 639 |
btn = gr.Button("Update Plot", variant="primary")
|
|
|
|
| 640 |
|
| 641 |
with gr.Column(scale=3):
|
| 642 |
plot = gr.Plot(label="t-SNE Visualization")
|
|
|
|
| 644 |
btn.click(
|
| 645 |
plot_tsne,
|
| 646 |
inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
|
| 647 |
+
outputs=[plot],
|
| 648 |
)
|
| 649 |
|
| 650 |
demo.load(
|
| 651 |
plot_tsne,
|
| 652 |
inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter],
|
| 653 |
+
outputs=[plot],
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
with gr.Tab("Modulation Classification"):
|
| 657 |
+
gr.Markdown("Compare LWM embeddings vs raw spectrograms for per-technology modulation classification.")
|
| 658 |
+
with gr.Row():
|
| 659 |
+
with gr.Column(scale=1, min_width=320):
|
| 660 |
+
mod_tech = gr.Dropdown(
|
| 661 |
+
choices=tech_choices,
|
| 662 |
+
value=default_tech[0] if default_tech else None,
|
| 663 |
+
label="Technology",
|
| 664 |
+
)
|
| 665 |
+
mod_train = gr.Slider(minimum=50, maximum=90, step=5, value=70, label="Training Percentage (%)")
|
| 666 |
+
mod_seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
|
| 667 |
+
gr.Markdown("k-NN uses an adaptive k based on the number of modulation classes and available training samples.")
|
| 668 |
+
mod_btn = gr.Button("Run modulation evaluation", variant="primary")
|
| 669 |
+
with gr.Column(scale=3):
|
| 670 |
+
with gr.Row():
|
| 671 |
+
emb_plot = gr.Plot(label="Embedding Confusion Matrix")
|
| 672 |
+
raw_plot = gr.Plot(label="Raw Confusion Matrix")
|
| 673 |
+
mod_summary = gr.Markdown(value="Select a technology and run the evaluation to view metrics.")
|
| 674 |
+
mod_btn.click(
|
| 675 |
+
evaluate_modulation,
|
| 676 |
+
inputs=[mod_tech, mod_train, mod_seed],
|
| 677 |
+
outputs=[emb_plot, raw_plot, mod_summary],
|
| 678 |
)
|
| 679 |
|
| 680 |
+
with gr.Tab("Joint SNR/Doppler Evaluation"):
|
| 681 |
if evaluation_disabled:
|
| 682 |
gr.Markdown(
|
| 683 |
+
"⚠️ Precomputed MoE embeddings are not bundled in this Space build. Upload a dataset locally to run evaluations."
|
| 684 |
)
|
| 685 |
|
| 686 |
with gr.Row():
|
|
|
|
| 688 |
gr.Markdown("### Evaluation Filters")
|
| 689 |
eval_tech_filter = gr.CheckboxGroup(
|
| 690 |
choices=tech_choices,
|
| 691 |
+
value=default_tech,
|
| 692 |
label="Technology",
|
| 693 |
interactive=not evaluation_disabled,
|
| 694 |
)
|
|
|
|
| 734 |
eval_btn = gr.Button("Run evaluation", variant="primary", interactive=not evaluation_disabled)
|
| 735 |
|
| 736 |
with gr.Column(scale=3):
|
| 737 |
+
with gr.Row():
|
| 738 |
+
eval_plot = gr.Plot(label="MoE Prototype Confusion")
|
| 739 |
+
eval_plot_raw = gr.Plot(label="Raw Prototype Confusion")
|
| 740 |
+
eval_status = gr.Markdown(value="Run an evaluation to compare MoE vs raw baselines.")
|
| 741 |
|
| 742 |
eval_btn.click(
|
| 743 |
run_joint_evaluation,
|
| 744 |
inputs=[train_pct, seed, eval_tech_filter, eval_snr_filter, eval_mod_filter, eval_mob_filter],
|
| 745 |
+
outputs=[eval_plot, eval_plot_raw, eval_status],
|
| 746 |
)
|
| 747 |
|
| 748 |
if __name__ == "__main__":
|
pretraining/README.md
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
# 🔬 Pretraining Scripts
|
| 2 |
-
|
| 3 |
-
This folder contains scripts for **Large Wireless Model (LWM)** pre-training.
|
| 4 |
-
|
| 5 |
-
## 📁 File Descriptions
|
| 6 |
-
|
| 7 |
-
### `train_lwm_spectro.py`
|
| 8 |
-
- **Purpose**: Pre-train LWM model with spectrogram data
|
| 9 |
-
- **Features**:
|
| 10 |
-
- Self-supervised learning through masked patch prediction
|
| 11 |
-
- Multi-size spectrogram support (32x32, 128x128)
|
| 12 |
-
- MSE loss-based reconstruction
|
| 13 |
-
- Real-time training monitoring and result storage
|
| 14 |
-
|
| 15 |
-
### `pretrained_model.py`
|
| 16 |
-
- **Purpose**: Define structure of pre-trained LWM model
|
| 17 |
-
- **Features**: LWM model architecture implementation
|
| 18 |
-
|
| 19 |
-
## 🚀 Usage
|
| 20 |
-
|
| 21 |
-
### Basic Training Execution
|
| 22 |
-
```bash
|
| 23 |
-
cd pretraining
|
| 24 |
-
python train_lwm_spectro.py
|
| 25 |
-
```
|
| 26 |
-
|
| 27 |
-
### GPU Memory Optimization
|
| 28 |
-
```bash
|
| 29 |
-
cd pretraining
|
| 30 |
-
python train_lwm_spectro.py # GPU 메모리에 맞춰 batch_size 조정
|
| 31 |
-
```
|
| 32 |
-
|
| 33 |
-
### Check Results
|
| 34 |
-
Training results are automatically saved in `models/` folder:
|
| 35 |
-
- `*_checkpoint.pth`: Model checkpoint
|
| 36 |
-
- `*_training_history.json`: Training history
|
| 37 |
-
- `*_training_curves.png`: Training curve graphs
|
| 38 |
-
|
| 39 |
-
## 📊 Research Perspective
|
| 40 |
-
|
| 41 |
-
These scripts are used to study **LWM's representation learning capabilities**:
|
| 42 |
-
- Extract meaningful features from spectrograms
|
| 43 |
-
- Generalized representation learning through unsupervised learning
|
| 44 |
-
- Validate transfer learning effectiveness in downstream tasks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretraining/__init__.py
DELETED
|
File without changes
|
pretraining/pretrained_model.py
DELETED
|
@@ -1,180 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import numpy as np
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class LayerNormalization(nn.Module):
|
| 8 |
-
"""Layer norm with learnable scale and bias."""
|
| 9 |
-
|
| 10 |
-
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
| 11 |
-
super().__init__()
|
| 12 |
-
self.eps = eps
|
| 13 |
-
self.alpha = nn.Parameter(torch.ones(d_model))
|
| 14 |
-
self.bias = nn.Parameter(torch.zeros(d_model))
|
| 15 |
-
|
| 16 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
-
mean = x.mean(dim=-1, keepdim=True)
|
| 18 |
-
std = x.std(dim=-1, keepdim=True)
|
| 19 |
-
return self.alpha * (x - mean) / (std + self.eps) + self.bias
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class Embedding(nn.Module):
|
| 23 |
-
"""Linear projection + positional embedding with optional max_len override."""
|
| 24 |
-
|
| 25 |
-
def __init__(self, element_length: int, d_model: int, max_len: int | None = None) -> None:
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.element_length = element_length
|
| 28 |
-
self.d_model = d_model
|
| 29 |
-
self.max_len = max_len if max_len is not None else 1025
|
| 30 |
-
|
| 31 |
-
self.proj = nn.Linear(element_length, d_model)
|
| 32 |
-
self.pos_embed = nn.Embedding(self.max_len, d_model)
|
| 33 |
-
self.norm = LayerNormalization(d_model)
|
| 34 |
-
|
| 35 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 36 |
-
seq_len = x.size(1)
|
| 37 |
-
if seq_len > self.max_len:
|
| 38 |
-
raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}.")
|
| 39 |
-
|
| 40 |
-
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
| 41 |
-
pos_encodings = self.pos_embed(pos)
|
| 42 |
-
tok_emb = self.proj(x.float())
|
| 43 |
-
return self.norm(tok_emb + pos_encodings)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class ScaledDotProductAttention(nn.Module):
|
| 47 |
-
"""Scaled dot-product attention."""
|
| 48 |
-
|
| 49 |
-
def __init__(self, d_k: int) -> None:
|
| 50 |
-
super().__init__()
|
| 51 |
-
self.d_k = d_k
|
| 52 |
-
|
| 53 |
-
def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
-
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
|
| 55 |
-
attn = F.softmax(scores, dim=-1)
|
| 56 |
-
context = torch.matmul(attn, V)
|
| 57 |
-
return context, attn
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
class MultiHeadAttention(nn.Module):
|
| 61 |
-
"""Multi-head self-attention module."""
|
| 62 |
-
|
| 63 |
-
def __init__(self, d_model: int, n_heads: int, dropout: float) -> None:
|
| 64 |
-
super().__init__()
|
| 65 |
-
if d_model % n_heads != 0:
|
| 66 |
-
raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads}).")
|
| 67 |
-
|
| 68 |
-
self.d_k = d_model // n_heads
|
| 69 |
-
self.d_v = d_model // n_heads
|
| 70 |
-
self.n_heads = n_heads
|
| 71 |
-
|
| 72 |
-
self.W_Q = nn.Linear(d_model, self.d_k * n_heads)
|
| 73 |
-
self.W_K = nn.Linear(d_model, self.d_k * n_heads)
|
| 74 |
-
self.W_V = nn.Linear(d_model, self.d_v * n_heads)
|
| 75 |
-
self.linear = nn.Linear(n_heads * self.d_v, d_model)
|
| 76 |
-
self.dropout = nn.Dropout(dropout)
|
| 77 |
-
self.scaled_dot_attn = ScaledDotProductAttention(self.d_k)
|
| 78 |
-
|
| 79 |
-
def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 80 |
-
residual = Q
|
| 81 |
-
batch_size = Q.size(0)
|
| 82 |
-
|
| 83 |
-
q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
| 84 |
-
k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
| 85 |
-
v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
|
| 86 |
-
|
| 87 |
-
context, attn = self.scaled_dot_attn(q_s, k_s, v_s)
|
| 88 |
-
output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
|
| 89 |
-
output = self.linear(output)
|
| 90 |
-
return residual + self.dropout(output), attn
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
class PoswiseFeedForwardNet(nn.Module):
|
| 94 |
-
"""Position-wise feed-forward network."""
|
| 95 |
-
|
| 96 |
-
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
|
| 97 |
-
super().__init__()
|
| 98 |
-
self.fc1 = nn.Linear(d_model, d_ff)
|
| 99 |
-
self.fc2 = nn.Linear(d_ff, d_model)
|
| 100 |
-
self.dropout = nn.Dropout(dropout)
|
| 101 |
-
|
| 102 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 103 |
-
return self.fc2(self.dropout(F.relu(self.fc1(x))))
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
class EncoderLayer(nn.Module):
|
| 107 |
-
"""Transformer encoder block."""
|
| 108 |
-
|
| 109 |
-
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float) -> None:
|
| 110 |
-
super().__init__()
|
| 111 |
-
self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 112 |
-
self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff, dropout)
|
| 113 |
-
self.norm1 = LayerNormalization(d_model)
|
| 114 |
-
self.norm2 = LayerNormalization(d_model)
|
| 115 |
-
|
| 116 |
-
def forward(self, enc_inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 117 |
-
attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
|
| 118 |
-
attn_outputs = self.norm1(attn_outputs)
|
| 119 |
-
ff_outputs = self.pos_ffn(attn_outputs)
|
| 120 |
-
enc_outputs = self.norm2(attn_outputs + ff_outputs)
|
| 121 |
-
return enc_outputs, attn
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class LWM(nn.Module):
|
| 125 |
-
"""Large Wireless Model (Transformer encoder)."""
|
| 126 |
-
|
| 127 |
-
def __init__(
|
| 128 |
-
self,
|
| 129 |
-
element_length: int = 32,
|
| 130 |
-
d_model: int = 128,
|
| 131 |
-
n_layers: int = 12,
|
| 132 |
-
max_len: int | None = None,
|
| 133 |
-
n_heads: int = 8,
|
| 134 |
-
dropout: float = 0.1,
|
| 135 |
-
) -> None:
|
| 136 |
-
super().__init__()
|
| 137 |
-
|
| 138 |
-
self.element_length = element_length
|
| 139 |
-
self.d_model = d_model
|
| 140 |
-
self.n_layers = n_layers
|
| 141 |
-
self.max_len = max_len if max_len is not None else 1025
|
| 142 |
-
self.n_heads = n_heads
|
| 143 |
-
self.dropout = dropout
|
| 144 |
-
|
| 145 |
-
self.embedding = Embedding(element_length, d_model, self.max_len)
|
| 146 |
-
self.layers = nn.ModuleList(
|
| 147 |
-
[EncoderLayer(d_model, n_heads, d_model * 4, dropout) for _ in range(n_layers)]
|
| 148 |
-
)
|
| 149 |
-
self.linear = nn.Linear(d_model, d_model)
|
| 150 |
-
self.norm = LayerNormalization(d_model)
|
| 151 |
-
|
| 152 |
-
embed_weight = self.embedding.proj.weight
|
| 153 |
-
_, n_dim = embed_weight.size()
|
| 154 |
-
self.decoder = nn.Linear(d_model, n_dim, bias=False)
|
| 155 |
-
self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
|
| 156 |
-
|
| 157 |
-
def forward(
|
| 158 |
-
self,
|
| 159 |
-
input_ids: torch.Tensor,
|
| 160 |
-
masked_pos: torch.Tensor | None = None,
|
| 161 |
-
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
|
| 162 |
-
output = self.embedding(input_ids)
|
| 163 |
-
|
| 164 |
-
for layer in self.layers:
|
| 165 |
-
output, attn = layer(output)
|
| 166 |
-
|
| 167 |
-
if masked_pos is not None:
|
| 168 |
-
masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
|
| 169 |
-
h_masked = torch.gather(output, 1, masked_pos)
|
| 170 |
-
h_masked = self.norm(F.relu(self.linear(h_masked)))
|
| 171 |
-
logits_lm = self.decoder(h_masked) + self.decoder_bias
|
| 172 |
-
return logits_lm, output
|
| 173 |
-
|
| 174 |
-
return output
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def lwm(*args, **kwargs) -> LWM:
|
| 178 |
-
"""Factory to preserve backward compatibility with older imports."""
|
| 179 |
-
|
| 180 |
-
return LWM(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretraining/train_lwm_spectro.py
DELETED
|
@@ -1,741 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# =============================================================================
|
| 3 |
-
# train_lwm_spectro.py - LWM Pretraining with Complex-Valued Spectrogram Support
|
| 4 |
-
# Modified from train_lwm_spectro_no_contrast.py to handle complex spectrograms
|
| 5 |
-
# by separating real and imaginary parts and flattening them (similar to train_lwm.py)
|
| 6 |
-
# =============================================================================
|
| 7 |
-
|
| 8 |
-
# =============================================================================
|
| 9 |
-
# 1. IMPORTS AND WARNINGS SETUP
|
| 10 |
-
# - Load necessary PyTorch modules, utilities, and suppress UserWarnings
|
| 11 |
-
# =============================================================================
|
| 12 |
-
import sys
|
| 13 |
-
import os
|
| 14 |
-
import argparse
|
| 15 |
-
# Add project root to path (Windows compatible)
|
| 16 |
-
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 17 |
-
sys.path.insert(0, project_root)
|
| 18 |
-
import torch
|
| 19 |
-
import torch.nn as nn
|
| 20 |
-
import torch.nn.functional as F
|
| 21 |
-
from torch.utils.data import DataLoader, random_split, TensorDataset
|
| 22 |
-
import torch.optim as optim
|
| 23 |
-
from utils import (generate_spectrograms_and_labels, tokenizer_train,
|
| 24 |
-
create_train_dataloader, count_parameters, train_lwm)
|
| 25 |
-
import numpy as np
|
| 26 |
-
import pretrained_model # Assuming this contains the LWM model definition
|
| 27 |
-
from torch.optim.lr_scheduler import LambdaLR
|
| 28 |
-
from torch.optim import AdamW
|
| 29 |
-
import warnings
|
| 30 |
-
import platform
|
| 31 |
-
import re
|
| 32 |
-
from tqdm import tqdm
|
| 33 |
-
from datetime import datetime
|
| 34 |
-
import concurrent.futures
|
| 35 |
-
import multiprocessing
|
| 36 |
-
from collections import Counter
|
| 37 |
-
from functools import lru_cache
|
| 38 |
-
import json
|
| 39 |
-
|
| 40 |
-
SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
|
| 41 |
-
DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
|
| 42 |
-
DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _parse_snr_and_doppler(path: str) -> tuple[float, int]:
|
| 46 |
-
snr_db = 0.0
|
| 47 |
-
doppler_id = 0
|
| 48 |
-
|
| 49 |
-
matches = SNR_PATTERN.findall(path)
|
| 50 |
-
if matches:
|
| 51 |
-
try:
|
| 52 |
-
snr_db = float(matches[-1])
|
| 53 |
-
except ValueError:
|
| 54 |
-
snr_db = 0.0
|
| 55 |
-
|
| 56 |
-
normalized_path = os.path.normpath(path)
|
| 57 |
-
parts = normalized_path.split(os.sep)
|
| 58 |
-
for part in parts:
|
| 59 |
-
if part in DOPPLER_MAP:
|
| 60 |
-
doppler_id = DOPPLER_MAP[part]
|
| 61 |
-
break
|
| 62 |
-
|
| 63 |
-
return snr_db, doppler_id
|
| 64 |
-
|
| 65 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 66 |
-
|
| 67 |
-
# Use simple progress display instead of tqdm on Windows
|
| 68 |
-
USE_TQDM = platform.system() != 'Windows'
|
| 69 |
-
|
| 70 |
-
# CPU 코어 수 계산 (메모리 사용량 고려하여 보수적으로 설정)
|
| 71 |
-
total_cores = multiprocessing.cpu_count()
|
| 72 |
-
if total_cores >= 16:
|
| 73 |
-
MAX_WORKERS = min(8, total_cores // 2) # 고성능 서버의 경우 8코어로 제한
|
| 74 |
-
else:
|
| 75 |
-
MAX_WORKERS = max(2, total_cores // 2) # 일반 시스템의 경우 절반 사용
|
| 76 |
-
print(f"🚀 Using {MAX_WORKERS}/{total_cores} CPU cores for parallel processing")
|
| 77 |
-
|
| 78 |
-
PRINT_CONVERSION_STATS = os.environ.get("LWM_PRINT_CONVERSION_STATS", "").strip().lower() in {"1", "true", "yes"}
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def convert_complex_to_interleaved(spectrograms):
|
| 82 |
-
"""
|
| 83 |
-
Convert complex-valued spectrograms to real-imaginary interleaved format.
|
| 84 |
-
|
| 85 |
-
Similar to patch_maker() in train_lwm.py, this function:
|
| 86 |
-
1. Extracts real and imaginary parts
|
| 87 |
-
2. Interleaves them along the last dimension
|
| 88 |
-
|
| 89 |
-
Args:
|
| 90 |
-
spectrograms (np.ndarray): Complex-valued array of shape (n_samples, n_rows, n_cols)
|
| 91 |
-
or (n_samples, 1, n_rows, n_cols)
|
| 92 |
-
|
| 93 |
-
Returns:
|
| 94 |
-
np.ndarray: Real-valued array with interleaved real/imag parts
|
| 95 |
-
Shape: (n_samples, n_rows, n_cols * 2)
|
| 96 |
-
"""
|
| 97 |
-
# Handle different input shapes
|
| 98 |
-
if spectrograms.ndim == 4:
|
| 99 |
-
# Remove channel dimension if present: (n_samples, 1, n_rows, n_cols) -> (n_samples, n_rows, n_cols)
|
| 100 |
-
spectrograms = spectrograms[:, 0, :, :]
|
| 101 |
-
|
| 102 |
-
# Check if data is complex
|
| 103 |
-
if np.iscomplexobj(spectrograms):
|
| 104 |
-
n_samples, n_rows, n_cols = spectrograms.shape
|
| 105 |
-
|
| 106 |
-
# Extract real and imaginary parts
|
| 107 |
-
flat_real = spectrograms.real
|
| 108 |
-
flat_imag = spectrograms.imag
|
| 109 |
-
|
| 110 |
-
# Interleave real and imaginary parts along the last axis
|
| 111 |
-
# Output shape: (n_samples, n_rows, n_cols * 2)
|
| 112 |
-
interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
|
| 113 |
-
interleaved[:, :, 0::2] = flat_real # Even indices: real parts
|
| 114 |
-
interleaved[:, :, 1::2] = flat_imag # Odd indices: imaginary parts
|
| 115 |
-
|
| 116 |
-
if PRINT_CONVERSION_STATS:
|
| 117 |
-
print(f" ℹ️ Converted complex spectrograms: {spectrograms.shape} -> {interleaved.shape}")
|
| 118 |
-
print(f" Real part range: [{flat_real.min():.2e}, {flat_real.max():.2e}]")
|
| 119 |
-
print(f" Imag part range: [{flat_imag.min():.2e}, {flat_imag.max():.2e}]")
|
| 120 |
-
|
| 121 |
-
return interleaved
|
| 122 |
-
else:
|
| 123 |
-
# Already real-valued, just ensure correct shape
|
| 124 |
-
if spectrograms.ndim == 3:
|
| 125 |
-
if PRINT_CONVERSION_STATS:
|
| 126 |
-
print(f" ℹ️ Data is already real-valued: {spectrograms.shape}")
|
| 127 |
-
return spectrograms
|
| 128 |
-
else:
|
| 129 |
-
raise ValueError(f"Unexpected spectrogram shape: {spectrograms.shape}")
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def process_single_scenario(scenario_info):
|
| 133 |
-
"""단일 시나리오를 처리하는 함수 (멀티프로세싱용)"""
|
| 134 |
-
scenario_name, spectrogram_path = scenario_info
|
| 135 |
-
|
| 136 |
-
try:
|
| 137 |
-
# 메모리 효율성을 위해 필요한 데이터만 로드
|
| 138 |
-
scenario_spectrograms, scenario_labels = generate_spectrograms_and_labels(
|
| 139 |
-
scenario_name=scenario_name,
|
| 140 |
-
spectrogram_path=spectrogram_path,
|
| 141 |
-
cache_path=None, # 메모리 문제로 캐시 비활성화
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
# Validate load
|
| 145 |
-
if scenario_spectrograms is None or (hasattr(scenario_spectrograms, 'size') and scenario_spectrograms.size == 0):
|
| 146 |
-
print(f" ⚠️ No data loaded from: {spectrogram_path}")
|
| 147 |
-
return None
|
| 148 |
-
|
| 149 |
-
# Convert complex spectrograms to interleaved real-imaginary format
|
| 150 |
-
scenario_spectrograms = convert_complex_to_interleaved(scenario_spectrograms)
|
| 151 |
-
|
| 152 |
-
snr_db, doppler_id = _parse_snr_and_doppler(spectrogram_path)
|
| 153 |
-
|
| 154 |
-
# 데이터 분할 (인덱스만 계산)
|
| 155 |
-
total_samples = len(scenario_spectrograms)
|
| 156 |
-
train_size = int(0.8 * total_samples)
|
| 157 |
-
val_size = total_samples - train_size
|
| 158 |
-
|
| 159 |
-
# 메모리 절약을 위해 numpy array로 유지 (필요할 때만 tensor로 변환)
|
| 160 |
-
train_data = np.array(scenario_spectrograms[:train_size], dtype=np.float32)
|
| 161 |
-
val_data = np.array(scenario_spectrograms[train_size:], dtype=np.float32)
|
| 162 |
-
|
| 163 |
-
snr_array = np.full(total_samples, snr_db, dtype=np.float32)
|
| 164 |
-
doppler_array = np.full(total_samples, doppler_id, dtype=np.int64)
|
| 165 |
-
train_meta = {
|
| 166 |
-
'snr_db': snr_array[:train_size],
|
| 167 |
-
'doppler_id': doppler_array[:train_size],
|
| 168 |
-
}
|
| 169 |
-
val_meta = {
|
| 170 |
-
'snr_db': snr_array[train_size:],
|
| 171 |
-
'doppler_id': doppler_array[train_size:],
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
# 불필요한 데이터 즉시 삭제
|
| 175 |
-
del scenario_spectrograms
|
| 176 |
-
|
| 177 |
-
return {
|
| 178 |
-
'scenario': scenario_name,
|
| 179 |
-
'train_data': train_data,
|
| 180 |
-
'val_data': val_data,
|
| 181 |
-
'train_meta': train_meta,
|
| 182 |
-
'val_meta': val_meta,
|
| 183 |
-
'train_size': len(train_data),
|
| 184 |
-
'val_size': len(val_data)
|
| 185 |
-
}
|
| 186 |
-
except Exception as e:
|
| 187 |
-
print(f"❌ Error processing scenario {scenario_name}: {e}")
|
| 188 |
-
import traceback
|
| 189 |
-
traceback.print_exc()
|
| 190 |
-
return None
|
| 191 |
-
|
| 192 |
-
# GPU Memory Monitor import (for Lambda) - Removed
|
| 193 |
-
|
| 194 |
-
# =============================================================================
|
| 195 |
-
# 2. SCENARIO LIST DEFINITION
|
| 196 |
-
# - Define the list of scenario names to iterate over for data generation
|
| 197 |
-
# =============================================================================
|
| 198 |
-
|
| 199 |
-
# Supported communications; can be limited via CLI
|
| 200 |
-
SUPPORTED_COMM_TYPES = {"LTE", "WiFi", "5G"}
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
def _parse_standard_args():
|
| 204 |
-
parser = argparse.ArgumentParser(add_help=False)
|
| 205 |
-
parser.add_argument('--standards', nargs='+', choices=SUPPORTED_COMM_TYPES,
|
| 206 |
-
help='Specify one or more communication types to include (default: all).')
|
| 207 |
-
for comm in SUPPORTED_COMM_TYPES:
|
| 208 |
-
parser.add_argument(f'--{comm}', dest=f'flag_{comm}', action='store_true',
|
| 209 |
-
help=f'Include only {comm} data (can be combined).')
|
| 210 |
-
parser.add_argument('--city', '--cities', dest='cities', nargs='+',
|
| 211 |
-
help='Limit scenarios to one or more city prefixes (e.g., "0" or "city_0").')
|
| 212 |
-
parser.add_argument(
|
| 213 |
-
'--normalization',
|
| 214 |
-
choices=('per_sample', 'dataset'),
|
| 215 |
-
default='per_sample',
|
| 216 |
-
help='Normalization mode applied during tokenization (default: %(default)s).'
|
| 217 |
-
)
|
| 218 |
-
parser.add_argument('--help', action='help')
|
| 219 |
-
|
| 220 |
-
args, remaining = parser.parse_known_args()
|
| 221 |
-
|
| 222 |
-
enabled = set(SUPPORTED_COMM_TYPES)
|
| 223 |
-
if args.standards:
|
| 224 |
-
enabled = set(args.standards)
|
| 225 |
-
else:
|
| 226 |
-
flagged = {comm for comm in SUPPORTED_COMM_TYPES if getattr(args, f'flag_{comm}', False)}
|
| 227 |
-
if flagged:
|
| 228 |
-
enabled = flagged
|
| 229 |
-
|
| 230 |
-
selected_cities: list[str] | None = None
|
| 231 |
-
if args.cities:
|
| 232 |
-
selected_cities = []
|
| 233 |
-
for city_token in args.cities:
|
| 234 |
-
token = str(city_token).strip()
|
| 235 |
-
if not token:
|
| 236 |
-
continue
|
| 237 |
-
if token.startswith('city_'):
|
| 238 |
-
selected_cities.append(token)
|
| 239 |
-
else:
|
| 240 |
-
selected_cities.append(f'city_{token}')
|
| 241 |
-
if not selected_cities:
|
| 242 |
-
selected_cities = None
|
| 243 |
-
|
| 244 |
-
# Return remaining args to allow downstream parsing if needed
|
| 245 |
-
sys.argv = [sys.argv[0]] + remaining
|
| 246 |
-
return enabled, selected_cities, args.normalization
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
ENABLED_COMM_TYPES, ENABLED_CITY_PREFIXES, NORMALIZATION_MODE = _parse_standard_args()
|
| 250 |
-
MAX_SCENARIOS = int(os.environ.get("LWM_MAX_SCENARIOS", "0")) or None
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
def _extract_scenario_token(file_path):
|
| 254 |
-
"""Derive the base scenario token (without city) from the file path."""
|
| 255 |
-
normalized_path = os.path.normpath(file_path)
|
| 256 |
-
parts = normalized_path.split(os.sep)
|
| 257 |
-
|
| 258 |
-
scenario_parts = []
|
| 259 |
-
for i, part in enumerate(parts):
|
| 260 |
-
if part in SUPPORTED_COMM_TYPES:
|
| 261 |
-
trailing = parts[i:i + 5]
|
| 262 |
-
if trailing:
|
| 263 |
-
scenario_parts = trailing[:5]
|
| 264 |
-
break
|
| 265 |
-
|
| 266 |
-
if not scenario_parts:
|
| 267 |
-
# Fallback for datasets where the communication type is only captured in the filename
|
| 268 |
-
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 269 |
-
if base_name.startswith('spectrogram_'):
|
| 270 |
-
tokens = base_name.split('_')[1:] # drop 'spectrogram'
|
| 271 |
-
if tokens and tokens[0] in SUPPORTED_COMM_TYPES:
|
| 272 |
-
scenario_parts = tokens[:5] if len(tokens) >= 5 else tokens
|
| 273 |
-
|
| 274 |
-
return '_'.join(scenario_parts) if scenario_parts else None
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
@lru_cache(maxsize=1)
|
| 278 |
-
def _collect_scenario_file_info():
|
| 279 |
-
import glob
|
| 280 |
-
|
| 281 |
-
scenario_entries = []
|
| 282 |
-
|
| 283 |
-
# New MATLAB receiver pipeline output
|
| 284 |
-
new_base = os.path.join('ls_data', 'MATLAB', 'receiver_pipeline')
|
| 285 |
-
if os.path.isdir(new_base):
|
| 286 |
-
patterns = [os.path.join(new_base, '*', '**', 'spectrogram_*.mat')]
|
| 287 |
-
for pattern in patterns:
|
| 288 |
-
for file_path in sorted(glob.glob(pattern, recursive=True)):
|
| 289 |
-
norm = os.path.normpath(file_path)
|
| 290 |
-
parts = norm.split(os.sep)
|
| 291 |
-
# Determine a grouping token similar to city_name; use the standard folder name
|
| 292 |
-
try:
|
| 293 |
-
idx = parts.index('receiver_pipeline')
|
| 294 |
-
city_name = parts[idx + 1] if idx + 1 < len(parts) else 'receiver_pipeline'
|
| 295 |
-
except ValueError:
|
| 296 |
-
city_name = 'receiver_pipeline'
|
| 297 |
-
|
| 298 |
-
base_token = _extract_scenario_token(file_path)
|
| 299 |
-
if not base_token:
|
| 300 |
-
continue
|
| 301 |
-
comm_type = base_token.split('_', 1)[0]
|
| 302 |
-
if comm_type not in ENABLED_COMM_TYPES:
|
| 303 |
-
continue
|
| 304 |
-
scenario_id = f"{city_name}::{base_token}"
|
| 305 |
-
scenario_entries.append((scenario_id, file_path, city_name, base_token))
|
| 306 |
-
|
| 307 |
-
# Legacy repo layouts under spectrograms/city_*
|
| 308 |
-
import glob as _glob
|
| 309 |
-
for city_dir in sorted(_glob.glob(os.path.join('spectrograms', 'city_*'))):
|
| 310 |
-
if not os.path.isdir(city_dir):
|
| 311 |
-
continue
|
| 312 |
-
city_name = os.path.basename(city_dir)
|
| 313 |
-
if ENABLED_CITY_PREFIXES:
|
| 314 |
-
if not any(city_name.startswith(prefix) for prefix in ENABLED_CITY_PREFIXES):
|
| 315 |
-
continue
|
| 316 |
-
# Look for complex spectrogram outputs; support both nested and flat layouts
|
| 317 |
-
candidate_patterns = [
|
| 318 |
-
os.path.join(city_dir, '**', 'complex_raw', '**', 'spectrogram_*.mat'),
|
| 319 |
-
os.path.join(city_dir, '**', 'spectrogram_*.mat'),
|
| 320 |
-
]
|
| 321 |
-
city_files = []
|
| 322 |
-
seen_paths = set()
|
| 323 |
-
for pattern in candidate_patterns:
|
| 324 |
-
for file_path in sorted(_glob.glob(pattern, recursive=True)):
|
| 325 |
-
if not file_path.lower().endswith('.mat'):
|
| 326 |
-
continue
|
| 327 |
-
if file_path in seen_paths:
|
| 328 |
-
continue
|
| 329 |
-
seen_paths.add(file_path)
|
| 330 |
-
city_files.append(file_path)
|
| 331 |
-
|
| 332 |
-
# Fallback: 512FFT pattern (기존 호환성)
|
| 333 |
-
if not city_files:
|
| 334 |
-
pattern = os.path.join(city_dir, '**', '512FFT', '**', 'spectrograms', '*.pkl')
|
| 335 |
-
city_files = sorted(_glob.glob(pattern, recursive=True))
|
| 336 |
-
|
| 337 |
-
for file_path in city_files:
|
| 338 |
-
base_token = _extract_scenario_token(file_path)
|
| 339 |
-
if not base_token:
|
| 340 |
-
continue
|
| 341 |
-
comm_type = base_token.split('_', 1)[0]
|
| 342 |
-
if comm_type not in ENABLED_COMM_TYPES:
|
| 343 |
-
continue
|
| 344 |
-
scenario_id = f"{city_name}::{base_token}"
|
| 345 |
-
scenario_entries.append((scenario_id, file_path, city_name, base_token))
|
| 346 |
-
|
| 347 |
-
if MAX_SCENARIOS:
|
| 348 |
-
scenario_entries = scenario_entries[:MAX_SCENARIOS]
|
| 349 |
-
|
| 350 |
-
return scenario_entries
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
def scenarios_list():
|
| 354 |
-
scenario_entries = _collect_scenario_file_info()
|
| 355 |
-
|
| 356 |
-
if not scenario_entries:
|
| 357 |
-
print("⚠️ No spectrogram files found for pretraining.")
|
| 358 |
-
return np.array([])
|
| 359 |
-
|
| 360 |
-
print(f"Enabled communication types: {sorted(ENABLED_COMM_TYPES)}")
|
| 361 |
-
if ENABLED_CITY_PREFIXES:
|
| 362 |
-
print(f"Selected city prefixes: {sorted(ENABLED_CITY_PREFIXES)}")
|
| 363 |
-
city_counts = Counter(entry[2] for entry in scenario_entries)
|
| 364 |
-
print("Using scenarios from the following city datasets:")
|
| 365 |
-
for city_name, count in city_counts.items():
|
| 366 |
-
print(f" - {city_name}: {count} files")
|
| 367 |
-
|
| 368 |
-
print(f"Total scenarios selected: {len(scenario_entries)}")
|
| 369 |
-
return np.array([entry[0] for entry in scenario_entries])
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
# =============================================================================
|
| 373 |
-
# 3. SCENARIO PROPERTIES MAPPING
|
| 374 |
-
# - Map each scenario name to its corresponding properties
|
| 375 |
-
# =============================================================================
|
| 376 |
-
|
| 377 |
-
def scenario_prop():
|
| 378 |
-
scenario_entries = _collect_scenario_file_info()
|
| 379 |
-
|
| 380 |
-
row_column_users = {}
|
| 381 |
-
for scenario_id, file_path, city_name, _ in scenario_entries:
|
| 382 |
-
row_column_users[scenario_id] = {
|
| 383 |
-
'spectrogram_path': file_path,
|
| 384 |
-
'cache_path': os.path.join('spectrograms', city_name, 'spectrogram_cache_128x128.pkl')
|
| 385 |
-
}
|
| 386 |
-
|
| 387 |
-
return row_column_users
|
| 388 |
-
|
| 389 |
-
# =============================================================================
|
| 390 |
-
# 4. TRAINING PARAMETERS AND HYPERPARAMETERS
|
| 391 |
-
# - Set training epochs, batch sizes, learning rates, model dimensions, etc.
|
| 392 |
-
# =============================================================================
|
| 393 |
-
|
| 394 |
-
EPOCHS = 20 # Increased for better convergence
|
| 395 |
-
# Optimized batch size for A100 GPU (40GB)
|
| 396 |
-
BATCH_SIZE = 16
|
| 397 |
-
VAL_BATCH_SIZE = 16
|
| 398 |
-
WARMUP_EPOCHS = 5
|
| 399 |
-
BASE_LR = 5e-4
|
| 400 |
-
MIN_LR = 1e-8
|
| 401 |
-
# Updated for 128x128 complex spectrograms with real-imaginary interleaving
|
| 402 |
-
N_ROWS = 4
|
| 403 |
-
N_COLUMNS = 4
|
| 404 |
-
ELEMENT_LENGTH = N_ROWS * N_COLUMNS * 2 # Complex spectrograms: 2x for real+imaginary interleaving
|
| 405 |
-
D_MODEL = 128
|
| 406 |
-
MAX_LEN = 1025 # (128/4) * (128/4) + 1 = 32 * 32 + 1 = 1024 + 1 for [CLS] token
|
| 407 |
-
# Interleaving keeps the same number of spatial patches (32x32) while doubling patch width
|
| 408 |
-
# so each token covers 4x4 complex bins (real+imag) and sequence length stays at 1025.
|
| 409 |
-
N_LAYERS = 12
|
| 410 |
-
device_idx = 0
|
| 411 |
-
WEIGHT_DECAY = 0.05
|
| 412 |
-
BETA1 = 0.9
|
| 413 |
-
BETA2 = 0.999
|
| 414 |
-
MASK_PERCENT = 0.6
|
| 415 |
-
N_HEADS = 8
|
| 416 |
-
DROPOUT = 0.1
|
| 417 |
-
|
| 418 |
-
print(f"📊 Model configuration for complex spectrograms:")
|
| 419 |
-
print(f" Patch size: {N_ROWS}x{N_COLUMNS}")
|
| 420 |
-
print(f" Element length: {ELEMENT_LENGTH} (includes real+imag interleaving)")
|
| 421 |
-
print(f" Max sequence length: {MAX_LEN}")
|
| 422 |
-
|
| 423 |
-
# =============================================================================
|
| 424 |
-
# 5. DATA GENERATION LOOP
|
| 425 |
-
# - Iterate over scenarios to generate spectrogram samples and labels
|
| 426 |
-
# =============================================================================
|
| 427 |
-
|
| 428 |
-
scenarios = scenarios_list()
|
| 429 |
-
scenario_properties = scenario_prop()
|
| 430 |
-
|
| 431 |
-
# Collect all training and validation data separately
|
| 432 |
-
train_spectrogram_chunks = []
|
| 433 |
-
val_spectrogram_chunks = []
|
| 434 |
-
train_label_chunks = []
|
| 435 |
-
val_label_chunks = []
|
| 436 |
-
train_meta_chunks = []
|
| 437 |
-
val_meta_chunks = []
|
| 438 |
-
|
| 439 |
-
print(f"📂 Loading {len(scenarios)} scenarios...")
|
| 440 |
-
|
| 441 |
-
# TEMP: Modified to not use cache
|
| 442 |
-
print("⚠️ TEMPORARY FIX: Skipping cache to avoid memory issues")
|
| 443 |
-
cache_path = None # Disable cache usage
|
| 444 |
-
|
| 445 |
-
# 단일 프로세스 시나리오 처리 (멀티프로세싱 비활성화)
|
| 446 |
-
scenario_info_list = []
|
| 447 |
-
missing_props = []
|
| 448 |
-
for scenario in scenarios:
|
| 449 |
-
props = scenario_properties.get(scenario)
|
| 450 |
-
if props is None:
|
| 451 |
-
missing_props.append(scenario)
|
| 452 |
-
continue
|
| 453 |
-
scenario_info_list.append((scenario, props["spectrogram_path"]))
|
| 454 |
-
|
| 455 |
-
if missing_props:
|
| 456 |
-
print("⚠️ Missing metadata for the following scenarios; skipping:")
|
| 457 |
-
for scen in missing_props:
|
| 458 |
-
print(f" - {scen}")
|
| 459 |
-
|
| 460 |
-
print(f"📂 Loading {len(scenario_info_list)} scenarios using single process...")
|
| 461 |
-
|
| 462 |
-
# 단일 프로세스로 처리
|
| 463 |
-
successful_scenarios = 0
|
| 464 |
-
scenario_results = []
|
| 465 |
-
|
| 466 |
-
for scenario_info in tqdm(scenario_info_list, desc="Processing scenarios", unit="scenario"):
|
| 467 |
-
scenario_name = scenario_info[0]
|
| 468 |
-
try:
|
| 469 |
-
result = process_single_scenario(scenario_info)
|
| 470 |
-
if result is not None:
|
| 471 |
-
# 데이터 수집 (시나리오 단위로 누적)
|
| 472 |
-
train_spectrogram_chunks.append(result['train_data'])
|
| 473 |
-
val_spectrogram_chunks.append(result['val_data'])
|
| 474 |
-
train_label_chunks.append(np.zeros(result['train_size'], dtype=np.int64))
|
| 475 |
-
val_label_chunks.append(np.zeros(result['val_size'], dtype=np.int64))
|
| 476 |
-
train_meta_chunks.append(result['train_meta'])
|
| 477 |
-
val_meta_chunks.append(result['val_meta'])
|
| 478 |
-
successful_scenarios += 1
|
| 479 |
-
except Exception as e:
|
| 480 |
-
print(f"❌ Scenario {scenario_name} processing failed: {e}")
|
| 481 |
-
|
| 482 |
-
print(f"✅ Processing completed! Successful scenarios: {successful_scenarios}/{len(scenario_info_list)}")
|
| 483 |
-
|
| 484 |
-
if not train_spectrogram_chunks or not val_spectrogram_chunks:
|
| 485 |
-
raise ValueError("No spectrogram data collected; check scenario configuration.")
|
| 486 |
-
|
| 487 |
-
print("🔄 Collating spectrogram arrays...")
|
| 488 |
-
train_spectrograms = np.concatenate(train_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
|
| 489 |
-
val_spectrograms = np.concatenate(val_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
|
| 490 |
-
train_labels = np.concatenate(train_label_chunks, axis=0)
|
| 491 |
-
val_labels = np.concatenate(val_label_chunks, axis=0)
|
| 492 |
-
|
| 493 |
-
def _concat_metadata_dicts(dict_list):
|
| 494 |
-
if not dict_list:
|
| 495 |
-
return {}
|
| 496 |
-
keys = dict_list[0].keys()
|
| 497 |
-
return {k: np.concatenate([d[k] for d in dict_list], axis=0) for k in keys}
|
| 498 |
-
|
| 499 |
-
train_metadata = _concat_metadata_dicts(train_meta_chunks)
|
| 500 |
-
val_metadata = _concat_metadata_dicts(val_meta_chunks)
|
| 501 |
-
|
| 502 |
-
del train_spectrogram_chunks, val_spectrogram_chunks, train_label_chunks, val_label_chunks
|
| 503 |
-
del train_meta_chunks, val_meta_chunks
|
| 504 |
-
|
| 505 |
-
print(f"Training spectrograms shape: {train_spectrograms.shape}")
|
| 506 |
-
print(f"Validation spectrograms shape: {val_spectrograms.shape}")
|
| 507 |
-
print(f"Memory usage: {train_spectrograms.nbytes + val_spectrograms.nbytes + train_labels.nbytes + val_labels.nbytes:,} bytes")
|
| 508 |
-
|
| 509 |
-
train_mean = float(train_spectrograms.mean())
|
| 510 |
-
train_std = float(train_spectrograms.std())
|
| 511 |
-
if abs(train_std) < 1e-6:
|
| 512 |
-
print("⚠️ Training std near zero, using epsilon for stability")
|
| 513 |
-
train_std = 1e-6
|
| 514 |
-
dataset_normalization = {'mean': train_mean, 'std': train_std, 'normalization': NORMALIZATION_MODE}
|
| 515 |
-
print(f"Dataset normalization stats -> mean: {train_mean:.4f}, std: {train_std:.4f}")
|
| 516 |
-
|
| 517 |
-
# =============================================================================
|
| 518 |
-
# 6. DATA TOKENIZATION
|
| 519 |
-
# - Tokenize spectrogram matrices into input sequences with masking for pretraining
|
| 520 |
-
# =============================================================================
|
| 521 |
-
|
| 522 |
-
# Tokenize training data
|
| 523 |
-
print("🔄 Starting tokenization of training data...")
|
| 524 |
-
preprocessed_train = tokenizer_train(
|
| 525 |
-
train_spectrograms,
|
| 526 |
-
max_len=MAX_LEN,
|
| 527 |
-
masking_percent=MASK_PERCENT,
|
| 528 |
-
mask=True,
|
| 529 |
-
seed=42,
|
| 530 |
-
metadata=train_metadata,
|
| 531 |
-
dataset_stats=dataset_normalization,
|
| 532 |
-
normalization=NORMALIZATION_MODE,
|
| 533 |
-
interleaved=True,
|
| 534 |
-
)
|
| 535 |
-
print("✅ Training data tokenization completed!")
|
| 536 |
-
|
| 537 |
-
# Tokenize validation data (with masking for pretraining evaluation)
|
| 538 |
-
print("🔄 Starting tokenization of validation data...")
|
| 539 |
-
preprocessed_val = tokenizer_train(
|
| 540 |
-
val_spectrograms,
|
| 541 |
-
max_len=MAX_LEN,
|
| 542 |
-
masking_percent=MASK_PERCENT,
|
| 543 |
-
mask=True, # Apply masking for pretraining evaluation
|
| 544 |
-
seed=42,
|
| 545 |
-
metadata=val_metadata,
|
| 546 |
-
dataset_stats=dataset_normalization,
|
| 547 |
-
normalization=NORMALIZATION_MODE,
|
| 548 |
-
interleaved=True,
|
| 549 |
-
)
|
| 550 |
-
print("✅ Validation data tokenization completed!")
|
| 551 |
-
|
| 552 |
-
# =============================================================================
|
| 553 |
-
# 7. TRAIN/VALIDATION DATA SETUP
|
| 554 |
-
# - Use pre-split training and validation data
|
| 555 |
-
# =============================================================================
|
| 556 |
-
|
| 557 |
-
SEED = 42
|
| 558 |
-
torch.manual_seed(SEED)
|
| 559 |
-
np.random.seed(SEED)
|
| 560 |
-
|
| 561 |
-
# Use pre-split data
|
| 562 |
-
train_data = preprocessed_train
|
| 563 |
-
val_data = preprocessed_val
|
| 564 |
-
|
| 565 |
-
# =============================================================================
|
| 566 |
-
# 8. DATALOADER CREATION
|
| 567 |
-
# - Build PyTorch DataLoader objects for batched training and validation
|
| 568 |
-
# =============================================================================
|
| 569 |
-
|
| 570 |
-
# Handle different data formats
|
| 571 |
-
print("🔧 Creating data loaders...")
|
| 572 |
-
|
| 573 |
-
if isinstance(train_data, dict):
|
| 574 |
-
print(f" Training data format: dict with {len(train_data)} sequence lengths")
|
| 575 |
-
# Training data with masking
|
| 576 |
-
train_loaders = create_train_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
|
| 577 |
-
else:
|
| 578 |
-
print(f" Training data format: tensor with shape {train_data.shape}")
|
| 579 |
-
# Training data without masking (fallback)
|
| 580 |
-
train_dataset = TensorDataset(train_data)
|
| 581 |
-
train_loaders = {'seq_0': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)}
|
| 582 |
-
|
| 583 |
-
if isinstance(val_data, dict):
|
| 584 |
-
print(f" Validation data format: dict with {len(val_data)} sequence lengths")
|
| 585 |
-
# Validation data with masking
|
| 586 |
-
val_loaders = create_train_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
|
| 587 |
-
else:
|
| 588 |
-
print(f" Validation data format: tensor with shape {val_data.shape}")
|
| 589 |
-
# Validation data without masking
|
| 590 |
-
val_dataset = TensorDataset(val_data)
|
| 591 |
-
val_loaders = {'seq_0': DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)}
|
| 592 |
-
|
| 593 |
-
print("✅ Data loaders created successfully!")
|
| 594 |
-
|
| 595 |
-
# =============================================================================
|
| 596 |
-
# 9. MODEL INITIALIZATION
|
| 597 |
-
# - Instantiate the LWM transformer model and optionally load pre-trained weights
|
| 598 |
-
# - Wrap with DataParallel for multi-GPU support
|
| 599 |
-
# =============================================================================
|
| 600 |
-
|
| 601 |
-
# Device selection with MPS support for Mac
|
| 602 |
-
print("🔧 Setting up device and GPU configuration...")
|
| 603 |
-
|
| 604 |
-
if torch.cuda.is_available():
|
| 605 |
-
device_count = torch.cuda.device_count()
|
| 606 |
-
print(f" CUDA available: {device_count} GPU(s) detected")
|
| 607 |
-
|
| 608 |
-
device = torch.device("cuda:0")
|
| 609 |
-
|
| 610 |
-
# On Windows, use only available GPUs
|
| 611 |
-
gpu_ids = list(range(device_count)) # 0, 1, 2... auto-detect
|
| 612 |
-
print(f" Using CUDA GPUs: {gpu_ids}")
|
| 613 |
-
|
| 614 |
-
# GPU memory status
|
| 615 |
-
for i in gpu_ids:
|
| 616 |
-
try:
|
| 617 |
-
mem_total = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
| 618 |
-
mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
|
| 619 |
-
print(f" GPU {i}: Total: {mem_total:.1f}GB, Allocated: {mem_allocated:.1f}GB")
|
| 620 |
-
except Exception as e:
|
| 621 |
-
print(f" GPU {i}: Error getting memory info - {e}")
|
| 622 |
-
|
| 623 |
-
elif torch.backends.mps.is_available():
|
| 624 |
-
device = torch.device("mps")
|
| 625 |
-
gpu_ids = [] # MPS doesn't support DataParallel
|
| 626 |
-
print(" Using MPS (Apple Silicon GPU)")
|
| 627 |
-
else:
|
| 628 |
-
device = torch.device("cpu")
|
| 629 |
-
gpu_ids = []
|
| 630 |
-
print(" Using CPU")
|
| 631 |
-
|
| 632 |
-
print(f" Final device: {device}")
|
| 633 |
-
print(f" GPU IDs for DataParallel: {gpu_ids}")
|
| 634 |
-
|
| 635 |
-
print("🤖 Initializing LWM model...")
|
| 636 |
-
print(f" Model parameters: element_length={ELEMENT_LENGTH}, d_model={D_MODEL}, n_layers={N_LAYERS}, max_len={MAX_LEN}, n_heads={N_HEADS}")
|
| 637 |
-
|
| 638 |
-
try:
|
| 639 |
-
model = pretrained_model.lwm(
|
| 640 |
-
element_length=ELEMENT_LENGTH, # Complex spectrograms with real-imag interleaving
|
| 641 |
-
d_model=D_MODEL,
|
| 642 |
-
n_layers=N_LAYERS,
|
| 643 |
-
max_len=MAX_LEN,
|
| 644 |
-
n_heads=N_HEADS,
|
| 645 |
-
dropout=DROPOUT
|
| 646 |
-
)
|
| 647 |
-
print(" ✅ Model created successfully")
|
| 648 |
-
|
| 649 |
-
print(f" Moving model to device: {device}")
|
| 650 |
-
# MPS only supports float32, so set dtype
|
| 651 |
-
if 'mps' in str(device):
|
| 652 |
-
model = model.to(device).float()
|
| 653 |
-
print(" ✅ Model moved to MPS device (float32)")
|
| 654 |
-
else:
|
| 655 |
-
model = model.to(device)
|
| 656 |
-
print(" ✅ Model moved to device successfully")
|
| 657 |
-
|
| 658 |
-
except Exception as e:
|
| 659 |
-
print(f" ❌ Model initialization failed: {e}")
|
| 660 |
-
import traceback
|
| 661 |
-
traceback.print_exc()
|
| 662 |
-
exit(1)
|
| 663 |
-
|
| 664 |
-
# Optional: Load pre-trained model
|
| 665 |
-
load_model = False
|
| 666 |
-
if load_model:
|
| 667 |
-
model.load_state_dict(torch.load("models/model_checkpoint.pth", map_location=device))
|
| 668 |
-
print("Pre-trained model loaded successfully.")
|
| 669 |
-
|
| 670 |
-
# Use DataParallel for multi-GPU support (skip for MPS)
|
| 671 |
-
if gpu_ids:
|
| 672 |
-
model = nn.DataParallel(model, device_ids=gpu_ids)
|
| 673 |
-
print(f"Model loaded successfully on GPU {device.index}")
|
| 674 |
-
else:
|
| 675 |
-
print(f"Model loaded successfully on {device}")
|
| 676 |
-
n_parameters = count_parameters(model)
|
| 677 |
-
print(f"Number of trainable parameters: {n_parameters:,}")
|
| 678 |
-
|
| 679 |
-
# =============================================================================
|
| 680 |
-
# 10. OPTIMIZER AND LEARNING RATE SCHEDULER
|
| 681 |
-
# - Configure AdamW optimizer and a cosine-with-warmup LR schedule based on total steps
|
| 682 |
-
# =============================================================================
|
| 683 |
-
|
| 684 |
-
TOTAL_STEPS = sum(len(loader) for loader in train_loaders.values()) * EPOCHS
|
| 685 |
-
WARMUP_STEPS = sum(len(loader) for loader in train_loaders.values()) * WARMUP_EPOCHS
|
| 686 |
-
|
| 687 |
-
optimizer = AdamW(
|
| 688 |
-
model.parameters(),
|
| 689 |
-
lr=BASE_LR,
|
| 690 |
-
betas=(BETA1, BETA2),
|
| 691 |
-
weight_decay=WEIGHT_DECAY
|
| 692 |
-
)
|
| 693 |
-
|
| 694 |
-
def lr_lambda(current_step):
|
| 695 |
-
if current_step < WARMUP_STEPS:
|
| 696 |
-
return current_step / WARMUP_STEPS
|
| 697 |
-
else:
|
| 698 |
-
scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
|
| 699 |
-
cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
|
| 700 |
-
return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
|
| 701 |
-
|
| 702 |
-
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
|
| 703 |
-
|
| 704 |
-
# =============================================================================
|
| 705 |
-
# 11. PRE-TRAINING LOOP
|
| 706 |
-
# - Call the train_lwm utility to run the pre-training epochs, logging metrics and saving models
|
| 707 |
-
# =============================================================================
|
| 708 |
-
|
| 709 |
-
# Create timestamp-based save directory
|
| 710 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 711 |
-
save_dir = f"models/{timestamp}_complex"
|
| 712 |
-
print(f"📁 Models and logs will be saved to: {save_dir}")
|
| 713 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 714 |
-
|
| 715 |
-
stats_path = os.path.join(save_dir, "dataset_stats.json")
|
| 716 |
-
with open(stats_path, 'w') as f:
|
| 717 |
-
json.dump(dataset_normalization, f, indent=2)
|
| 718 |
-
print(f"📝 Saved dataset stats to {stats_path}")
|
| 719 |
-
|
| 720 |
-
comm_selection = sorted(ENABLED_COMM_TYPES) if ENABLED_COMM_TYPES else []
|
| 721 |
-
if comm_selection:
|
| 722 |
-
comm_suffix = "_" + "-".join(comm_selection)
|
| 723 |
-
else:
|
| 724 |
-
comm_suffix = ""
|
| 725 |
-
if comm_selection:
|
| 726 |
-
print(f"[INFO] Communication standards for this run: {', '.join(comm_selection)}")
|
| 727 |
-
|
| 728 |
-
if __name__ == "__main__":
|
| 729 |
-
pretrained_model_output = train_lwm(
|
| 730 |
-
model,
|
| 731 |
-
train_loaders,
|
| 732 |
-
val_loaders,
|
| 733 |
-
optimizer,
|
| 734 |
-
scheduler,
|
| 735 |
-
EPOCHS,
|
| 736 |
-
device=device,
|
| 737 |
-
save_dir=save_dir,
|
| 738 |
-
log_file="training_log.csv",
|
| 739 |
-
checkpoint_suffix=comm_suffix + "_complex",
|
| 740 |
-
)
|
| 741 |
-
print("🎉 Training completed successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretraining/train_lwm_spectro_contrastive.py
DELETED
|
@@ -1,1450 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# =============================================================================
|
| 3 |
-
# train_lwm_spectro_contrastive.py - LWM Pretraining with Contrastive Learning
|
| 4 |
-
# Extended from train_lwm_spectro.py to add modulation/mobility contrastive learning
|
| 5 |
-
#
|
| 6 |
-
# Key additions:
|
| 7 |
-
# - Contrastive learning module with projection head
|
| 8 |
-
# - Multi-task loss: MLM + Contrastive (modulation + mobility)
|
| 9 |
-
# - Hard negative mining
|
| 10 |
-
# - Supervised contrastive loss (SupCon)
|
| 11 |
-
# =============================================================================
|
| 12 |
-
|
| 13 |
-
# =============================================================================
|
| 14 |
-
# 1. IMPORTS AND WARNINGS SETUP
|
| 15 |
-
# =============================================================================
|
| 16 |
-
import sys
|
| 17 |
-
import os
|
| 18 |
-
import argparse
|
| 19 |
-
import math
|
| 20 |
-
# Add project root to path (Windows compatible)
|
| 21 |
-
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
-
sys.path.insert(0, project_root)
|
| 23 |
-
import torch
|
| 24 |
-
import torch.nn as nn
|
| 25 |
-
import torch.nn.functional as F
|
| 26 |
-
from torch.utils.data import DataLoader, random_split, TensorDataset, Dataset
|
| 27 |
-
import torch.optim as optim
|
| 28 |
-
from utils import (generate_spectrograms_and_labels, tokenizer_train,
|
| 29 |
-
create_train_dataloader, count_parameters)
|
| 30 |
-
import numpy as np
|
| 31 |
-
import pretrained_model # Assuming this contains the LWM model definition
|
| 32 |
-
from torch.optim.lr_scheduler import LambdaLR
|
| 33 |
-
from torch.optim import AdamW
|
| 34 |
-
import warnings
|
| 35 |
-
import platform
|
| 36 |
-
import re
|
| 37 |
-
from tqdm import tqdm
|
| 38 |
-
from datetime import datetime
|
| 39 |
-
import concurrent.futures
|
| 40 |
-
import multiprocessing
|
| 41 |
-
from collections import Counter
|
| 42 |
-
from functools import lru_cache
|
| 43 |
-
import json
|
| 44 |
-
from typing import Dict, Tuple, List, Optional
|
| 45 |
-
|
| 46 |
-
SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
|
| 47 |
-
DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
|
| 48 |
-
DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
|
| 49 |
-
|
| 50 |
-
# Dynamic modulation mapping - will be built from actual data
|
| 51 |
-
MODULATION_MAP = {} # Will be populated: {"BPSK": 0, "QPSK": 1, ...}
|
| 52 |
-
MODULATION_INV = {} # Will be populated: {0: "BPSK", 1: "QPSK", ...}
|
| 53 |
-
|
| 54 |
-
# Standard-to-modulation mapping (for reference only - not used in code)
|
| 55 |
-
# Note: Actual modulations are dynamically discovered from file paths
|
| 56 |
-
# These match the MCS definitions in MATLAB/receiver_pipeline/getMCSDefinitions.m
|
| 57 |
-
STANDARD_MODULATIONS = {
|
| 58 |
-
"WiFi": [
|
| 59 |
-
"BPSK", "QPSK", "16QAM", "64QAM"
|
| 60 |
-
# From getMCSDefinitions.m WiFi MCS table:
|
| 61 |
-
# - MCS 0: BPSK rate1-2
|
| 62 |
-
# - MCS 1-2: QPSK rate1-2, rate3-4
|
| 63 |
-
# - MCS 3-4: 16QAM rate1-2, rate3-4
|
| 64 |
-
# - MCS 5-7: 64QAM rate2-3, rate3-4, rate5-6
|
| 65 |
-
# Note: Your MATLAB pipeline uses 802.11a/g MCS (no 256QAM/1024QAM)
|
| 66 |
-
],
|
| 67 |
-
"LTE": [
|
| 68 |
-
"QPSK", "16QAM", "64QAM"
|
| 69 |
-
# From getMCSDefinitions.m LTE MCS table:
|
| 70 |
-
# - MCS 0-2: QPSK rate1-3, rate1-2, rate3-4
|
| 71 |
-
# - MCS 3-4: 16QAM rate1-2, rate3-4
|
| 72 |
-
# - MCS 5-6: 64QAM rate2-3, rate3-4
|
| 73 |
-
# Note: Your MATLAB pipeline does NOT include 256QAM
|
| 74 |
-
],
|
| 75 |
-
"5G": [
|
| 76 |
-
"QPSK", "16QAM", "64QAM", "256QAM"
|
| 77 |
-
# From getMCSDefinitions.m 5G MCS table:
|
| 78 |
-
# - MCS 0-1: QPSK rate1-3, rate1-2
|
| 79 |
-
# - MCS 2-3: 16QAM rate1-2, rate3-4
|
| 80 |
-
# - MCS 4-5: 64QAM rate2-3, rate3-4
|
| 81 |
-
# - MCS 6: 256QAM rate3-4
|
| 82 |
-
],
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
# Important: This mapping is for documentation only
|
| 86 |
-
# The actual modulations used in your dataset may differ
|
| 87 |
-
# They will be automatically discovered from file paths
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def _parse_metadata(path: str) -> Dict[str, any]:
|
| 91 |
-
"""
|
| 92 |
-
Parse SNR, Doppler, and Modulation from file path.
|
| 93 |
-
Modulation is dynamically extracted and added to global MODULATION_MAP.
|
| 94 |
-
|
| 95 |
-
Returns:
|
| 96 |
-
dict with keys: snr_db, doppler_id, modulation_id, modulation_name
|
| 97 |
-
"""
|
| 98 |
-
global MODULATION_MAP, MODULATION_INV
|
| 99 |
-
|
| 100 |
-
snr_db = 0.0
|
| 101 |
-
doppler_id = 0
|
| 102 |
-
modulation_name = "Unknown"
|
| 103 |
-
|
| 104 |
-
# Parse SNR
|
| 105 |
-
matches = SNR_PATTERN.findall(path)
|
| 106 |
-
if matches:
|
| 107 |
-
try:
|
| 108 |
-
snr_db = float(matches[-1])
|
| 109 |
-
except ValueError:
|
| 110 |
-
snr_db = 0.0
|
| 111 |
-
|
| 112 |
-
# Parse Doppler
|
| 113 |
-
normalized_path = os.path.normpath(path)
|
| 114 |
-
parts = normalized_path.split(os.sep)
|
| 115 |
-
for part in parts:
|
| 116 |
-
if part in DOPPLER_MAP:
|
| 117 |
-
doppler_id = DOPPLER_MAP[part]
|
| 118 |
-
break
|
| 119 |
-
|
| 120 |
-
# Parse Modulation (dynamic - look for common modulation patterns)
|
| 121 |
-
# Patterns: BPSK, QPSK, 8PSK, 16QAM, 32QAM, 64QAM, 256QAM, 1024QAM, etc.
|
| 122 |
-
# Note: We ONLY use explicit modulation names in the path, not code rates
|
| 123 |
-
# since the same code rate can be used with different modulations
|
| 124 |
-
modulation_patterns = [
|
| 125 |
-
r"BPSK",
|
| 126 |
-
r"QPSK",
|
| 127 |
-
r"8PSK",
|
| 128 |
-
r"16QAM",
|
| 129 |
-
r"32QAM",
|
| 130 |
-
r"64QAM",
|
| 131 |
-
r"128QAM",
|
| 132 |
-
r"256QAM",
|
| 133 |
-
r"512QAM",
|
| 134 |
-
r"1024QAM",
|
| 135 |
-
]
|
| 136 |
-
|
| 137 |
-
for pattern in modulation_patterns:
|
| 138 |
-
if re.search(pattern, path, re.IGNORECASE):
|
| 139 |
-
modulation_name = pattern
|
| 140 |
-
break
|
| 141 |
-
|
| 142 |
-
# Add to global mapping if new
|
| 143 |
-
if modulation_name != "Unknown" and modulation_name not in MODULATION_MAP:
|
| 144 |
-
modulation_id = len(MODULATION_MAP)
|
| 145 |
-
MODULATION_MAP[modulation_name] = modulation_id
|
| 146 |
-
MODULATION_INV[modulation_id] = modulation_name
|
| 147 |
-
elif modulation_name in MODULATION_MAP:
|
| 148 |
-
modulation_id = MODULATION_MAP[modulation_name]
|
| 149 |
-
else:
|
| 150 |
-
modulation_id = -1 # Unknown
|
| 151 |
-
|
| 152 |
-
return {
|
| 153 |
-
'snr_db': snr_db,
|
| 154 |
-
'doppler_id': doppler_id,
|
| 155 |
-
'modulation_id': modulation_id,
|
| 156 |
-
'modulation_name': modulation_name
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 161 |
-
|
| 162 |
-
# Use simple progress display instead of tqdm on Windows
|
| 163 |
-
USE_TQDM = platform.system() != 'Windows'
|
| 164 |
-
|
| 165 |
-
# CPU 코어 수 계산 (메모리 사용량 고려하여 보수적으로 설정)
|
| 166 |
-
total_cores = multiprocessing.cpu_count()
|
| 167 |
-
if total_cores >= 16:
|
| 168 |
-
MAX_WORKERS = min(8, total_cores // 2)
|
| 169 |
-
else:
|
| 170 |
-
MAX_WORKERS = max(2, total_cores // 2)
|
| 171 |
-
print(f"🚀 Using {MAX_WORKERS}/{total_cores} CPU cores for parallel processing")
|
| 172 |
-
|
| 173 |
-
PRINT_CONVERSION_STATS = os.environ.get("LWM_PRINT_CONVERSION_STATS", "").strip().lower() in {"1", "true", "yes"}
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
# =============================================================================
|
| 177 |
-
# 2. CONTRASTIVE LEARNING COMPONENTS
|
| 178 |
-
# =============================================================================
|
| 179 |
-
|
| 180 |
-
class ProjectionHead(nn.Module):
|
| 181 |
-
"""
|
| 182 |
-
Projection head for contrastive learning (SimCLR-style).
|
| 183 |
-
Projects encoder output to a lower-dimensional space for contrastive loss.
|
| 184 |
-
"""
|
| 185 |
-
def __init__(self, d_model: int, projection_dim: int = 128):
|
| 186 |
-
super().__init__()
|
| 187 |
-
self.projection = nn.Sequential(
|
| 188 |
-
nn.Linear(d_model, d_model),
|
| 189 |
-
nn.ReLU(),
|
| 190 |
-
nn.Linear(d_model, projection_dim)
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
def forward(self, x):
|
| 194 |
-
"""
|
| 195 |
-
Args:
|
| 196 |
-
x: (batch, seq_len, d_model) - Encoder output
|
| 197 |
-
Returns:
|
| 198 |
-
z: (batch, projection_dim) - Projected embeddings
|
| 199 |
-
"""
|
| 200 |
-
# Global average pooling over sequence dimension
|
| 201 |
-
pooled = x.mean(dim=1) # (batch, d_model)
|
| 202 |
-
z = self.projection(pooled) # (batch, projection_dim)
|
| 203 |
-
z = F.normalize(z, dim=1) # L2 normalize
|
| 204 |
-
return z
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
class ContrastiveLWM(nn.Module):
|
| 208 |
-
"""
|
| 209 |
-
LWM model with contrastive learning projection heads.
|
| 210 |
-
"""
|
| 211 |
-
def __init__(self, lwm_encoder, projection_dim: int = 128, input_dim: int = 32):
|
| 212 |
-
super().__init__()
|
| 213 |
-
self.encoder = lwm_encoder
|
| 214 |
-
|
| 215 |
-
# MLM reconstruction head: project d_model back to input_dim
|
| 216 |
-
self.mlm_head = nn.Linear(lwm_encoder.d_model, input_dim)
|
| 217 |
-
|
| 218 |
-
# Separate projection heads for modulation and mobility
|
| 219 |
-
self.modulation_projection = ProjectionHead(lwm_encoder.d_model, projection_dim)
|
| 220 |
-
self.mobility_projection = ProjectionHead(lwm_encoder.d_model, projection_dim)
|
| 221 |
-
|
| 222 |
-
def forward(self, x, return_projections: bool = False):
|
| 223 |
-
"""
|
| 224 |
-
Args:
|
| 225 |
-
x: Input tokens
|
| 226 |
-
return_projections: If True, return contrastive projections and MLM predictions
|
| 227 |
-
|
| 228 |
-
Returns:
|
| 229 |
-
If return_projections:
|
| 230 |
-
mlm_predictions, z_modulation, z_mobility
|
| 231 |
-
Else:
|
| 232 |
-
mlm_predictions (for MLM task only)
|
| 233 |
-
"""
|
| 234 |
-
# Forward through encoder
|
| 235 |
-
encoder_out = self.encoder(x) # (batch, seq_len, d_model)
|
| 236 |
-
|
| 237 |
-
# MLM prediction head (always compute for reconstruction)
|
| 238 |
-
mlm_predictions = self.mlm_head(encoder_out) # (batch, seq_len, input_dim)
|
| 239 |
-
|
| 240 |
-
if return_projections:
|
| 241 |
-
z_mod = self.modulation_projection(encoder_out)
|
| 242 |
-
z_mob = self.mobility_projection(encoder_out)
|
| 243 |
-
return mlm_predictions, z_mod, z_mob
|
| 244 |
-
else:
|
| 245 |
-
return mlm_predictions
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
def supervised_contrastive_loss(
|
| 249 |
-
embeddings: torch.Tensor,
|
| 250 |
-
labels: torch.Tensor,
|
| 251 |
-
temperature: float = 0.07,
|
| 252 |
-
base_temperature: float = 0.07
|
| 253 |
-
) -> torch.Tensor:
|
| 254 |
-
"""
|
| 255 |
-
Supervised Contrastive Loss (SupCon) from Khosla et al. 2020.
|
| 256 |
-
|
| 257 |
-
Args:
|
| 258 |
-
embeddings: (batch, dim) - Normalized embeddings
|
| 259 |
-
labels: (batch,) - Class labels
|
| 260 |
-
temperature: Temperature scaling
|
| 261 |
-
base_temperature: Base temperature for normalization
|
| 262 |
-
|
| 263 |
-
Returns:
|
| 264 |
-
loss: Scalar SupCon loss
|
| 265 |
-
"""
|
| 266 |
-
batch_size = embeddings.size(0)
|
| 267 |
-
|
| 268 |
-
# Compute similarity matrix
|
| 269 |
-
sim_matrix = torch.matmul(embeddings, embeddings.T) / temperature # (batch, batch)
|
| 270 |
-
|
| 271 |
-
# Mask for positives (same label)
|
| 272 |
-
labels = labels.contiguous().view(-1, 1)
|
| 273 |
-
mask_pos = torch.eq(labels, labels.T).float().to(embeddings.device) # (batch, batch)
|
| 274 |
-
|
| 275 |
-
# Remove diagonal (self-similarity)
|
| 276 |
-
logits_mask = torch.scatter(
|
| 277 |
-
torch.ones_like(mask_pos),
|
| 278 |
-
1,
|
| 279 |
-
torch.arange(batch_size).view(-1, 1).to(embeddings.device),
|
| 280 |
-
0
|
| 281 |
-
)
|
| 282 |
-
mask_pos = mask_pos * logits_mask
|
| 283 |
-
|
| 284 |
-
# Compute log probabilities
|
| 285 |
-
exp_sim = torch.exp(sim_matrix) * logits_mask
|
| 286 |
-
log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
|
| 287 |
-
|
| 288 |
-
# Mean over positives
|
| 289 |
-
mean_log_prob_pos = (mask_pos * log_prob).sum(dim=1) / (mask_pos.sum(dim=1) + 1e-8)
|
| 290 |
-
|
| 291 |
-
# Loss
|
| 292 |
-
loss = -(temperature / base_temperature) * mean_log_prob_pos
|
| 293 |
-
loss = loss.mean()
|
| 294 |
-
|
| 295 |
-
return loss
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
class ContrastiveDataset(Dataset):
|
| 299 |
-
"""
|
| 300 |
-
Dataset wrapper that provides contrastive learning triplets.
|
| 301 |
-
"""
|
| 302 |
-
def __init__(
|
| 303 |
-
self,
|
| 304 |
-
spectrograms: np.ndarray,
|
| 305 |
-
labels: np.ndarray,
|
| 306 |
-
metadata: Dict[str, np.ndarray],
|
| 307 |
-
indices_by_modulation: Dict[int, List[int]],
|
| 308 |
-
indices_by_mobility: Dict[int, List[int]]
|
| 309 |
-
):
|
| 310 |
-
self.spectrograms = spectrograms
|
| 311 |
-
self.labels = labels
|
| 312 |
-
self.metadata = metadata
|
| 313 |
-
self.indices_by_modulation = indices_by_modulation
|
| 314 |
-
self.indices_by_mobility = indices_by_mobility
|
| 315 |
-
|
| 316 |
-
def __len__(self):
|
| 317 |
-
return len(self.spectrograms)
|
| 318 |
-
|
| 319 |
-
def __getitem__(self, idx):
|
| 320 |
-
"""
|
| 321 |
-
Returns anchor sample with its metadata.
|
| 322 |
-
"""
|
| 323 |
-
spectrogram = self.spectrograms[idx]
|
| 324 |
-
label = self.labels[idx]
|
| 325 |
-
|
| 326 |
-
metadata = {
|
| 327 |
-
'snr_db': self.metadata['snr_db'][idx],
|
| 328 |
-
'doppler_id': self.metadata['doppler_id'][idx],
|
| 329 |
-
'modulation_id': self.metadata['modulation_id'][idx]
|
| 330 |
-
}
|
| 331 |
-
|
| 332 |
-
return spectrogram, label, metadata
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
# =============================================================================
|
| 336 |
-
# 3. DATA CONVERSION AND PREPROCESSING
|
| 337 |
-
# =============================================================================
|
| 338 |
-
|
| 339 |
-
def convert_complex_to_interleaved(spectrograms):
|
| 340 |
-
"""
|
| 341 |
-
Convert complex-valued spectrograms to real-imaginary interleaved format.
|
| 342 |
-
|
| 343 |
-
Args:
|
| 344 |
-
spectrograms (np.ndarray): Complex-valued array of shape (n_samples, n_rows, n_cols)
|
| 345 |
-
or (n_samples, 1, n_rows, n_cols)
|
| 346 |
-
|
| 347 |
-
Returns:
|
| 348 |
-
np.ndarray: Real-valued array with interleaved real/imag parts
|
| 349 |
-
Shape: (n_samples, n_rows, n_cols * 2)
|
| 350 |
-
"""
|
| 351 |
-
# Handle different input shapes
|
| 352 |
-
if spectrograms.ndim == 4:
|
| 353 |
-
spectrograms = spectrograms[:, 0, :, :]
|
| 354 |
-
|
| 355 |
-
# Check if data is complex
|
| 356 |
-
if np.iscomplexobj(spectrograms):
|
| 357 |
-
n_samples, n_rows, n_cols = spectrograms.shape
|
| 358 |
-
|
| 359 |
-
# Extract real and imaginary parts
|
| 360 |
-
flat_real = spectrograms.real
|
| 361 |
-
flat_imag = spectrograms.imag
|
| 362 |
-
|
| 363 |
-
# Interleave real and imaginary parts along the last axis
|
| 364 |
-
interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
|
| 365 |
-
interleaved[:, :, 0::2] = flat_real # Even indices: real parts
|
| 366 |
-
interleaved[:, :, 1::2] = flat_imag # Odd indices: imaginary parts
|
| 367 |
-
|
| 368 |
-
if PRINT_CONVERSION_STATS:
|
| 369 |
-
print(f" ℹ️ Converted complex spectrograms: {spectrograms.shape} -> {interleaved.shape}")
|
| 370 |
-
print(f" Real part range: [{flat_real.min():.2e}, {flat_real.max():.2e}]")
|
| 371 |
-
print(f" Imag part range: [{flat_imag.min():.2e}, {flat_imag.max():.2e}]")
|
| 372 |
-
|
| 373 |
-
return interleaved
|
| 374 |
-
else:
|
| 375 |
-
# Already real-valued
|
| 376 |
-
if spectrograms.ndim == 3:
|
| 377 |
-
if PRINT_CONVERSION_STATS:
|
| 378 |
-
print(f" ℹ️ Data is already real-valued: {spectrograms.shape}")
|
| 379 |
-
return spectrograms
|
| 380 |
-
else:
|
| 381 |
-
raise ValueError(f"Unexpected spectrogram shape: {spectrograms.shape}")
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
def process_single_scenario(scenario_info):
|
| 385 |
-
"""단일 시나리오를 처리하는 함수 (멀티프로세싱용)"""
|
| 386 |
-
scenario_name, spectrogram_path = scenario_info
|
| 387 |
-
|
| 388 |
-
try:
|
| 389 |
-
# Parse metadata from path
|
| 390 |
-
path_metadata = _parse_metadata(spectrogram_path)
|
| 391 |
-
|
| 392 |
-
# 메모리 효율성을 위해 필요한 데이터만 로드
|
| 393 |
-
scenario_spectrograms, scenario_labels = generate_spectrograms_and_labels(
|
| 394 |
-
scenario_name=scenario_name,
|
| 395 |
-
spectrogram_path=spectrogram_path,
|
| 396 |
-
cache_path=None, # 메모리 문제로 캐시 비활성화
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
# Validate load
|
| 400 |
-
if scenario_spectrograms is None or (hasattr(scenario_spectrograms, 'size') and scenario_spectrograms.size == 0):
|
| 401 |
-
print(f" ⚠️ No data loaded from: {spectrogram_path}")
|
| 402 |
-
return None
|
| 403 |
-
|
| 404 |
-
# Convert complex spectrograms to interleaved real-imaginary format
|
| 405 |
-
scenario_spectrograms = convert_complex_to_interleaved(scenario_spectrograms)
|
| 406 |
-
|
| 407 |
-
# 데이터 분할 (인덱스만 계산)
|
| 408 |
-
total_samples = len(scenario_spectrograms)
|
| 409 |
-
train_size = int(0.8 * total_samples)
|
| 410 |
-
val_size = total_samples - train_size
|
| 411 |
-
|
| 412 |
-
# 메모리 절약을 위해 numpy array로 유지
|
| 413 |
-
train_data = np.array(scenario_spectrograms[:train_size], dtype=np.float32)
|
| 414 |
-
val_data = np.array(scenario_spectrograms[train_size:], dtype=np.float32)
|
| 415 |
-
|
| 416 |
-
# Metadata arrays
|
| 417 |
-
snr_array = np.full(total_samples, path_metadata['snr_db'], dtype=np.float32)
|
| 418 |
-
doppler_array = np.full(total_samples, path_metadata['doppler_id'], dtype=np.int64)
|
| 419 |
-
modulation_array = np.full(total_samples, path_metadata['modulation_id'], dtype=np.int64)
|
| 420 |
-
|
| 421 |
-
train_meta = {
|
| 422 |
-
'snr_db': snr_array[:train_size],
|
| 423 |
-
'doppler_id': doppler_array[:train_size],
|
| 424 |
-
'modulation_id': modulation_array[:train_size],
|
| 425 |
-
}
|
| 426 |
-
val_meta = {
|
| 427 |
-
'snr_db': snr_array[train_size:],
|
| 428 |
-
'doppler_id': doppler_array[train_size:],
|
| 429 |
-
'modulation_id': modulation_array[train_size:],
|
| 430 |
-
}
|
| 431 |
-
|
| 432 |
-
# 불필요한 데이터 즉시 삭제
|
| 433 |
-
del scenario_spectrograms
|
| 434 |
-
|
| 435 |
-
return {
|
| 436 |
-
'scenario': scenario_name,
|
| 437 |
-
'train_data': train_data,
|
| 438 |
-
'val_data': val_data,
|
| 439 |
-
'train_meta': train_meta,
|
| 440 |
-
'val_meta': val_meta,
|
| 441 |
-
'train_size': len(train_data),
|
| 442 |
-
'val_size': len(val_data)
|
| 443 |
-
}
|
| 444 |
-
except Exception as e:
|
| 445 |
-
print(f"❌ Error processing scenario {scenario_name}: {e}")
|
| 446 |
-
import traceback
|
| 447 |
-
traceback.print_exc()
|
| 448 |
-
return None
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
# =============================================================================
|
| 452 |
-
# 4. SCENARIO LIST AND PROPERTIES (Same as original)
|
| 453 |
-
# =============================================================================
|
| 454 |
-
|
| 455 |
-
SUPPORTED_COMM_TYPES = {"LTE", "WiFi", "5G"}
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
def _parse_standard_args():
|
| 459 |
-
parser = argparse.ArgumentParser(add_help=False)
|
| 460 |
-
parser.add_argument('--standards', nargs='+', choices=SUPPORTED_COMM_TYPES,
|
| 461 |
-
help='Specify one or more communication types to include (default: all).')
|
| 462 |
-
for comm in SUPPORTED_COMM_TYPES:
|
| 463 |
-
parser.add_argument(f'--{comm}', dest=f'flag_{comm}', action='store_true',
|
| 464 |
-
help=f'Include only {comm} data (can be combined).')
|
| 465 |
-
parser.add_argument('--city', '--cities', dest='cities', nargs='+',
|
| 466 |
-
help='Limit scenarios to one or more city prefixes (e.g., "0" or "city_0").')
|
| 467 |
-
parser.add_argument(
|
| 468 |
-
'--normalization',
|
| 469 |
-
choices=('per_sample', 'dataset'),
|
| 470 |
-
default='per_sample',
|
| 471 |
-
help='Normalization mode applied during tokenization (default: %(default)s).'
|
| 472 |
-
)
|
| 473 |
-
parser.add_argument('--help', action='help')
|
| 474 |
-
|
| 475 |
-
args, remaining = parser.parse_known_args()
|
| 476 |
-
|
| 477 |
-
enabled = set(SUPPORTED_COMM_TYPES)
|
| 478 |
-
if args.standards:
|
| 479 |
-
enabled = set(args.standards)
|
| 480 |
-
else:
|
| 481 |
-
flagged = {comm for comm in SUPPORTED_COMM_TYPES if getattr(args, f'flag_{comm}', False)}
|
| 482 |
-
if flagged:
|
| 483 |
-
enabled = flagged
|
| 484 |
-
|
| 485 |
-
selected_cities: list[str] | None = None
|
| 486 |
-
if args.cities:
|
| 487 |
-
selected_cities = []
|
| 488 |
-
for city_token in args.cities:
|
| 489 |
-
token = str(city_token).strip()
|
| 490 |
-
if not token:
|
| 491 |
-
continue
|
| 492 |
-
if token.startswith('city_'):
|
| 493 |
-
selected_cities.append(token)
|
| 494 |
-
else:
|
| 495 |
-
selected_cities.append(f'city_{token}')
|
| 496 |
-
if not selected_cities:
|
| 497 |
-
selected_cities = None
|
| 498 |
-
|
| 499 |
-
sys.argv = [sys.argv[0]] + remaining
|
| 500 |
-
return enabled, selected_cities, args.normalization
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
ENABLED_COMM_TYPES, ENABLED_CITY_PREFIXES, NORMALIZATION_MODE = _parse_standard_args()
|
| 504 |
-
MAX_SCENARIOS = int(os.environ.get("LWM_MAX_SCENARIOS", "0")) or None
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
def _extract_scenario_token(file_path):
|
| 508 |
-
"""Derive the base scenario token (without city) from the file path."""
|
| 509 |
-
normalized_path = os.path.normpath(file_path)
|
| 510 |
-
parts = normalized_path.split(os.sep)
|
| 511 |
-
|
| 512 |
-
scenario_parts = []
|
| 513 |
-
for i, part in enumerate(parts):
|
| 514 |
-
if part in SUPPORTED_COMM_TYPES:
|
| 515 |
-
trailing = parts[i:i + 5]
|
| 516 |
-
if trailing:
|
| 517 |
-
scenario_parts = trailing[:5]
|
| 518 |
-
break
|
| 519 |
-
|
| 520 |
-
if not scenario_parts:
|
| 521 |
-
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 522 |
-
if base_name.startswith('spectrogram_'):
|
| 523 |
-
tokens = base_name.split('_')[1:]
|
| 524 |
-
if tokens and tokens[0] in SUPPORTED_COMM_TYPES:
|
| 525 |
-
scenario_parts = tokens[:5] if len(tokens) >= 5 else tokens
|
| 526 |
-
|
| 527 |
-
return '_'.join(scenario_parts) if scenario_parts else None
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
@lru_cache(maxsize=1)
|
| 531 |
-
def _collect_scenario_file_info():
|
| 532 |
-
import glob
|
| 533 |
-
|
| 534 |
-
scenario_entries = []
|
| 535 |
-
|
| 536 |
-
# New MATLAB receiver pipeline output
|
| 537 |
-
new_base = os.path.join('ls_data', 'MATLAB', 'receiver_pipeline')
|
| 538 |
-
if os.path.isdir(new_base):
|
| 539 |
-
patterns = [os.path.join(new_base, '*', '**', 'spectrogram_*.mat')]
|
| 540 |
-
for pattern in patterns:
|
| 541 |
-
for file_path in sorted(glob.glob(pattern, recursive=True)):
|
| 542 |
-
norm = os.path.normpath(file_path)
|
| 543 |
-
parts = norm.split(os.sep)
|
| 544 |
-
try:
|
| 545 |
-
idx = parts.index('receiver_pipeline')
|
| 546 |
-
city_name = parts[idx + 1] if idx + 1 < len(parts) else 'receiver_pipeline'
|
| 547 |
-
except ValueError:
|
| 548 |
-
city_name = 'receiver_pipeline'
|
| 549 |
-
|
| 550 |
-
base_token = _extract_scenario_token(file_path)
|
| 551 |
-
if not base_token:
|
| 552 |
-
continue
|
| 553 |
-
comm_type = base_token.split('_', 1)[0]
|
| 554 |
-
if comm_type not in ENABLED_COMM_TYPES:
|
| 555 |
-
continue
|
| 556 |
-
scenario_id = f"{city_name}::{base_token}"
|
| 557 |
-
scenario_entries.append((scenario_id, file_path, city_name, base_token))
|
| 558 |
-
|
| 559 |
-
# Legacy repo layouts under spectrograms/city_*
|
| 560 |
-
import glob as _glob
|
| 561 |
-
for city_dir in sorted(_glob.glob(os.path.join('spectrograms', 'city_*'))):
|
| 562 |
-
if not os.path.isdir(city_dir):
|
| 563 |
-
continue
|
| 564 |
-
city_name = os.path.basename(city_dir)
|
| 565 |
-
if ENABLED_CITY_PREFIXES:
|
| 566 |
-
if not any(city_name.startswith(prefix) for prefix in ENABLED_CITY_PREFIXES):
|
| 567 |
-
continue
|
| 568 |
-
candidate_patterns = [
|
| 569 |
-
os.path.join(city_dir, '**', 'complex_raw', '**', 'spectrogram_*.mat'),
|
| 570 |
-
os.path.join(city_dir, '**', 'spectrogram_*.mat'),
|
| 571 |
-
]
|
| 572 |
-
city_files = []
|
| 573 |
-
seen_paths = set()
|
| 574 |
-
for pattern in candidate_patterns:
|
| 575 |
-
for file_path in sorted(_glob.glob(pattern, recursive=True)):
|
| 576 |
-
if not file_path.lower().endswith('.mat'):
|
| 577 |
-
continue
|
| 578 |
-
if file_path in seen_paths:
|
| 579 |
-
continue
|
| 580 |
-
seen_paths.add(file_path)
|
| 581 |
-
city_files.append(file_path)
|
| 582 |
-
|
| 583 |
-
if not city_files:
|
| 584 |
-
pattern = os.path.join(city_dir, '**', '512FFT', '**', 'spectrograms', '*.pkl')
|
| 585 |
-
city_files = sorted(_glob.glob(pattern, recursive=True))
|
| 586 |
-
|
| 587 |
-
for file_path in city_files:
|
| 588 |
-
base_token = _extract_scenario_token(file_path)
|
| 589 |
-
if not base_token:
|
| 590 |
-
continue
|
| 591 |
-
comm_type = base_token.split('_', 1)[0]
|
| 592 |
-
if comm_type not in ENABLED_COMM_TYPES:
|
| 593 |
-
continue
|
| 594 |
-
scenario_id = f"{city_name}::{base_token}"
|
| 595 |
-
scenario_entries.append((scenario_id, file_path, city_name, base_token))
|
| 596 |
-
|
| 597 |
-
if MAX_SCENARIOS:
|
| 598 |
-
scenario_entries = scenario_entries[:MAX_SCENARIOS]
|
| 599 |
-
|
| 600 |
-
return scenario_entries
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
def scenarios_list():
|
| 604 |
-
scenario_entries = _collect_scenario_file_info()
|
| 605 |
-
|
| 606 |
-
if not scenario_entries:
|
| 607 |
-
print("⚠️ No spectrogram files found for pretraining.")
|
| 608 |
-
return np.array([])
|
| 609 |
-
|
| 610 |
-
print(f"Enabled communication types: {sorted(ENABLED_COMM_TYPES)}")
|
| 611 |
-
if ENABLED_CITY_PREFIXES:
|
| 612 |
-
print(f"Selected city prefixes: {sorted(ENABLED_CITY_PREFIXES)}")
|
| 613 |
-
city_counts = Counter(entry[2] for entry in scenario_entries)
|
| 614 |
-
print("Using scenarios from the following city datasets:")
|
| 615 |
-
for city_name, count in city_counts.items():
|
| 616 |
-
print(f" - {city_name}: {count} files")
|
| 617 |
-
|
| 618 |
-
print(f"Total scenarios selected: {len(scenario_entries)}")
|
| 619 |
-
return np.array([entry[0] for entry in scenario_entries])
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
def scenario_prop():
|
| 623 |
-
scenario_entries = _collect_scenario_file_info()
|
| 624 |
-
|
| 625 |
-
row_column_users = {}
|
| 626 |
-
for scenario_id, file_path, city_name, _ in scenario_entries:
|
| 627 |
-
row_column_users[scenario_id] = {
|
| 628 |
-
'spectrogram_path': file_path,
|
| 629 |
-
'cache_path': os.path.join('spectrograms', city_name, 'spectrogram_cache_128x128.pkl')
|
| 630 |
-
}
|
| 631 |
-
|
| 632 |
-
return row_column_users
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
# =============================================================================
|
| 636 |
-
# 5. TRAINING PARAMETERS AND HYPERPARAMETERS
|
| 637 |
-
# =============================================================================
|
| 638 |
-
|
| 639 |
-
EPOCHS = 20
|
| 640 |
-
BATCH_SIZE = 64
|
| 641 |
-
VAL_BATCH_SIZE = 64
|
| 642 |
-
WARMUP_EPOCHS = 5
|
| 643 |
-
BASE_LR = 5e-4
|
| 644 |
-
MIN_LR = 1e-5 # Base LR의 1/50 (was 1e-8, too small for effective learning)
|
| 645 |
-
|
| 646 |
-
# Gradient accumulation for larger effective batch size
|
| 647 |
-
ACCUMULATION_STEPS = 4 # Effective batch size = 64 × 4 = 256
|
| 648 |
-
|
| 649 |
-
# Model parameters
|
| 650 |
-
N_ROWS = 4
|
| 651 |
-
N_COLUMNS = 4
|
| 652 |
-
ELEMENT_LENGTH = N_ROWS * N_COLUMNS * 2 # Complex spectrograms
|
| 653 |
-
D_MODEL = 128
|
| 654 |
-
MAX_LEN = 1025
|
| 655 |
-
N_LAYERS = 12
|
| 656 |
-
device_idx = 0
|
| 657 |
-
WEIGHT_DECAY = 0.05
|
| 658 |
-
BETA1 = 0.9
|
| 659 |
-
BETA2 = 0.999
|
| 660 |
-
MASK_PERCENT = 0.6
|
| 661 |
-
N_HEADS = 8
|
| 662 |
-
DROPOUT = 0.1
|
| 663 |
-
|
| 664 |
-
# Contrastive learning parameters
|
| 665 |
-
PROJECTION_DIM = 128
|
| 666 |
-
CONTRASTIVE_TEMPERATURE = 0.07
|
| 667 |
-
CONTRASTIVE_WEIGHT_MODULATION = 50.0 # Increased from 0.5 to match MLM loss scale
|
| 668 |
-
CONTRASTIVE_WEIGHT_MOBILITY = 30.0 # Increased from 0.3 to match MLM loss scale
|
| 669 |
-
MLM_WEIGHT = 1.0
|
| 670 |
-
|
| 671 |
-
print(f"📊 Model configuration for complex spectrograms with contrastive learning:")
|
| 672 |
-
print(f" Patch size: {N_ROWS}x{N_COLUMNS}")
|
| 673 |
-
print(f" Element length: {ELEMENT_LENGTH} (includes real+imag interleaving)")
|
| 674 |
-
print(f" Max sequence length: {MAX_LEN}")
|
| 675 |
-
print(f" Batch size: {BATCH_SIZE} (physical), {BATCH_SIZE * ACCUMULATION_STEPS} (effective)")
|
| 676 |
-
print(f" Gradient accumulation steps: {ACCUMULATION_STEPS}")
|
| 677 |
-
print(f" Projection dim: {PROJECTION_DIM}")
|
| 678 |
-
print(f" Contrastive temperature: {CONTRASTIVE_TEMPERATURE}")
|
| 679 |
-
print(f" Loss weights - MLM: {MLM_WEIGHT}, Modulation: {CONTRASTIVE_WEIGHT_MODULATION}, Mobility: {CONTRASTIVE_WEIGHT_MOBILITY}")
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
# =============================================================================
|
| 683 |
-
# 6. DATA GENERATION AND LOADING
|
| 684 |
-
# =============================================================================
|
| 685 |
-
|
| 686 |
-
scenarios = scenarios_list()
|
| 687 |
-
scenario_properties = scenario_prop()
|
| 688 |
-
|
| 689 |
-
train_spectrogram_chunks = []
|
| 690 |
-
val_spectrogram_chunks = []
|
| 691 |
-
train_label_chunks = []
|
| 692 |
-
val_label_chunks = []
|
| 693 |
-
train_meta_chunks = []
|
| 694 |
-
val_meta_chunks = []
|
| 695 |
-
|
| 696 |
-
print(f"📂 Loading {len(scenarios)} scenarios...")
|
| 697 |
-
|
| 698 |
-
scenario_info_list = []
|
| 699 |
-
missing_props = []
|
| 700 |
-
for scenario in scenarios:
|
| 701 |
-
props = scenario_properties.get(scenario)
|
| 702 |
-
if props is None:
|
| 703 |
-
missing_props.append(scenario)
|
| 704 |
-
continue
|
| 705 |
-
scenario_info_list.append((scenario, props["spectrogram_path"]))
|
| 706 |
-
|
| 707 |
-
if missing_props:
|
| 708 |
-
print("⚠️ Missing metadata for the following scenarios; skipping:")
|
| 709 |
-
for scen in missing_props:
|
| 710 |
-
print(f" - {scen}")
|
| 711 |
-
|
| 712 |
-
print(f"📂 Loading {len(scenario_info_list)} scenarios using {MAX_WORKERS} workers...")
|
| 713 |
-
|
| 714 |
-
successful_scenarios = 0
|
| 715 |
-
|
| 716 |
-
# Parallel processing with progress bar
|
| 717 |
-
from multiprocessing import Pool
|
| 718 |
-
with Pool(processes=MAX_WORKERS) as pool:
|
| 719 |
-
results = list(tqdm(
|
| 720 |
-
pool.imap(process_single_scenario, scenario_info_list),
|
| 721 |
-
total=len(scenario_info_list),
|
| 722 |
-
desc="Processing scenarios",
|
| 723 |
-
unit="scenario"
|
| 724 |
-
))
|
| 725 |
-
|
| 726 |
-
for result in results:
|
| 727 |
-
if result is not None:
|
| 728 |
-
train_spectrogram_chunks.append(result['train_data'])
|
| 729 |
-
val_spectrogram_chunks.append(result['val_data'])
|
| 730 |
-
train_label_chunks.append(np.zeros(result['train_size'], dtype=np.int64))
|
| 731 |
-
val_label_chunks.append(np.zeros(result['val_size'], dtype=np.int64))
|
| 732 |
-
train_meta_chunks.append(result['train_meta'])
|
| 733 |
-
val_meta_chunks.append(result['val_meta'])
|
| 734 |
-
successful_scenarios += 1
|
| 735 |
-
|
| 736 |
-
print(f"✅ Processing completed! Successful scenarios: {successful_scenarios}/{len(scenario_info_list)}")
|
| 737 |
-
|
| 738 |
-
if not train_spectrogram_chunks or not val_spectrogram_chunks:
|
| 739 |
-
raise ValueError("No spectrogram data collected; check scenario configuration.")
|
| 740 |
-
|
| 741 |
-
print("🔄 Collating spectrogram arrays...")
|
| 742 |
-
train_spectrograms = np.concatenate(train_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
|
| 743 |
-
val_spectrograms = np.concatenate(val_spectrogram_chunks, axis=0).astype(np.float32, copy=False)
|
| 744 |
-
train_labels = np.concatenate(train_label_chunks, axis=0)
|
| 745 |
-
val_labels = np.concatenate(val_label_chunks, axis=0)
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
def _concat_metadata_dicts(dict_list):
|
| 749 |
-
if not dict_list:
|
| 750 |
-
return {}
|
| 751 |
-
keys = dict_list[0].keys()
|
| 752 |
-
return {k: np.concatenate([d[k] for d in dict_list], axis=0) for k in keys}
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
train_metadata = _concat_metadata_dicts(train_meta_chunks)
|
| 756 |
-
val_metadata = _concat_metadata_dicts(val_meta_chunks)
|
| 757 |
-
|
| 758 |
-
del train_spectrogram_chunks, val_spectrogram_chunks, train_label_chunks, val_label_chunks
|
| 759 |
-
del train_meta_chunks, val_meta_chunks
|
| 760 |
-
|
| 761 |
-
print(f"Training spectrograms shape: {train_spectrograms.shape}")
|
| 762 |
-
print(f"Validation spectrograms shape: {val_spectrograms.shape}")
|
| 763 |
-
print(f"Memory usage: {train_spectrograms.nbytes + val_spectrograms.nbytes:,} bytes")
|
| 764 |
-
|
| 765 |
-
# Print metadata statistics
|
| 766 |
-
print(f"\n📊 Metadata statistics:")
|
| 767 |
-
print(f" Discovered modulation schemes: {len(MODULATION_MAP)}")
|
| 768 |
-
for mod_name, mod_id in sorted(MODULATION_MAP.items(), key=lambda x: x[1]):
|
| 769 |
-
count_train = np.sum(train_metadata['modulation_id'] == mod_id)
|
| 770 |
-
count_val = np.sum(val_metadata['modulation_id'] == mod_id)
|
| 771 |
-
print(f" {mod_name} (ID={mod_id}): {count_train} train, {count_val} val samples")
|
| 772 |
-
|
| 773 |
-
print(f"\n Modulation distribution (train):")
|
| 774 |
-
for mod_id in np.unique(train_metadata['modulation_id']):
|
| 775 |
-
count = np.sum(train_metadata['modulation_id'] == mod_id)
|
| 776 |
-
mod_name = MODULATION_INV.get(mod_id, f"Unknown({mod_id})")
|
| 777 |
-
print(f" {mod_name}: {count} samples ({100*count/len(train_metadata['modulation_id']):.1f}%)")
|
| 778 |
-
|
| 779 |
-
print(f" Mobility distribution (train):")
|
| 780 |
-
for mob_id in np.unique(train_metadata['doppler_id']):
|
| 781 |
-
count = np.sum(train_metadata['doppler_id'] == mob_id)
|
| 782 |
-
mob_name = DOPPLER_INV.get(mob_id, f"Unknown({mob_id})")
|
| 783 |
-
print(f" {mob_name}: {count} samples ({100*count/len(train_metadata['doppler_id']):.1f}%)")
|
| 784 |
-
|
| 785 |
-
train_mean = float(train_spectrograms.mean())
|
| 786 |
-
train_std = float(train_spectrograms.std())
|
| 787 |
-
if abs(train_std) < 1e-6:
|
| 788 |
-
print("⚠️ Training std near zero, using epsilon for stability")
|
| 789 |
-
train_std = 1e-6
|
| 790 |
-
dataset_normalization = {'mean': train_mean, 'std': train_std, 'normalization': NORMALIZATION_MODE}
|
| 791 |
-
print(f"Dataset normalization stats -> mean: {train_mean:.4f}, std: {train_std:.4f}")
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
# =============================================================================
|
| 795 |
-
# 7. BUILD INDEX FOR CONTRASTIVE SAMPLING
|
| 796 |
-
# =============================================================================
|
| 797 |
-
|
| 798 |
-
def build_class_indices(metadata: Dict[str, np.ndarray]) -> Tuple[Dict, Dict]:
|
| 799 |
-
"""
|
| 800 |
-
Build index mapping from modulation/mobility ID to sample indices.
|
| 801 |
-
"""
|
| 802 |
-
indices_by_modulation = {}
|
| 803 |
-
indices_by_mobility = {}
|
| 804 |
-
|
| 805 |
-
for idx in range(len(metadata['modulation_id'])):
|
| 806 |
-
mod_id = int(metadata['modulation_id'][idx])
|
| 807 |
-
mob_id = int(metadata['doppler_id'][idx])
|
| 808 |
-
|
| 809 |
-
if mod_id not in indices_by_modulation:
|
| 810 |
-
indices_by_modulation[mod_id] = []
|
| 811 |
-
indices_by_modulation[mod_id].append(idx)
|
| 812 |
-
|
| 813 |
-
if mob_id not in indices_by_mobility:
|
| 814 |
-
indices_by_mobility[mob_id] = []
|
| 815 |
-
indices_by_mobility[mob_id].append(idx)
|
| 816 |
-
|
| 817 |
-
return indices_by_modulation, indices_by_mobility
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
print("🔍 Building class indices for contrastive learning...")
|
| 821 |
-
train_indices_by_modulation, train_indices_by_mobility = build_class_indices(train_metadata)
|
| 822 |
-
val_indices_by_modulation, val_indices_by_mobility = build_class_indices(val_metadata)
|
| 823 |
-
print("✅ Class indices built successfully!")
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
# =============================================================================
|
| 827 |
-
# 8. DATA TOKENIZATION
|
| 828 |
-
# =============================================================================
|
| 829 |
-
|
| 830 |
-
print("🔄 Starting tokenization of training data...")
|
| 831 |
-
preprocessed_train = tokenizer_train(
|
| 832 |
-
train_spectrograms,
|
| 833 |
-
max_len=MAX_LEN,
|
| 834 |
-
masking_percent=MASK_PERCENT,
|
| 835 |
-
mask=True,
|
| 836 |
-
seed=42,
|
| 837 |
-
metadata=train_metadata,
|
| 838 |
-
dataset_stats=dataset_normalization,
|
| 839 |
-
normalization=NORMALIZATION_MODE,
|
| 840 |
-
interleaved=True,
|
| 841 |
-
)
|
| 842 |
-
print("✅ Training data tokenization completed!")
|
| 843 |
-
|
| 844 |
-
print("🔄 Starting tokenization of validation data...")
|
| 845 |
-
preprocessed_val = tokenizer_train(
|
| 846 |
-
val_spectrograms,
|
| 847 |
-
max_len=MAX_LEN,
|
| 848 |
-
masking_percent=MASK_PERCENT,
|
| 849 |
-
mask=True,
|
| 850 |
-
seed=42,
|
| 851 |
-
metadata=val_metadata,
|
| 852 |
-
dataset_stats=dataset_normalization,
|
| 853 |
-
normalization=NORMALIZATION_MODE,
|
| 854 |
-
interleaved=True,
|
| 855 |
-
)
|
| 856 |
-
print("✅ Validation data tokenization completed!")
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
# =============================================================================
|
| 860 |
-
# 9. TRAIN/VALIDATION DATA SETUP
|
| 861 |
-
# =============================================================================
|
| 862 |
-
|
| 863 |
-
SEED = 42
|
| 864 |
-
torch.manual_seed(SEED)
|
| 865 |
-
np.random.seed(SEED)
|
| 866 |
-
|
| 867 |
-
train_data = preprocessed_train
|
| 868 |
-
val_data = preprocessed_val
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
# =============================================================================
|
| 872 |
-
# 10. DATALOADER CREATION
|
| 873 |
-
# =============================================================================
|
| 874 |
-
|
| 875 |
-
print("🔧 Creating data loaders...")
|
| 876 |
-
|
| 877 |
-
if isinstance(train_data, dict):
|
| 878 |
-
print(f" Training data format: dict with {len(train_data)} sequence lengths")
|
| 879 |
-
train_loaders = create_train_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
|
| 880 |
-
else:
|
| 881 |
-
print(f" Training data format: tensor with shape {train_data.shape}")
|
| 882 |
-
train_dataset = TensorDataset(train_data)
|
| 883 |
-
train_loaders = {'seq_0': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)}
|
| 884 |
-
|
| 885 |
-
if isinstance(val_data, dict):
|
| 886 |
-
print(f" Validation data format: dict with {len(val_data)} sequence lengths")
|
| 887 |
-
val_loaders = create_train_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
|
| 888 |
-
else:
|
| 889 |
-
print(f" Validation data format: tensor with shape {val_data.shape}")
|
| 890 |
-
val_dataset = TensorDataset(val_data)
|
| 891 |
-
val_loaders = {'seq_0': DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)}
|
| 892 |
-
|
| 893 |
-
print("✅ Data loaders created successfully!")
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
# =============================================================================
|
| 897 |
-
# 11. MODEL INITIALIZATION
|
| 898 |
-
# =============================================================================
|
| 899 |
-
|
| 900 |
-
print("🔧 Setting up device and GPU configuration...")
|
| 901 |
-
|
| 902 |
-
if torch.cuda.is_available():
|
| 903 |
-
device_count = torch.cuda.device_count()
|
| 904 |
-
print(f" CUDA available: {device_count} GPU(s) detected")
|
| 905 |
-
device = torch.device("cuda:0")
|
| 906 |
-
gpu_ids = list(range(device_count))
|
| 907 |
-
print(f" Using CUDA GPUs: {gpu_ids}")
|
| 908 |
-
|
| 909 |
-
for i in gpu_ids:
|
| 910 |
-
try:
|
| 911 |
-
mem_total = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
| 912 |
-
mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
|
| 913 |
-
print(f" GPU {i}: Total: {mem_total:.1f}GB, Allocated: {mem_allocated:.1f}GB")
|
| 914 |
-
except Exception as e:
|
| 915 |
-
print(f" GPU {i}: Error getting memory info - {e}")
|
| 916 |
-
|
| 917 |
-
elif torch.backends.mps.is_available():
|
| 918 |
-
device = torch.device("mps")
|
| 919 |
-
gpu_ids = []
|
| 920 |
-
print(" Using MPS (Apple Silicon GPU)")
|
| 921 |
-
else:
|
| 922 |
-
device = torch.device("cpu")
|
| 923 |
-
gpu_ids = []
|
| 924 |
-
print(" Using CPU")
|
| 925 |
-
|
| 926 |
-
print(f" Final device: {device}")
|
| 927 |
-
print(f" GPU IDs for DataParallel: {gpu_ids}")
|
| 928 |
-
|
| 929 |
-
print("🤖 Initializing LWM model with contrastive learning...")
|
| 930 |
-
print(f" Model parameters: element_length={ELEMENT_LENGTH}, d_model={D_MODEL}, n_layers={N_LAYERS}, max_len={MAX_LEN}, n_heads={N_HEADS}")
|
| 931 |
-
|
| 932 |
-
try:
|
| 933 |
-
# Create base LWM encoder
|
| 934 |
-
lwm_encoder = pretrained_model.lwm(
|
| 935 |
-
element_length=ELEMENT_LENGTH,
|
| 936 |
-
d_model=D_MODEL,
|
| 937 |
-
n_layers=N_LAYERS,
|
| 938 |
-
max_len=MAX_LEN,
|
| 939 |
-
n_heads=N_HEADS,
|
| 940 |
-
dropout=DROPOUT
|
| 941 |
-
)
|
| 942 |
-
|
| 943 |
-
# Wrap with contrastive learning module
|
| 944 |
-
# MLM head must output patch dimension (ELEMENT_LENGTH), not full spectrogram width
|
| 945 |
-
# Each token represents a 4×4×2 patch = 32 elements
|
| 946 |
-
model = ContrastiveLWM(lwm_encoder, projection_dim=PROJECTION_DIM, input_dim=ELEMENT_LENGTH)
|
| 947 |
-
print(f" ✅ Model created with input_dim={ELEMENT_LENGTH} (patch dimension)")
|
| 948 |
-
|
| 949 |
-
print(f" Moving model to device: {device}")
|
| 950 |
-
if 'mps' in str(device):
|
| 951 |
-
model = model.to(device).float()
|
| 952 |
-
print(" ✅ Model moved to MPS device (float32)")
|
| 953 |
-
else:
|
| 954 |
-
model = model.to(device)
|
| 955 |
-
print(" ✅ Model moved to device successfully")
|
| 956 |
-
|
| 957 |
-
except Exception as e:
|
| 958 |
-
print(f" ❌ Model initialization failed: {e}")
|
| 959 |
-
import traceback
|
| 960 |
-
traceback.print_exc()
|
| 961 |
-
exit(1)
|
| 962 |
-
|
| 963 |
-
# Use DataParallel for multi-GPU support
|
| 964 |
-
if gpu_ids:
|
| 965 |
-
model = nn.DataParallel(model, device_ids=gpu_ids)
|
| 966 |
-
print(f"Model loaded successfully on GPU {device.index}")
|
| 967 |
-
else:
|
| 968 |
-
print(f"Model loaded successfully on {device}")
|
| 969 |
-
|
| 970 |
-
n_parameters = count_parameters(model)
|
| 971 |
-
print(f"Number of trainable parameters: {n_parameters:,}")
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
# =============================================================================
|
| 975 |
-
# 12. OPTIMIZER AND LEARNING RATE SCHEDULER
|
| 976 |
-
# =============================================================================
|
| 977 |
-
|
| 978 |
-
# Account for gradient accumulation: scheduler step is called once per ACCUMULATION_STEPS batches
|
| 979 |
-
# So actual optimizer steps = total_batches / ACCUMULATION_STEPS
|
| 980 |
-
total_batches_per_epoch = sum(len(loader) for loader in train_loaders.values())
|
| 981 |
-
actual_steps_per_epoch = math.ceil(total_batches_per_epoch / ACCUMULATION_STEPS)
|
| 982 |
-
TOTAL_STEPS = actual_steps_per_epoch * EPOCHS
|
| 983 |
-
WARMUP_STEPS = actual_steps_per_epoch * WARMUP_EPOCHS
|
| 984 |
-
|
| 985 |
-
print(f"📊 Learning rate schedule:")
|
| 986 |
-
print(f" Total batches per epoch: {total_batches_per_epoch}")
|
| 987 |
-
print(f" Accumulation steps: {ACCUMULATION_STEPS}")
|
| 988 |
-
print(f" Actual optimizer steps per epoch: {actual_steps_per_epoch}")
|
| 989 |
-
print(f" Total training steps: {TOTAL_STEPS}")
|
| 990 |
-
print(f" Warmup steps: {WARMUP_STEPS}")
|
| 991 |
-
|
| 992 |
-
optimizer = AdamW(
|
| 993 |
-
model.parameters(),
|
| 994 |
-
lr=BASE_LR,
|
| 995 |
-
betas=(BETA1, BETA2),
|
| 996 |
-
weight_decay=WEIGHT_DECAY
|
| 997 |
-
)
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
def lr_lambda(current_step):
|
| 1001 |
-
if current_step < WARMUP_STEPS:
|
| 1002 |
-
return current_step / WARMUP_STEPS
|
| 1003 |
-
else:
|
| 1004 |
-
scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
|
| 1005 |
-
cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
|
| 1006 |
-
return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
# =============================================================================
|
| 1013 |
-
# 13. TRAINING LOOP WITH CONTRASTIVE LEARNING
|
| 1014 |
-
# =============================================================================
|
| 1015 |
-
|
| 1016 |
-
def train_epoch_contrastive(
|
| 1017 |
-
model,
|
| 1018 |
-
train_loaders,
|
| 1019 |
-
optimizer,
|
| 1020 |
-
scheduler,
|
| 1021 |
-
device,
|
| 1022 |
-
epoch,
|
| 1023 |
-
train_metadata
|
| 1024 |
-
):
|
| 1025 |
-
"""
|
| 1026 |
-
Train one epoch with MLM + Contrastive Learning with Gradient Accumulation.
|
| 1027 |
-
"""
|
| 1028 |
-
model.train()
|
| 1029 |
-
total_mlm_loss = 0.0
|
| 1030 |
-
total_contrastive_mod_loss = 0.0
|
| 1031 |
-
total_contrastive_mob_loss = 0.0
|
| 1032 |
-
total_loss = 0.0
|
| 1033 |
-
total_batches = 0
|
| 1034 |
-
|
| 1035 |
-
criterion = nn.MSELoss(reduction='sum')
|
| 1036 |
-
|
| 1037 |
-
# Initialize gradient accumulation
|
| 1038 |
-
optimizer.zero_grad()
|
| 1039 |
-
accumulation_counter = 0
|
| 1040 |
-
|
| 1041 |
-
for seq_key, loader in train_loaders.items():
|
| 1042 |
-
for batch_idx, batch in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}", leave=False)):
|
| 1043 |
-
# Unpack batch - expect (input_ids, masked_tokens, masked_pos, snr_db, doppler_id, power_stats, snr_id, modulation_id)
|
| 1044 |
-
if len(batch) >= 8:
|
| 1045 |
-
input_ids = batch[0].to(device)
|
| 1046 |
-
masked_tokens = batch[1].to(device)
|
| 1047 |
-
masked_pos = batch[2].to(device)
|
| 1048 |
-
snr_db = batch[3].to(device)
|
| 1049 |
-
doppler_id = batch[4].to(device)
|
| 1050 |
-
power_stats = batch[5].to(device)
|
| 1051 |
-
snr_id = batch[6].to(device)
|
| 1052 |
-
modulation_id = batch[7].to(device)
|
| 1053 |
-
has_metadata = True
|
| 1054 |
-
elif len(batch) == 3:
|
| 1055 |
-
input_ids = batch[0].to(device)
|
| 1056 |
-
masked_tokens = batch[1].to(device)
|
| 1057 |
-
masked_pos = batch[2].to(device)
|
| 1058 |
-
has_metadata = False
|
| 1059 |
-
else:
|
| 1060 |
-
input_ids = batch[0].to(device)
|
| 1061 |
-
has_metadata = False
|
| 1062 |
-
|
| 1063 |
-
# Forward pass with projections
|
| 1064 |
-
mlm_predictions, z_mod, z_mob = model(input_ids, return_projections=True)
|
| 1065 |
-
|
| 1066 |
-
# MLM Loss (reconstruction)
|
| 1067 |
-
if len(batch) >= 3 and masked_tokens.numel() > 0:
|
| 1068 |
-
batch_size = input_ids.size(0)
|
| 1069 |
-
mlm_loss = 0.0
|
| 1070 |
-
|
| 1071 |
-
for i in range(batch_size):
|
| 1072 |
-
# Get masked positions for this sample
|
| 1073 |
-
sample_masked_pos = masked_pos[i]
|
| 1074 |
-
sample_masked_tokens = masked_tokens[i]
|
| 1075 |
-
|
| 1076 |
-
# Skip if no masked positions
|
| 1077 |
-
if sample_masked_pos.numel() == 0:
|
| 1078 |
-
continue
|
| 1079 |
-
|
| 1080 |
-
# Get predictions at masked positions
|
| 1081 |
-
predictions = mlm_predictions[i, sample_masked_pos, :]
|
| 1082 |
-
targets = sample_masked_tokens
|
| 1083 |
-
|
| 1084 |
-
# Ensure shapes match
|
| 1085 |
-
if predictions.size(0) != targets.size(0):
|
| 1086 |
-
# Adjust if needed
|
| 1087 |
-
min_len = min(predictions.size(0), targets.size(0))
|
| 1088 |
-
predictions = predictions[:min_len]
|
| 1089 |
-
targets = targets[:min_len]
|
| 1090 |
-
|
| 1091 |
-
# MSE loss
|
| 1092 |
-
mlm_loss += criterion(predictions, targets)
|
| 1093 |
-
|
| 1094 |
-
mlm_loss = mlm_loss / batch_size if batch_size > 0 else torch.tensor(0.0, device=device)
|
| 1095 |
-
else:
|
| 1096 |
-
mlm_loss = torch.zeros(1, device=device)
|
| 1097 |
-
|
| 1098 |
-
# Contrastive losses (only if we have metadata)
|
| 1099 |
-
if has_metadata:
|
| 1100 |
-
# DEBUG: Print batch statistics
|
| 1101 |
-
if batch_idx == 0 and epoch == 0: # Only first batch of first epoch
|
| 1102 |
-
print(f"\n🔍 DEBUG - Batch analysis:")
|
| 1103 |
-
print(f" Batch size: {modulation_id.size(0)}")
|
| 1104 |
-
print(f" Modulation IDs: {modulation_id.cpu().numpy()}")
|
| 1105 |
-
print(f" Unique modulations: {torch.unique(modulation_id).cpu().numpy()}")
|
| 1106 |
-
print(f" Doppler IDs: {doppler_id.cpu().numpy()}")
|
| 1107 |
-
print(f" Unique doppler: {torch.unique(doppler_id).cpu().numpy()}")
|
| 1108 |
-
|
| 1109 |
-
# Modulation contrastive loss
|
| 1110 |
-
# Filter out unknown modulations (-1)
|
| 1111 |
-
valid_mod_mask = modulation_id >= 0
|
| 1112 |
-
if valid_mod_mask.sum() > 1: # Need at least 2 samples
|
| 1113 |
-
z_mod_valid = z_mod[valid_mod_mask]
|
| 1114 |
-
mod_labels_valid = modulation_id[valid_mod_mask]
|
| 1115 |
-
|
| 1116 |
-
# Check if we have positive pairs
|
| 1117 |
-
unique_mods, counts = torch.unique(mod_labels_valid, return_counts=True)
|
| 1118 |
-
has_positive_pairs = (counts > 1).any()
|
| 1119 |
-
|
| 1120 |
-
if has_positive_pairs:
|
| 1121 |
-
contrastive_mod_loss = supervised_contrastive_loss(
|
| 1122 |
-
z_mod_valid,
|
| 1123 |
-
mod_labels_valid,
|
| 1124 |
-
temperature=CONTRASTIVE_TEMPERATURE
|
| 1125 |
-
)
|
| 1126 |
-
if batch_idx == 0 and epoch == 0:
|
| 1127 |
-
print(f" Modulation contrastive loss: {contrastive_mod_loss.item():.4f}")
|
| 1128 |
-
else:
|
| 1129 |
-
contrastive_mod_loss = torch.zeros(1, device=device)
|
| 1130 |
-
if batch_idx == 0 and epoch == 0:
|
| 1131 |
-
print(f" No positive pairs for modulation - loss set to 0")
|
| 1132 |
-
else:
|
| 1133 |
-
contrastive_mod_loss = torch.zeros(1, device=device)
|
| 1134 |
-
if batch_idx == 0 and epoch == 0:
|
| 1135 |
-
print(f" Not enough valid modulation samples - loss set to 0")
|
| 1136 |
-
|
| 1137 |
-
# Mobility contrastive loss
|
| 1138 |
-
z_mob_valid = z_mob
|
| 1139 |
-
mob_labels_valid = doppler_id
|
| 1140 |
-
if mob_labels_valid.numel() > 1:
|
| 1141 |
-
unique_mobs, counts = torch.unique(mob_labels_valid, return_counts=True)
|
| 1142 |
-
has_positive_pairs = (counts > 1).any()
|
| 1143 |
-
|
| 1144 |
-
if has_positive_pairs:
|
| 1145 |
-
contrastive_mob_loss = supervised_contrastive_loss(
|
| 1146 |
-
z_mob_valid,
|
| 1147 |
-
mob_labels_valid,
|
| 1148 |
-
temperature=CONTRASTIVE_TEMPERATURE
|
| 1149 |
-
)
|
| 1150 |
-
if batch_idx == 0 and epoch == 0:
|
| 1151 |
-
print(f" Mobility contrastive loss: {contrastive_mob_loss.item():.4f}")
|
| 1152 |
-
else:
|
| 1153 |
-
contrastive_mob_loss = torch.zeros(1, device=device)
|
| 1154 |
-
if batch_idx == 0 and epoch == 0:
|
| 1155 |
-
print(f" No positive pairs for mobility - loss set to 0")
|
| 1156 |
-
else:
|
| 1157 |
-
contrastive_mob_loss = torch.zeros(1, device=device)
|
| 1158 |
-
if batch_idx == 0 and epoch == 0:
|
| 1159 |
-
print(f" Not enough mobility samples - loss set to 0")
|
| 1160 |
-
else:
|
| 1161 |
-
contrastive_mod_loss = torch.zeros(1, device=device)
|
| 1162 |
-
contrastive_mob_loss = torch.zeros(1, device=device)
|
| 1163 |
-
|
| 1164 |
-
# Combined loss
|
| 1165 |
-
loss = (
|
| 1166 |
-
MLM_WEIGHT * mlm_loss +
|
| 1167 |
-
CONTRASTIVE_WEIGHT_MODULATION * contrastive_mod_loss +
|
| 1168 |
-
CONTRASTIVE_WEIGHT_MOBILITY * contrastive_mob_loss
|
| 1169 |
-
)
|
| 1170 |
-
|
| 1171 |
-
# Normalize loss by accumulation steps
|
| 1172 |
-
loss = loss / ACCUMULATION_STEPS
|
| 1173 |
-
|
| 1174 |
-
# Backward pass (accumulate gradients)
|
| 1175 |
-
loss.backward()
|
| 1176 |
-
|
| 1177 |
-
# Accumulate losses (denormalized for logging)
|
| 1178 |
-
total_mlm_loss += mlm_loss.item()
|
| 1179 |
-
total_contrastive_mod_loss += contrastive_mod_loss.item()
|
| 1180 |
-
total_contrastive_mob_loss += contrastive_mob_loss.item()
|
| 1181 |
-
total_loss += (loss.item() * ACCUMULATION_STEPS) # Denormalize for logging
|
| 1182 |
-
total_batches += 1
|
| 1183 |
-
accumulation_counter += 1
|
| 1184 |
-
|
| 1185 |
-
# Perform optimizer step every ACCUMULATION_STEPS
|
| 1186 |
-
if accumulation_counter % ACCUMULATION_STEPS == 0:
|
| 1187 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 1188 |
-
optimizer.step()
|
| 1189 |
-
scheduler.step()
|
| 1190 |
-
optimizer.zero_grad()
|
| 1191 |
-
accumulation_counter = 0
|
| 1192 |
-
|
| 1193 |
-
# Handle remaining gradients if total batches not divisible by ACCUMULATION_STEPS
|
| 1194 |
-
if accumulation_counter > 0:
|
| 1195 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 1196 |
-
optimizer.step()
|
| 1197 |
-
scheduler.step()
|
| 1198 |
-
optimizer.zero_grad()
|
| 1199 |
-
|
| 1200 |
-
# Average losses
|
| 1201 |
-
avg_mlm_loss = total_mlm_loss / total_batches if total_batches > 0 else 0
|
| 1202 |
-
avg_contrastive_mod_loss = total_contrastive_mod_loss / total_batches if total_batches > 0 else 0
|
| 1203 |
-
avg_contrastive_mob_loss = total_contrastive_mob_loss / total_batches if total_batches > 0 else 0
|
| 1204 |
-
avg_total_loss = total_loss / total_batches if total_batches > 0 else 0
|
| 1205 |
-
|
| 1206 |
-
return {
|
| 1207 |
-
'mlm_loss': avg_mlm_loss,
|
| 1208 |
-
'contrastive_mod_loss': avg_contrastive_mod_loss,
|
| 1209 |
-
'contrastive_mob_loss': avg_contrastive_mob_loss,
|
| 1210 |
-
'total_loss': avg_total_loss
|
| 1211 |
-
}
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
def validate_epoch_contrastive(
|
| 1215 |
-
model,
|
| 1216 |
-
val_loaders,
|
| 1217 |
-
device,
|
| 1218 |
-
epoch
|
| 1219 |
-
):
|
| 1220 |
-
"""
|
| 1221 |
-
Validate one epoch with MLM + Contrastive Learning.
|
| 1222 |
-
"""
|
| 1223 |
-
model.eval()
|
| 1224 |
-
total_mlm_loss = 0.0
|
| 1225 |
-
total_contrastive_mod_loss = 0.0
|
| 1226 |
-
total_contrastive_mob_loss = 0.0
|
| 1227 |
-
total_loss = 0.0
|
| 1228 |
-
total_batches = 0
|
| 1229 |
-
|
| 1230 |
-
criterion = nn.MSELoss(reduction='sum')
|
| 1231 |
-
|
| 1232 |
-
with torch.no_grad():
|
| 1233 |
-
for seq_key, loader in val_loaders.items():
|
| 1234 |
-
for batch_idx, batch in enumerate(loader):
|
| 1235 |
-
# Unpack batch
|
| 1236 |
-
if len(batch) >= 8:
|
| 1237 |
-
input_ids = batch[0].to(device)
|
| 1238 |
-
masked_tokens = batch[1].to(device)
|
| 1239 |
-
masked_pos = batch[2].to(device)
|
| 1240 |
-
snr_db = batch[3].to(device)
|
| 1241 |
-
doppler_id = batch[4].to(device)
|
| 1242 |
-
power_stats = batch[5].to(device)
|
| 1243 |
-
snr_id = batch[6].to(device)
|
| 1244 |
-
modulation_id = batch[7].to(device)
|
| 1245 |
-
has_metadata = True
|
| 1246 |
-
elif len(batch) == 3:
|
| 1247 |
-
input_ids = batch[0].to(device)
|
| 1248 |
-
masked_tokens = batch[1].to(device)
|
| 1249 |
-
masked_pos = batch[2].to(device)
|
| 1250 |
-
has_metadata = False
|
| 1251 |
-
else:
|
| 1252 |
-
input_ids = batch[0].to(device)
|
| 1253 |
-
has_metadata = False
|
| 1254 |
-
|
| 1255 |
-
# Forward pass
|
| 1256 |
-
mlm_predictions, z_mod, z_mob = model(input_ids, return_projections=True)
|
| 1257 |
-
|
| 1258 |
-
# MLM Loss
|
| 1259 |
-
if len(batch) >= 3 and masked_tokens.numel() > 0:
|
| 1260 |
-
batch_size = input_ids.size(0)
|
| 1261 |
-
mlm_loss = 0.0
|
| 1262 |
-
|
| 1263 |
-
for i in range(batch_size):
|
| 1264 |
-
sample_masked_pos = masked_pos[i]
|
| 1265 |
-
sample_masked_tokens = masked_tokens[i]
|
| 1266 |
-
|
| 1267 |
-
if sample_masked_pos.numel() == 0:
|
| 1268 |
-
continue
|
| 1269 |
-
|
| 1270 |
-
predictions = mlm_predictions[i, sample_masked_pos, :]
|
| 1271 |
-
targets = sample_masked_tokens
|
| 1272 |
-
|
| 1273 |
-
if predictions.size(0) != targets.size(0):
|
| 1274 |
-
min_len = min(predictions.size(0), targets.size(0))
|
| 1275 |
-
predictions = predictions[:min_len]
|
| 1276 |
-
targets = targets[:min_len]
|
| 1277 |
-
|
| 1278 |
-
mlm_loss += criterion(predictions, targets)
|
| 1279 |
-
|
| 1280 |
-
mlm_loss = mlm_loss / batch_size if batch_size > 0 else torch.tensor(0.0, device=device)
|
| 1281 |
-
else:
|
| 1282 |
-
mlm_loss = torch.zeros(1, device=device)
|
| 1283 |
-
|
| 1284 |
-
# Contrastive losses
|
| 1285 |
-
if has_metadata:
|
| 1286 |
-
valid_mod_mask = modulation_id >= 0
|
| 1287 |
-
if valid_mod_mask.sum() > 1:
|
| 1288 |
-
z_mod_valid = z_mod[valid_mod_mask]
|
| 1289 |
-
mod_labels_valid = modulation_id[valid_mod_mask]
|
| 1290 |
-
contrastive_mod_loss = supervised_contrastive_loss(
|
| 1291 |
-
z_mod_valid,
|
| 1292 |
-
mod_labels_valid,
|
| 1293 |
-
temperature=CONTRASTIVE_TEMPERATURE
|
| 1294 |
-
)
|
| 1295 |
-
else:
|
| 1296 |
-
contrastive_mod_loss = torch.zeros(1, device=device)
|
| 1297 |
-
|
| 1298 |
-
if doppler_id.numel() > 1:
|
| 1299 |
-
contrastive_mob_loss = supervised_contrastive_loss(
|
| 1300 |
-
z_mob,
|
| 1301 |
-
doppler_id,
|
| 1302 |
-
temperature=CONTRASTIVE_TEMPERATURE
|
| 1303 |
-
)
|
| 1304 |
-
else:
|
| 1305 |
-
contrastive_mob_loss = torch.zeros(1, device=device)
|
| 1306 |
-
else:
|
| 1307 |
-
contrastive_mod_loss = torch.zeros(1, device=device)
|
| 1308 |
-
contrastive_mob_loss = torch.zeros(1, device=device)
|
| 1309 |
-
|
| 1310 |
-
loss = (
|
| 1311 |
-
MLM_WEIGHT * mlm_loss +
|
| 1312 |
-
CONTRASTIVE_WEIGHT_MODULATION * contrastive_mod_loss +
|
| 1313 |
-
CONTRASTIVE_WEIGHT_MOBILITY * contrastive_mob_loss
|
| 1314 |
-
)
|
| 1315 |
-
|
| 1316 |
-
total_mlm_loss += mlm_loss.item()
|
| 1317 |
-
total_contrastive_mod_loss += contrastive_mod_loss.item()
|
| 1318 |
-
total_contrastive_mob_loss += contrastive_mob_loss.item()
|
| 1319 |
-
total_loss += loss.item()
|
| 1320 |
-
total_batches += 1
|
| 1321 |
-
|
| 1322 |
-
avg_mlm_loss = total_mlm_loss / total_batches if total_batches > 0 else 0
|
| 1323 |
-
avg_contrastive_mod_loss = total_contrastive_mod_loss / total_batches if total_batches > 0 else 0
|
| 1324 |
-
avg_contrastive_mob_loss = total_contrastive_mob_loss / total_batches if total_batches > 0 else 0
|
| 1325 |
-
avg_total_loss = total_loss / total_batches if total_batches > 0 else 0
|
| 1326 |
-
|
| 1327 |
-
return {
|
| 1328 |
-
'mlm_loss': avg_mlm_loss,
|
| 1329 |
-
'contrastive_mod_loss': avg_contrastive_mod_loss,
|
| 1330 |
-
'contrastive_mob_loss': avg_contrastive_mob_loss,
|
| 1331 |
-
'total_loss': avg_total_loss
|
| 1332 |
-
}
|
| 1333 |
-
|
| 1334 |
-
|
| 1335 |
-
# =============================================================================
|
| 1336 |
-
# 14. MAIN TRAINING LOOP
|
| 1337 |
-
# =============================================================================
|
| 1338 |
-
|
| 1339 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 1340 |
-
save_dir = f"models/{timestamp}_contrastive"
|
| 1341 |
-
print(f"📁 Models and logs will be saved to: {save_dir}")
|
| 1342 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 1343 |
-
|
| 1344 |
-
stats_path = os.path.join(save_dir, "dataset_stats.json")
|
| 1345 |
-
with open(stats_path, 'w') as f:
|
| 1346 |
-
json.dump(dataset_normalization, f, indent=2)
|
| 1347 |
-
print(f"📝 Saved dataset stats to {stats_path}")
|
| 1348 |
-
|
| 1349 |
-
# Save training configuration
|
| 1350 |
-
config = {
|
| 1351 |
-
'epochs': EPOCHS,
|
| 1352 |
-
'batch_size': BATCH_SIZE,
|
| 1353 |
-
'effective_batch_size': BATCH_SIZE * ACCUMULATION_STEPS,
|
| 1354 |
-
'accumulation_steps': ACCUMULATION_STEPS,
|
| 1355 |
-
'learning_rate': BASE_LR,
|
| 1356 |
-
'element_length': ELEMENT_LENGTH,
|
| 1357 |
-
'd_model': D_MODEL,
|
| 1358 |
-
'n_layers': N_LAYERS,
|
| 1359 |
-
'n_heads': N_HEADS,
|
| 1360 |
-
'projection_dim': PROJECTION_DIM,
|
| 1361 |
-
'contrastive_temperature': CONTRASTIVE_TEMPERATURE,
|
| 1362 |
-
'mlm_weight': MLM_WEIGHT,
|
| 1363 |
-
'contrastive_weight_modulation': CONTRASTIVE_WEIGHT_MODULATION,
|
| 1364 |
-
'contrastive_weight_mobility': CONTRASTIVE_WEIGHT_MOBILITY,
|
| 1365 |
-
'modulation_map': MODULATION_MAP,
|
| 1366 |
-
'doppler_map': DOPPLER_MAP,
|
| 1367 |
-
'num_modulations': len(MODULATION_MAP),
|
| 1368 |
-
}
|
| 1369 |
-
config_path = os.path.join(save_dir, "config.json")
|
| 1370 |
-
with open(config_path, 'w') as f:
|
| 1371 |
-
json.dump(config, f, indent=2)
|
| 1372 |
-
print(f"📝 Saved training config to {config_path}")
|
| 1373 |
-
|
| 1374 |
-
# Training log
|
| 1375 |
-
log_path = os.path.join(save_dir, "training_log.csv")
|
| 1376 |
-
with open(log_path, 'w') as f:
|
| 1377 |
-
f.write("epoch,train_mlm_loss,train_contrastive_mod_loss,train_contrastive_mob_loss,train_total_loss,")
|
| 1378 |
-
f.write("val_mlm_loss,val_contrastive_mod_loss,val_contrastive_mob_loss,val_total_loss,learning_rate\n")
|
| 1379 |
-
|
| 1380 |
-
print("\n" + "="*80)
|
| 1381 |
-
print("🚀 Starting training with contrastive learning!")
|
| 1382 |
-
print("="*80 + "\n")
|
| 1383 |
-
|
| 1384 |
-
if __name__ == "__main__":
|
| 1385 |
-
best_val_loss = float('inf')
|
| 1386 |
-
|
| 1387 |
-
for epoch in range(EPOCHS):
|
| 1388 |
-
print(f"\n{'='*80}")
|
| 1389 |
-
print(f"Epoch {epoch+1}/{EPOCHS}")
|
| 1390 |
-
print(f"{'='*80}")
|
| 1391 |
-
|
| 1392 |
-
# Train
|
| 1393 |
-
train_metrics = train_epoch_contrastive(
|
| 1394 |
-
model, train_loaders, optimizer, scheduler, device, epoch, train_metadata
|
| 1395 |
-
)
|
| 1396 |
-
|
| 1397 |
-
# Validate
|
| 1398 |
-
val_metrics = validate_epoch_contrastive(
|
| 1399 |
-
model, val_loaders, device, epoch
|
| 1400 |
-
)
|
| 1401 |
-
|
| 1402 |
-
# Log metrics
|
| 1403 |
-
current_lr = optimizer.param_groups[0]['lr']
|
| 1404 |
-
print(f"\nEpoch {epoch+1} Results:")
|
| 1405 |
-
print(f" Train - MLM: {train_metrics['mlm_loss']:.4f}, "
|
| 1406 |
-
f"ContrastMod: {train_metrics['contrastive_mod_loss']:.4f}, "
|
| 1407 |
-
f"ContrastMob: {train_metrics['contrastive_mob_loss']:.4f}, "
|
| 1408 |
-
f"Total: {train_metrics['total_loss']:.4f}")
|
| 1409 |
-
print(f" Val - MLM: {val_metrics['mlm_loss']:.4f}, "
|
| 1410 |
-
f"ContrastMod: {val_metrics['contrastive_mod_loss']:.4f}, "
|
| 1411 |
-
f"ContrastMob: {val_metrics['contrastive_mob_loss']:.4f}, "
|
| 1412 |
-
f"Total: {val_metrics['total_loss']:.4f}")
|
| 1413 |
-
print(f" Learning Rate: {current_lr:.6f}")
|
| 1414 |
-
|
| 1415 |
-
# Save to log
|
| 1416 |
-
with open(log_path, 'a') as f:
|
| 1417 |
-
f.write(f"{epoch+1},{train_metrics['mlm_loss']:.6f},"
|
| 1418 |
-
f"{train_metrics['contrastive_mod_loss']:.6f},"
|
| 1419 |
-
f"{train_metrics['contrastive_mob_loss']:.6f},"
|
| 1420 |
-
f"{train_metrics['total_loss']:.6f},"
|
| 1421 |
-
f"{val_metrics['mlm_loss']:.6f},"
|
| 1422 |
-
f"{val_metrics['contrastive_mod_loss']:.6f},"
|
| 1423 |
-
f"{val_metrics['contrastive_mob_loss']:.6f},"
|
| 1424 |
-
f"{val_metrics['total_loss']:.6f},"
|
| 1425 |
-
f"{current_lr:.8f}\n")
|
| 1426 |
-
|
| 1427 |
-
# Save best model
|
| 1428 |
-
if val_metrics['total_loss'] < best_val_loss:
|
| 1429 |
-
best_val_loss = val_metrics['total_loss']
|
| 1430 |
-
checkpoint_path = os.path.join(save_dir, "best_model_contrastive.pth")
|
| 1431 |
-
if isinstance(model, nn.DataParallel):
|
| 1432 |
-
torch.save(model.module.state_dict(), checkpoint_path)
|
| 1433 |
-
else:
|
| 1434 |
-
torch.save(model.state_dict(), checkpoint_path)
|
| 1435 |
-
print(f" ✅ Saved best model to {checkpoint_path}")
|
| 1436 |
-
|
| 1437 |
-
# Save periodic checkpoint
|
| 1438 |
-
if (epoch + 1) % 5 == 0:
|
| 1439 |
-
checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch{epoch+1}_contrastive.pth")
|
| 1440 |
-
if isinstance(model, nn.DataParallel):
|
| 1441 |
-
torch.save(model.module.state_dict(), checkpoint_path)
|
| 1442 |
-
else:
|
| 1443 |
-
torch.save(model.state_dict(), checkpoint_path)
|
| 1444 |
-
print(f" 💾 Saved checkpoint to {checkpoint_path}")
|
| 1445 |
-
|
| 1446 |
-
print("\n" + "="*80)
|
| 1447 |
-
print("🎉 Training completed successfully!")
|
| 1448 |
-
print(f"📁 Models saved to: {save_dir}")
|
| 1449 |
-
print(f"📊 Training log: {log_path}")
|
| 1450 |
-
print("="*80 + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretraining/train_lwm_spectro_no_contrast.py
DELETED
|
@@ -1,1136 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# =============================================================================
|
| 3 |
-
# 1. IMPORTS AND WARNINGS SETUP
|
| 4 |
-
# - Load necessary PyTorch modules, utilities, and suppress UserWarnings
|
| 5 |
-
# =============================================================================
|
| 6 |
-
import sys
|
| 7 |
-
import os
|
| 8 |
-
import argparse
|
| 9 |
-
# Add project root to path (Windows compatible)
|
| 10 |
-
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
-
sys.path.insert(0, project_root)
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn as nn
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
from torch.utils.data import DataLoader, IterableDataset
|
| 16 |
-
import torch.distributed as dist
|
| 17 |
-
import torch.optim as optim
|
| 18 |
-
from utils import (generate_spectrograms_and_labels, tokenizer_train,
|
| 19 |
-
count_parameters, train_lwm)
|
| 20 |
-
import numpy as np
|
| 21 |
-
import pretrained_model # Assuming this contains the LWM model definition
|
| 22 |
-
from torch.optim.lr_scheduler import LambdaLR
|
| 23 |
-
from torch.optim import AdamW
|
| 24 |
-
import warnings
|
| 25 |
-
import platform
|
| 26 |
-
import re
|
| 27 |
-
from tqdm import tqdm
|
| 28 |
-
from datetime import datetime
|
| 29 |
-
import concurrent.futures
|
| 30 |
-
import multiprocessing
|
| 31 |
-
from collections import Counter
|
| 32 |
-
from functools import lru_cache
|
| 33 |
-
import json
|
| 34 |
-
import random
|
| 35 |
-
import math
|
| 36 |
-
from typing import Any, Dict, Optional, List, Tuple
|
| 37 |
-
import time
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
LOG_ALL_RANKS = False
|
| 41 |
-
|
| 42 |
-
SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
|
| 43 |
-
DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
|
| 44 |
-
DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def _is_hpu_available() -> bool:
|
| 48 |
-
hpu = getattr(torch, "hpu", None)
|
| 49 |
-
if hpu is None:
|
| 50 |
-
return False
|
| 51 |
-
is_available = getattr(hpu, "is_available", None)
|
| 52 |
-
available = False
|
| 53 |
-
if callable(is_available):
|
| 54 |
-
try:
|
| 55 |
-
available = bool(is_available())
|
| 56 |
-
except Exception:
|
| 57 |
-
available = False
|
| 58 |
-
if not available:
|
| 59 |
-
# Try initializing the Habana runtime lazily
|
| 60 |
-
try:
|
| 61 |
-
import habana_frameworks.torch.core as htcore # type: ignore
|
| 62 |
-
|
| 63 |
-
init_fn = getattr(htcore, "hpu_initialize", None)
|
| 64 |
-
if callable(init_fn):
|
| 65 |
-
init_fn()
|
| 66 |
-
else:
|
| 67 |
-
inference_init = getattr(htcore, "hpu_inference_initialize", None)
|
| 68 |
-
if callable(inference_init):
|
| 69 |
-
inference_init()
|
| 70 |
-
available = bool(is_available())
|
| 71 |
-
except Exception:
|
| 72 |
-
available = False
|
| 73 |
-
return available
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def _get_hpu_device_count() -> int:
|
| 77 |
-
hpu = getattr(torch, "hpu", None)
|
| 78 |
-
if hpu is None:
|
| 79 |
-
return 0
|
| 80 |
-
device_count_fn = getattr(hpu, "device_count", None)
|
| 81 |
-
if callable(device_count_fn):
|
| 82 |
-
try:
|
| 83 |
-
return int(device_count_fn())
|
| 84 |
-
except Exception:
|
| 85 |
-
return 0
|
| 86 |
-
return 1 if _is_hpu_available() else 0
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def _initialize_distributed(hpu_available: bool, backend_override: Optional[str] = None) -> Dict[str, Any]:
|
| 90 |
-
context: Dict[str, Any] = {
|
| 91 |
-
"is_distributed": False,
|
| 92 |
-
"backend": None,
|
| 93 |
-
"rank": 0,
|
| 94 |
-
"world_size": 1,
|
| 95 |
-
"local_rank": 0,
|
| 96 |
-
"is_primary": True,
|
| 97 |
-
}
|
| 98 |
-
if not dist.is_available():
|
| 99 |
-
return context
|
| 100 |
-
|
| 101 |
-
required_env = ("RANK", "WORLD_SIZE")
|
| 102 |
-
if not all(key in os.environ for key in required_env):
|
| 103 |
-
return context
|
| 104 |
-
|
| 105 |
-
if dist.is_initialized():
|
| 106 |
-
context["is_distributed"] = True
|
| 107 |
-
context["backend"] = dist.get_backend()
|
| 108 |
-
context["rank"] = dist.get_rank()
|
| 109 |
-
context["world_size"] = dist.get_world_size()
|
| 110 |
-
context["local_rank"] = int(os.environ.get("LOCAL_RANK", context["rank"]))
|
| 111 |
-
context["is_primary"] = context["rank"] == 0
|
| 112 |
-
return context
|
| 113 |
-
|
| 114 |
-
backend = backend_override or os.environ.get("LWM_DISTRIBUTED_BACKEND")
|
| 115 |
-
if not backend:
|
| 116 |
-
if hpu_available:
|
| 117 |
-
backend = "hccl"
|
| 118 |
-
elif torch.cuda.is_available():
|
| 119 |
-
backend = "nccl"
|
| 120 |
-
else:
|
| 121 |
-
backend = "gloo"
|
| 122 |
-
|
| 123 |
-
dist.init_process_group(backend=backend, init_method="env://")
|
| 124 |
-
|
| 125 |
-
context["is_distributed"] = True
|
| 126 |
-
context["backend"] = backend
|
| 127 |
-
context["rank"] = dist.get_rank()
|
| 128 |
-
context["world_size"] = dist.get_world_size()
|
| 129 |
-
context["local_rank"] = int(os.environ.get("LOCAL_RANK", context["rank"]))
|
| 130 |
-
context["is_primary"] = context["rank"] == 0
|
| 131 |
-
return context
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def _broadcast_object(obj: Any, src: int = 0) -> Any:
|
| 135 |
-
if not dist.is_available() or not dist.is_initialized():
|
| 136 |
-
return obj
|
| 137 |
-
object_list = [obj]
|
| 138 |
-
dist.broadcast_object_list(object_list, src=src)
|
| 139 |
-
return object_list[0]
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def _should_log(context: Dict[str, Any]) -> bool:
|
| 143 |
-
return LOG_ALL_RANKS or (not context.get("is_distributed")) or context.get("is_primary", True)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def _barrier(context: Dict[str, Any]) -> None:
|
| 147 |
-
if context.get("is_distributed") and dist.is_available() and dist.is_initialized():
|
| 148 |
-
dist.barrier()
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def _parse_snr_and_doppler(path: str) -> tuple[float, int]:
|
| 152 |
-
snr_db = 0.0
|
| 153 |
-
doppler_id = 0
|
| 154 |
-
|
| 155 |
-
matches = SNR_PATTERN.findall(path)
|
| 156 |
-
if matches:
|
| 157 |
-
try:
|
| 158 |
-
snr_db = float(matches[-1])
|
| 159 |
-
except ValueError:
|
| 160 |
-
snr_db = 0.0
|
| 161 |
-
|
| 162 |
-
normalized_path = os.path.normpath(path)
|
| 163 |
-
parts = normalized_path.split(os.sep)
|
| 164 |
-
for part in parts:
|
| 165 |
-
if part in DOPPLER_MAP:
|
| 166 |
-
doppler_id = DOPPLER_MAP[part]
|
| 167 |
-
break
|
| 168 |
-
|
| 169 |
-
return snr_db, doppler_id
|
| 170 |
-
|
| 171 |
-
def _parse_runtime_args():
|
| 172 |
-
parser = argparse.ArgumentParser(add_help=False)
|
| 173 |
-
parser.add_argument(
|
| 174 |
-
"--device",
|
| 175 |
-
default=os.environ.get("LWM_DEVICE", "auto"),
|
| 176 |
-
choices=("auto", "cpu", "cuda", "hpu", "mps"),
|
| 177 |
-
help="Select accelerator device (default: auto)."
|
| 178 |
-
)
|
| 179 |
-
parser.add_argument(
|
| 180 |
-
"--dist-backend",
|
| 181 |
-
dest="dist_backend",
|
| 182 |
-
default=os.environ.get("LWM_DIST_BACKEND"),
|
| 183 |
-
help="Override torch.distributed backend."
|
| 184 |
-
)
|
| 185 |
-
parser.add_argument(
|
| 186 |
-
"--log-all-ranks",
|
| 187 |
-
action="store_true",
|
| 188 |
-
help="If set, every rank prints logs instead of rank 0 only."
|
| 189 |
-
)
|
| 190 |
-
args, remaining = parser.parse_known_args()
|
| 191 |
-
sys.argv = [sys.argv[0]] + remaining
|
| 192 |
-
return args
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 196 |
-
|
| 197 |
-
RUNTIME_ARGS = _parse_runtime_args()
|
| 198 |
-
if getattr(RUNTIME_ARGS, "dist_backend", None) and RUNTIME_ARGS.dist_backend not in {"gloo", "nccl", "hccl"}:
|
| 199 |
-
raise ValueError(f"Unsupported dist backend override: {RUNTIME_ARGS.dist_backend}")
|
| 200 |
-
LOG_ALL_RANKS = bool(getattr(RUNTIME_ARGS, "log_all_ranks", False))
|
| 201 |
-
|
| 202 |
-
TRAIN_SPLIT_FRACTION = 0.8
|
| 203 |
-
VAL_SPLIT_FRACTION = 1.0 - TRAIN_SPLIT_FRACTION
|
| 204 |
-
DEFAULT_SAMPLES_PER_SCENARIO = int(os.environ.get("LWM_SAMPLES_PER_SCENARIO", "1000"))
|
| 205 |
-
|
| 206 |
-
# Use simple progress display instead of tqdm on Windows
|
| 207 |
-
USE_TQDM = platform.system() != 'Windows'
|
| 208 |
-
|
| 209 |
-
HPU_AVAILABLE = _is_hpu_available()
|
| 210 |
-
distributed_context = _initialize_distributed(HPU_AVAILABLE, backend_override=getattr(RUNTIME_ARGS, "dist_backend", None))
|
| 211 |
-
LOG_PRIMARY = _should_log(distributed_context)
|
| 212 |
-
HPU_DEBUG_LOG = os.environ.get("LWM_DEBUG_HPU_INIT", "").lower() in {"1", "true", "yes"}
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
def _debug_hpu(message: str) -> None:
|
| 216 |
-
if not HPU_DEBUG_LOG:
|
| 217 |
-
return
|
| 218 |
-
rank = distributed_context.get("rank", 0)
|
| 219 |
-
print(f"[HPU-DEBUG rank {rank}] {message}", flush=True)
|
| 220 |
-
|
| 221 |
-
if distributed_context["is_distributed"] and LOG_PRIMARY:
|
| 222 |
-
print(
|
| 223 |
-
f"🔗 Distributed initialized -> backend={distributed_context['backend']}, "
|
| 224 |
-
f"world_size={distributed_context['world_size']}, rank={distributed_context['rank']}"
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
# CPU 코어 수 계산 (메모리 사용량 고려하여 보수적으로 설정)
|
| 228 |
-
total_cores = multiprocessing.cpu_count()
|
| 229 |
-
if total_cores >= 16:
|
| 230 |
-
MAX_WORKERS = min(8, total_cores // 2) # 고성능 서버의 경우 8코어로 제한
|
| 231 |
-
else:
|
| 232 |
-
MAX_WORKERS = max(2, total_cores // 2) # 일반 시스템의 경우 절반 사용
|
| 233 |
-
if LOG_PRIMARY:
|
| 234 |
-
print(f"🚀 Using {MAX_WORKERS}/{total_cores} CPU cores for parallel processing")
|
| 235 |
-
|
| 236 |
-
def process_single_scenario(scenario_info):
|
| 237 |
-
"""단일 시나리오를 처리하는 함수 (멀티프로세싱용)"""
|
| 238 |
-
scenario_name, spectrogram_path = scenario_info
|
| 239 |
-
|
| 240 |
-
try:
|
| 241 |
-
# 메모리 효율성을 위해 필요한 데이터만 로드
|
| 242 |
-
scenario_spectrograms, scenario_labels = generate_spectrograms_and_labels(
|
| 243 |
-
scenario_name=scenario_name,
|
| 244 |
-
spectrogram_path=spectrogram_path,
|
| 245 |
-
cache_path=None, # 메모리 문제로 캐시 비활성화
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
snr_db, doppler_id = _parse_snr_and_doppler(spectrogram_path)
|
| 249 |
-
|
| 250 |
-
# 데이터 분할 (인덱스만 계산)
|
| 251 |
-
total_samples = len(scenario_spectrograms)
|
| 252 |
-
train_size = int(TRAIN_SPLIT_FRACTION * total_samples)
|
| 253 |
-
val_size = total_samples - train_size
|
| 254 |
-
|
| 255 |
-
# 메모리 절약을 위해 numpy array로 유지 (필요할 때만 tensor로 변환)
|
| 256 |
-
train_data = np.array(scenario_spectrograms[:train_size], dtype=np.float32)
|
| 257 |
-
val_data = np.array(scenario_spectrograms[train_size:], dtype=np.float32)
|
| 258 |
-
|
| 259 |
-
snr_array = np.full(total_samples, snr_db, dtype=np.float32)
|
| 260 |
-
doppler_array = np.full(total_samples, doppler_id, dtype=np.int64)
|
| 261 |
-
train_meta = {
|
| 262 |
-
'snr_db': snr_array[:train_size],
|
| 263 |
-
'doppler_id': doppler_array[:train_size],
|
| 264 |
-
}
|
| 265 |
-
val_meta = {
|
| 266 |
-
'snr_db': snr_array[train_size:],
|
| 267 |
-
'doppler_id': doppler_array[train_size:],
|
| 268 |
-
}
|
| 269 |
-
|
| 270 |
-
# 불필요한 데이터 즉시 삭제
|
| 271 |
-
del scenario_spectrograms
|
| 272 |
-
|
| 273 |
-
return {
|
| 274 |
-
'scenario': scenario_name,
|
| 275 |
-
'train_data': train_data,
|
| 276 |
-
'val_data': val_data,
|
| 277 |
-
'train_meta': train_meta,
|
| 278 |
-
'val_meta': val_meta,
|
| 279 |
-
'train_size': len(train_data),
|
| 280 |
-
'val_size': len(val_data)
|
| 281 |
-
}
|
| 282 |
-
except Exception as e:
|
| 283 |
-
context = globals().get("distributed_context", {})
|
| 284 |
-
if LOG_PRIMARY or not context.get("is_distributed", False):
|
| 285 |
-
print(f"❌ Error processing scenario {scenario_name}: {e}")
|
| 286 |
-
return None
|
| 287 |
-
|
| 288 |
-
# GPU Memory Monitor import (for Lambda) - Removed
|
| 289 |
-
|
| 290 |
-
class StreamingMaskedSpectrogramDataset(IterableDataset):
|
| 291 |
-
"""Stream spectrogram samples scenario-by-scenario to limit peak memory usage."""
|
| 292 |
-
|
| 293 |
-
def __init__(
|
| 294 |
-
self,
|
| 295 |
-
scenario_info_list,
|
| 296 |
-
split,
|
| 297 |
-
normalization_mode,
|
| 298 |
-
dataset_stats,
|
| 299 |
-
mask_percent,
|
| 300 |
-
max_len,
|
| 301 |
-
seed=42,
|
| 302 |
-
shuffle=True,
|
| 303 |
-
rank: int = 0,
|
| 304 |
-
world_size: int = 1,
|
| 305 |
-
):
|
| 306 |
-
super().__init__()
|
| 307 |
-
if split not in {"train", "val"}:
|
| 308 |
-
raise ValueError(f"Unsupported split '{split}'. Expected 'train' or 'val'.")
|
| 309 |
-
self.scenario_info_list = list(scenario_info_list)
|
| 310 |
-
self.split = split
|
| 311 |
-
self.normalization_mode = normalization_mode
|
| 312 |
-
self.dataset_stats = dataset_stats or {'mean': 0.0, 'std': 1.0, 'normalization': normalization_mode}
|
| 313 |
-
self.mask_percent = mask_percent
|
| 314 |
-
self.max_len = max_len
|
| 315 |
-
self.seed = seed
|
| 316 |
-
self.shuffle = shuffle
|
| 317 |
-
self._epoch = 0
|
| 318 |
-
self.num_samples = 0 # Populated after dataset summary
|
| 319 |
-
self.rank = rank
|
| 320 |
-
self.world_size = max(1, world_size)
|
| 321 |
-
|
| 322 |
-
def _format_sample(self, sample_dict):
|
| 323 |
-
input_ids = torch.from_numpy(sample_dict['input_ids']).float()
|
| 324 |
-
masked_tokens = torch.from_numpy(sample_dict['masked_tokens']).float()
|
| 325 |
-
masked_pos = torch.from_numpy(sample_dict['masked_pos']).long()
|
| 326 |
-
snr_db = torch.tensor(sample_dict.get('snr_db', 0.0), dtype=torch.float32)
|
| 327 |
-
doppler_id = torch.tensor(sample_dict.get('doppler_id', 0), dtype=torch.long)
|
| 328 |
-
power_stats = torch.tensor(sample_dict.get('power_stats', np.zeros(2, dtype=np.float32)), dtype=torch.float32)
|
| 329 |
-
snr_id = torch.tensor(sample_dict.get('snr_id', -1), dtype=torch.long)
|
| 330 |
-
modulation_id = torch.tensor(sample_dict.get('modulation_id', -1), dtype=torch.long)
|
| 331 |
-
return (
|
| 332 |
-
input_ids,
|
| 333 |
-
masked_tokens,
|
| 334 |
-
masked_pos,
|
| 335 |
-
snr_db,
|
| 336 |
-
doppler_id,
|
| 337 |
-
power_stats,
|
| 338 |
-
snr_id,
|
| 339 |
-
modulation_id,
|
| 340 |
-
)
|
| 341 |
-
|
| 342 |
-
def __iter__(self):
|
| 343 |
-
order = list(self.scenario_info_list)
|
| 344 |
-
if self.shuffle and order:
|
| 345 |
-
rng = random.Random(self.seed + self._epoch)
|
| 346 |
-
rng.shuffle(order)
|
| 347 |
-
epoch_seed = self.seed + self._epoch
|
| 348 |
-
self._epoch += 1
|
| 349 |
-
|
| 350 |
-
for idx, (scenario_name, spectrogram_path) in enumerate(order):
|
| 351 |
-
if self.world_size > 1 and (idx % self.world_size) != self.rank:
|
| 352 |
-
continue
|
| 353 |
-
result = process_single_scenario((scenario_name, spectrogram_path))
|
| 354 |
-
if result is None:
|
| 355 |
-
continue
|
| 356 |
-
|
| 357 |
-
data_key = 'train_data' if self.split == 'train' else 'val_data'
|
| 358 |
-
meta_key = 'train_meta' if self.split == 'train' else 'val_meta'
|
| 359 |
-
spectrograms = result.get(data_key)
|
| 360 |
-
metadata = result.get(meta_key)
|
| 361 |
-
|
| 362 |
-
if spectrograms is None or len(spectrograms) == 0:
|
| 363 |
-
continue
|
| 364 |
-
|
| 365 |
-
scenario_seed = (epoch_seed + idx) % (2**32)
|
| 366 |
-
tokenized = tokenizer_train(
|
| 367 |
-
spectrograms,
|
| 368 |
-
max_len=self.max_len,
|
| 369 |
-
masking_percent=self.mask_percent,
|
| 370 |
-
mask=True,
|
| 371 |
-
seed=scenario_seed,
|
| 372 |
-
metadata=metadata,
|
| 373 |
-
dataset_stats=self.dataset_stats,
|
| 374 |
-
normalization=self.normalization_mode,
|
| 375 |
-
show_progress=False,
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
for samples in tokenized.values():
|
| 379 |
-
for sample_dict in samples:
|
| 380 |
-
yield self._format_sample(sample_dict)
|
| 381 |
-
|
| 382 |
-
del tokenized, spectrograms, metadata, result
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
def summarize_scenarios(scenario_info_list, normalization_mode):
|
| 386 |
-
"""Calculate dataset-level normalization stats and sample counts without storing all data in memory."""
|
| 387 |
-
total_sum = 0.0
|
| 388 |
-
total_sq = 0.0
|
| 389 |
-
total_count = 0
|
| 390 |
-
train_samples = 0
|
| 391 |
-
val_samples = 0
|
| 392 |
-
|
| 393 |
-
iterator = scenario_info_list
|
| 394 |
-
if USE_TQDM and LOG_PRIMARY:
|
| 395 |
-
iterator = tqdm(scenario_info_list, desc="Summarizing scenarios", unit="scenario")
|
| 396 |
-
|
| 397 |
-
for scenario_name, spectrogram_path in iterator:
|
| 398 |
-
result = process_single_scenario((scenario_name, spectrogram_path))
|
| 399 |
-
if result is None:
|
| 400 |
-
continue
|
| 401 |
-
|
| 402 |
-
train_data = result.get('train_data')
|
| 403 |
-
val_data = result.get('val_data')
|
| 404 |
-
|
| 405 |
-
if isinstance(train_data, np.ndarray):
|
| 406 |
-
train_samples += train_data.shape[0]
|
| 407 |
-
if normalization_mode == "dataset" and train_data.size > 0:
|
| 408 |
-
arr64 = train_data.astype(np.float64, copy=False)
|
| 409 |
-
total_sum += arr64.sum()
|
| 410 |
-
total_sq += np.square(arr64).sum(dtype=np.float64)
|
| 411 |
-
total_count += arr64.size
|
| 412 |
-
|
| 413 |
-
if isinstance(val_data, np.ndarray):
|
| 414 |
-
val_samples += val_data.shape[0]
|
| 415 |
-
|
| 416 |
-
del result
|
| 417 |
-
|
| 418 |
-
if normalization_mode == "dataset":
|
| 419 |
-
if total_count == 0:
|
| 420 |
-
raise ValueError("Unable to compute dataset statistics: no training samples available.")
|
| 421 |
-
mean = float(total_sum / total_count)
|
| 422 |
-
variance = max(float(total_sq / total_count - mean ** 2), 1e-12)
|
| 423 |
-
std = float(np.sqrt(variance))
|
| 424 |
-
else:
|
| 425 |
-
mean = 0.0
|
| 426 |
-
std = 1.0
|
| 427 |
-
|
| 428 |
-
stats = {'mean': mean, 'std': std, 'normalization': normalization_mode}
|
| 429 |
-
return stats, train_samples, val_samples
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
# =============================================================================
|
| 433 |
-
# 2. SCENARIO LIST DEFINITION
|
| 434 |
-
# - Define the list of scenario names to iterate over for data generation
|
| 435 |
-
# =============================================================================
|
| 436 |
-
|
| 437 |
-
# Supported communications; can be limited via CLI
|
| 438 |
-
SUPPORTED_COMM_TYPES = {"LTE", "WiFi", "5G"}
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
def _parse_standard_args():
|
| 442 |
-
parser = argparse.ArgumentParser(add_help=False)
|
| 443 |
-
parser.add_argument('--standards', nargs='+', choices=SUPPORTED_COMM_TYPES,
|
| 444 |
-
help='Specify one or more communication types to include (default: all).')
|
| 445 |
-
for comm in SUPPORTED_COMM_TYPES:
|
| 446 |
-
parser.add_argument(f'--{comm}', dest=f'flag_{comm}', action='store_true',
|
| 447 |
-
help=f'Include only {comm} data (can be combined).')
|
| 448 |
-
parser.add_argument('--city', '--cities', dest='cities', nargs='+',
|
| 449 |
-
help='Limit scenarios to one or more city prefixes (e.g., "0" or "city_0").')
|
| 450 |
-
parser.add_argument(
|
| 451 |
-
'--normalization',
|
| 452 |
-
choices=('per_sample', 'dataset'),
|
| 453 |
-
default='per_sample',
|
| 454 |
-
help='Normalization mode applied during tokenization (default: %(default)s).'
|
| 455 |
-
)
|
| 456 |
-
parser.add_argument('--help', action='help')
|
| 457 |
-
|
| 458 |
-
args, remaining = parser.parse_known_args()
|
| 459 |
-
|
| 460 |
-
enabled = set(SUPPORTED_COMM_TYPES)
|
| 461 |
-
if args.standards:
|
| 462 |
-
enabled = set(args.standards)
|
| 463 |
-
else:
|
| 464 |
-
flagged = {comm for comm in SUPPORTED_COMM_TYPES if getattr(args, f'flag_{comm}', False)}
|
| 465 |
-
if flagged:
|
| 466 |
-
enabled = flagged
|
| 467 |
-
|
| 468 |
-
selected_cities: list[str] | None = None
|
| 469 |
-
if args.cities:
|
| 470 |
-
selected_cities = []
|
| 471 |
-
for city_token in args.cities:
|
| 472 |
-
token = str(city_token).strip()
|
| 473 |
-
if not token:
|
| 474 |
-
continue
|
| 475 |
-
if token.startswith('city_'):
|
| 476 |
-
selected_cities.append(token)
|
| 477 |
-
else:
|
| 478 |
-
selected_cities.append(f'city_{token}')
|
| 479 |
-
if not selected_cities:
|
| 480 |
-
selected_cities = None
|
| 481 |
-
|
| 482 |
-
# Return remaining args to allow downstream parsing if needed
|
| 483 |
-
sys.argv = [sys.argv[0]] + remaining
|
| 484 |
-
return enabled, selected_cities, args.normalization
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
ENABLED_COMM_TYPES, ENABLED_CITY_PREFIXES, NORMALIZATION_MODE = _parse_standard_args()
|
| 488 |
-
MAX_SCENARIOS = int(os.environ.get("LWM_MAX_SCENARIOS", "0")) or None
|
| 489 |
-
|
| 490 |
-
SCENARIO_ENTRIES: Optional[List[Tuple[str, str, str, str]]] = None
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
def _scenario_manifest_path() -> str:
|
| 494 |
-
"""Build cache file path based on selected comm types and city filters."""
|
| 495 |
-
comm_token = "-".join(sorted(ENABLED_COMM_TYPES)) if ENABLED_COMM_TYPES else "all"
|
| 496 |
-
city_token = "-".join(sorted(ENABLED_CITY_PREFIXES)) if ENABLED_CITY_PREFIXES else "all"
|
| 497 |
-
limit_token = MAX_SCENARIOS if MAX_SCENARIOS is not None else "all"
|
| 498 |
-
filename = f"_scenario_entries_{comm_token}_{city_token}_max{limit_token}.json"
|
| 499 |
-
return os.path.join("spectrograms", filename)
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
def _get_scenario_entries() -> List[Tuple[str, str, str, str]]:
|
| 503 |
-
"""Gather scenario metadata once on rank 0 and share via disk cache. Avoids long-lived collectives."""
|
| 504 |
-
global SCENARIO_ENTRIES
|
| 505 |
-
if SCENARIO_ENTRIES is not None:
|
| 506 |
-
return SCENARIO_ENTRIES
|
| 507 |
-
|
| 508 |
-
manifest_path = _scenario_manifest_path()
|
| 509 |
-
refresh_requested = os.environ.get("LWM_REFRESH_SCENARIOS", "").lower() in {"1", "true", "yes"}
|
| 510 |
-
|
| 511 |
-
def _load_manifest() -> Optional[List[Tuple[str, str, str, str]]]:
|
| 512 |
-
try:
|
| 513 |
-
with open(manifest_path, "r", encoding="utf-8") as f:
|
| 514 |
-
raw_entries = json.load(f)
|
| 515 |
-
except FileNotFoundError:
|
| 516 |
-
return None
|
| 517 |
-
except Exception as exc:
|
| 518 |
-
if LOG_PRIMARY:
|
| 519 |
-
print(f"⚠️ Unable to read scenario manifest {manifest_path}: {exc}", flush=True)
|
| 520 |
-
return None
|
| 521 |
-
|
| 522 |
-
entries: List[Tuple[str, str, str, str]] = []
|
| 523 |
-
for item in raw_entries:
|
| 524 |
-
if isinstance(item, dict):
|
| 525 |
-
entries.append(
|
| 526 |
-
(
|
| 527 |
-
item.get("scenario_id", ""),
|
| 528 |
-
item.get("file_path", ""),
|
| 529 |
-
item.get("city_name", ""),
|
| 530 |
-
item.get("base_token", ""),
|
| 531 |
-
)
|
| 532 |
-
)
|
| 533 |
-
elif isinstance(item, (list, tuple)) and len(item) == 4:
|
| 534 |
-
entries.append((str(item[0]), str(item[1]), str(item[2]), str(item[3])))
|
| 535 |
-
return entries if entries else None
|
| 536 |
-
|
| 537 |
-
def _save_manifest(entries_to_save: List[Tuple[str, str, str, str]]) -> None:
|
| 538 |
-
try:
|
| 539 |
-
os.makedirs(os.path.dirname(manifest_path), exist_ok=True)
|
| 540 |
-
tmp_path = f"{manifest_path}.tmp"
|
| 541 |
-
payload = [
|
| 542 |
-
{
|
| 543 |
-
"scenario_id": scenario_id,
|
| 544 |
-
"file_path": file_path,
|
| 545 |
-
"city_name": city_name,
|
| 546 |
-
"base_token": base_token,
|
| 547 |
-
}
|
| 548 |
-
for scenario_id, file_path, city_name, base_token in entries_to_save
|
| 549 |
-
]
|
| 550 |
-
with open(tmp_path, "w", encoding="utf-8") as f:
|
| 551 |
-
json.dump(payload, f)
|
| 552 |
-
os.replace(tmp_path, manifest_path)
|
| 553 |
-
if LOG_PRIMARY:
|
| 554 |
-
print(f"📊 [debug] Scenario manifest saved to {manifest_path}", flush=True)
|
| 555 |
-
except Exception as exc:
|
| 556 |
-
if LOG_PRIMARY:
|
| 557 |
-
print(f"⚠️ Failed to save scenario manifest {manifest_path}: {exc}", flush=True)
|
| 558 |
-
|
| 559 |
-
entries: Optional[List[Tuple[str, str, str, str]]] = None
|
| 560 |
-
if distributed_context["is_distributed"]:
|
| 561 |
-
entries = None if refresh_requested else _load_manifest()
|
| 562 |
-
if entries is None:
|
| 563 |
-
if distributed_context["is_primary"]:
|
| 564 |
-
if LOG_PRIMARY:
|
| 565 |
-
print("📊 [debug] Rank0 starting scenario discovery", flush=True)
|
| 566 |
-
entries = _collect_scenario_file_info()
|
| 567 |
-
if LOG_PRIMARY:
|
| 568 |
-
print(f"📊 [debug] Rank0 collected {len(entries)} scenario entries", flush=True)
|
| 569 |
-
_save_manifest(entries)
|
| 570 |
-
else:
|
| 571 |
-
deadline = time.time() + 300.0
|
| 572 |
-
while time.time() < deadline:
|
| 573 |
-
entries = _load_manifest()
|
| 574 |
-
if entries is not None:
|
| 575 |
-
break
|
| 576 |
-
time.sleep(1.0)
|
| 577 |
-
if entries is None:
|
| 578 |
-
raise RuntimeError(
|
| 579 |
-
f"Scenario manifest {manifest_path} not found after waiting. "
|
| 580 |
-
"Run with LWM_REFRESH_SCENARIOS=1 on a single rank to regenerate."
|
| 581 |
-
)
|
| 582 |
-
elif LOG_PRIMARY and distributed_context["is_primary"]:
|
| 583 |
-
print(f"📊 [debug] Rank0 loaded {len(entries)} scenario entries from manifest", flush=True)
|
| 584 |
-
else:
|
| 585 |
-
entries = None if refresh_requested else _load_manifest()
|
| 586 |
-
if entries is None:
|
| 587 |
-
if LOG_PRIMARY:
|
| 588 |
-
print("📊 [debug] Single-process scenario discovery", flush=True)
|
| 589 |
-
entries = _collect_scenario_file_info()
|
| 590 |
-
if LOG_PRIMARY:
|
| 591 |
-
print(f"📊 [debug] Collected {len(entries)} scenario entries (single process)", flush=True)
|
| 592 |
-
_save_manifest(entries)
|
| 593 |
-
elif LOG_PRIMARY:
|
| 594 |
-
print(f"📊 [debug] Loaded {len(entries)} scenario entries from manifest", flush=True)
|
| 595 |
-
|
| 596 |
-
if entries is None:
|
| 597 |
-
entries = []
|
| 598 |
-
SCENARIO_ENTRIES = entries
|
| 599 |
-
return entries
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
def _extract_scenario_token(file_path):
|
| 603 |
-
"""Derive the base scenario token (without city) from the file path."""
|
| 604 |
-
normalized_path = os.path.normpath(file_path)
|
| 605 |
-
parts = normalized_path.split(os.sep)
|
| 606 |
-
|
| 607 |
-
scenario_parts = []
|
| 608 |
-
for i, part in enumerate(parts):
|
| 609 |
-
if part in SUPPORTED_COMM_TYPES:
|
| 610 |
-
if i + 4 < len(parts):
|
| 611 |
-
scenario_parts = [part] + parts[i + 1:i + 5]
|
| 612 |
-
break
|
| 613 |
-
return '_'.join(scenario_parts) if scenario_parts else None
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
@lru_cache(maxsize=1)
|
| 617 |
-
def _collect_scenario_file_info():
|
| 618 |
-
import glob
|
| 619 |
-
|
| 620 |
-
if LOG_PRIMARY:
|
| 621 |
-
print("📊 [debug] _collect_scenario_file_info scanning directories...", flush=True)
|
| 622 |
-
city_dirs = []
|
| 623 |
-
for d in sorted(glob.glob(os.path.join('spectrograms', 'city_*'))):
|
| 624 |
-
if not os.path.isdir(d):
|
| 625 |
-
continue
|
| 626 |
-
city_dirs.append(d)
|
| 627 |
-
|
| 628 |
-
scenario_entries = []
|
| 629 |
-
for city_dir in city_dirs:
|
| 630 |
-
city_name = os.path.basename(city_dir)
|
| 631 |
-
if ENABLED_CITY_PREFIXES:
|
| 632 |
-
if not any(city_name.startswith(prefix) for prefix in ENABLED_CITY_PREFIXES):
|
| 633 |
-
continue
|
| 634 |
-
pattern = os.path.join(city_dir, '**', '512FFT', '**', 'spectrograms', '*.pkl')
|
| 635 |
-
city_files = sorted(glob.glob(pattern, recursive=True))
|
| 636 |
-
for file_path in city_files:
|
| 637 |
-
base_token = _extract_scenario_token(file_path)
|
| 638 |
-
if not base_token:
|
| 639 |
-
continue
|
| 640 |
-
scenario_id = f"{city_name}::{base_token}"
|
| 641 |
-
comm_type = base_token.split('_', 1)[0]
|
| 642 |
-
if comm_type not in ENABLED_COMM_TYPES:
|
| 643 |
-
continue
|
| 644 |
-
scenario_entries.append((scenario_id, file_path, city_name, base_token))
|
| 645 |
-
|
| 646 |
-
if MAX_SCENARIOS:
|
| 647 |
-
scenario_entries = scenario_entries[:MAX_SCENARIOS]
|
| 648 |
-
|
| 649 |
-
if LOG_PRIMARY:
|
| 650 |
-
print(f"📊 [debug] _collect_scenario_file_info found {len(scenario_entries)} entries", flush=True)
|
| 651 |
-
return scenario_entries
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
def scenarios_list():
|
| 655 |
-
scenario_entries = _get_scenario_entries()
|
| 656 |
-
|
| 657 |
-
if not scenario_entries:
|
| 658 |
-
if LOG_PRIMARY:
|
| 659 |
-
print("⚠️ No spectrogram files found for pretraining.", flush=True)
|
| 660 |
-
return np.array([])
|
| 661 |
-
|
| 662 |
-
if LOG_PRIMARY:
|
| 663 |
-
print(f"📊 [debug] scenarios_list received {len(scenario_entries)} entries", flush=True)
|
| 664 |
-
print(f"Enabled communication types: {sorted(ENABLED_COMM_TYPES)}", flush=True)
|
| 665 |
-
if ENABLED_CITY_PREFIXES:
|
| 666 |
-
print(f"Selected city prefixes: {sorted(ENABLED_CITY_PREFIXES)}", flush=True)
|
| 667 |
-
city_counts = Counter(entry[2] for entry in scenario_entries)
|
| 668 |
-
print("Using scenarios from the following city datasets:", flush=True)
|
| 669 |
-
for city_name, count in city_counts.items():
|
| 670 |
-
print(f" - {city_name}: {count} files", flush=True)
|
| 671 |
-
|
| 672 |
-
print(f"Total scenarios selected: {len(scenario_entries)}", flush=True)
|
| 673 |
-
return np.array([entry[0] for entry in scenario_entries])
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
# =============================================================================
|
| 677 |
-
# 3. SCENARIO PROPERTIES MAPPING
|
| 678 |
-
# - Map each scenario name to its corresponding properties
|
| 679 |
-
# =============================================================================
|
| 680 |
-
|
| 681 |
-
def scenario_prop():
|
| 682 |
-
scenario_entries = _get_scenario_entries()
|
| 683 |
-
|
| 684 |
-
row_column_users = {}
|
| 685 |
-
for scenario_id, file_path, city_name, _ in scenario_entries:
|
| 686 |
-
row_column_users[scenario_id] = {
|
| 687 |
-
'spectrogram_path': file_path,
|
| 688 |
-
'cache_path': os.path.join('spectrograms', city_name, 'spectrogram_cache_128x128.pkl')
|
| 689 |
-
}
|
| 690 |
-
|
| 691 |
-
return row_column_users
|
| 692 |
-
|
| 693 |
-
# =============================================================================
|
| 694 |
-
# 4. TRAINING PARAMETERS AND HYPERPARAMETERS
|
| 695 |
-
# - Set training epochs, batch sizes, learning rates, model dimensions, etc.
|
| 696 |
-
# =============================================================================
|
| 697 |
-
|
| 698 |
-
EPOCHS = 20 # Increased for better convergence
|
| 699 |
-
# Optimized batch size for A100 GPU (40GB)
|
| 700 |
-
BATCH_SIZE = 16
|
| 701 |
-
VAL_BATCH_SIZE = 16
|
| 702 |
-
WARMUP_EPOCHS = 5
|
| 703 |
-
BASE_LR = 5e-4
|
| 704 |
-
MIN_LR = 1e-8
|
| 705 |
-
# Updated for 128x128 spectrograms
|
| 706 |
-
N_ROWS = 4
|
| 707 |
-
N_COLUMNS = 4
|
| 708 |
-
ELEMENT_LENGTH = N_ROWS * N_COLUMNS # Real-valued spectrograms (no complex interleaving)
|
| 709 |
-
D_MODEL = 128
|
| 710 |
-
MAX_LEN = 1025 # (128/4)^2 + 1 = 1024 + 1 for [CLS] token
|
| 711 |
-
N_LAYERS = 12
|
| 712 |
-
device_idx = 0
|
| 713 |
-
WEIGHT_DECAY = 0.05
|
| 714 |
-
BETA1 = 0.9
|
| 715 |
-
BETA2 = 0.999
|
| 716 |
-
MASK_PERCENT = 0.6
|
| 717 |
-
N_HEADS = 8
|
| 718 |
-
DROPOUT = 0.1
|
| 719 |
-
|
| 720 |
-
# =============================================================================
|
| 721 |
-
# 5. DATA GENERATION LOOP
|
| 722 |
-
# - Iterate over scenarios to generate spectrogram samples and labels
|
| 723 |
-
# =============================================================================
|
| 724 |
-
|
| 725 |
-
scenarios = scenarios_list()
|
| 726 |
-
scenario_properties = scenario_prop()
|
| 727 |
-
|
| 728 |
-
if LOG_PRIMARY:
|
| 729 |
-
print(f"📂 Loading {len(scenarios)} scenarios...")
|
| 730 |
-
|
| 731 |
-
scenario_info_list = []
|
| 732 |
-
missing_props = []
|
| 733 |
-
for scenario in scenarios:
|
| 734 |
-
props = scenario_properties.get(scenario)
|
| 735 |
-
if props is None:
|
| 736 |
-
missing_props.append(scenario)
|
| 737 |
-
continue
|
| 738 |
-
scenario_info_list.append((scenario, props["spectrogram_path"]))
|
| 739 |
-
|
| 740 |
-
if distributed_context["is_distributed"] and len(scenario_info_list) < distributed_context["world_size"]:
|
| 741 |
-
if LOG_PRIMARY:
|
| 742 |
-
print("❌ Distributed configuration requires at least one scenario per process. "
|
| 743 |
-
f"Found {len(scenario_info_list)} scenarios for world size {distributed_context['world_size']}.")
|
| 744 |
-
raise ValueError("Insufficient scenarios for the requested distributed world size.")
|
| 745 |
-
|
| 746 |
-
if missing_props and LOG_PRIMARY:
|
| 747 |
-
print("⚠️ Missing metadata for the following scenarios; skipping:")
|
| 748 |
-
for scen in missing_props:
|
| 749 |
-
print(f" - {scen}")
|
| 750 |
-
|
| 751 |
-
if LOG_PRIMARY:
|
| 752 |
-
print(f"📂 Preparing {len(scenario_info_list)} scenarios with streaming loaders...")
|
| 753 |
-
|
| 754 |
-
if NORMALIZATION_MODE == "dataset":
|
| 755 |
-
if distributed_context["is_distributed"] and not distributed_context["is_primary"]:
|
| 756 |
-
dataset_normalization = None
|
| 757 |
-
train_sample_count = 0
|
| 758 |
-
val_sample_count = 0
|
| 759 |
-
else:
|
| 760 |
-
dataset_normalization, train_sample_count, val_sample_count = summarize_scenarios(
|
| 761 |
-
scenario_info_list,
|
| 762 |
-
NORMALIZATION_MODE,
|
| 763 |
-
)
|
| 764 |
-
if distributed_context["is_distributed"]:
|
| 765 |
-
payload = [dataset_normalization, train_sample_count, val_sample_count]
|
| 766 |
-
dataset_normalization, train_sample_count, val_sample_count = _broadcast_object(payload, src=0)
|
| 767 |
-
else:
|
| 768 |
-
train_samples_per_scenario = int(TRAIN_SPLIT_FRACTION * DEFAULT_SAMPLES_PER_SCENARIO)
|
| 769 |
-
val_samples_per_scenario = max(DEFAULT_SAMPLES_PER_SCENARIO - train_samples_per_scenario, 0)
|
| 770 |
-
dataset_normalization = {'mean': 0.0, 'std': 1.0, 'normalization': NORMALIZATION_MODE}
|
| 771 |
-
train_sample_count = len(scenario_info_list) * train_samples_per_scenario
|
| 772 |
-
val_sample_count = len(scenario_info_list) * val_samples_per_scenario
|
| 773 |
-
if LOG_PRIMARY:
|
| 774 |
-
print(f" Assuming {DEFAULT_SAMPLES_PER_SCENARIO} samples per scenario ({train_samples_per_scenario} train / {val_samples_per_scenario} val)")
|
| 775 |
-
|
| 776 |
-
if LOG_PRIMARY:
|
| 777 |
-
print(f" Training samples: {train_sample_count}")
|
| 778 |
-
print(f" Validation samples: {val_sample_count}")
|
| 779 |
-
if train_sample_count == 0:
|
| 780 |
-
raise ValueError("No training samples available after filtering scenarios.")
|
| 781 |
-
if NORMALIZATION_MODE == "dataset":
|
| 782 |
-
if LOG_PRIMARY:
|
| 783 |
-
print(f"Dataset normalization stats -> mean: {dataset_normalization['mean']:.4f}, std: {dataset_normalization['std']:.4f}")
|
| 784 |
-
else:
|
| 785 |
-
if LOG_PRIMARY:
|
| 786 |
-
print("Dataset normalization stats -> using per-sample normalization")
|
| 787 |
-
|
| 788 |
-
SEED = 42
|
| 789 |
-
torch.manual_seed(SEED)
|
| 790 |
-
np.random.seed(SEED)
|
| 791 |
-
|
| 792 |
-
world_size = max(1, distributed_context["world_size"])
|
| 793 |
-
train_samples_per_rank = math.ceil(train_sample_count / world_size) if distributed_context["is_distributed"] else train_sample_count
|
| 794 |
-
val_samples_per_rank = math.ceil(val_sample_count / world_size) if distributed_context["is_distributed"] else val_sample_count
|
| 795 |
-
|
| 796 |
-
train_dataset = StreamingMaskedSpectrogramDataset(
|
| 797 |
-
scenario_info_list,
|
| 798 |
-
split="train",
|
| 799 |
-
normalization_mode=NORMALIZATION_MODE,
|
| 800 |
-
dataset_stats=dataset_normalization,
|
| 801 |
-
mask_percent=MASK_PERCENT,
|
| 802 |
-
max_len=MAX_LEN,
|
| 803 |
-
seed=SEED,
|
| 804 |
-
shuffle=True,
|
| 805 |
-
rank=distributed_context["rank"],
|
| 806 |
-
world_size=world_size,
|
| 807 |
-
)
|
| 808 |
-
train_dataset.num_samples = train_samples_per_rank
|
| 809 |
-
|
| 810 |
-
val_dataset = StreamingMaskedSpectrogramDataset(
|
| 811 |
-
scenario_info_list,
|
| 812 |
-
split="val",
|
| 813 |
-
normalization_mode=NORMALIZATION_MODE,
|
| 814 |
-
dataset_stats=dataset_normalization,
|
| 815 |
-
mask_percent=MASK_PERCENT,
|
| 816 |
-
max_len=MAX_LEN,
|
| 817 |
-
seed=SEED,
|
| 818 |
-
shuffle=False,
|
| 819 |
-
rank=distributed_context["rank"],
|
| 820 |
-
world_size=world_size,
|
| 821 |
-
)
|
| 822 |
-
val_dataset.num_samples = val_samples_per_rank
|
| 823 |
-
|
| 824 |
-
if LOG_PRIMARY:
|
| 825 |
-
print("🔧 Creating streaming data loaders...")
|
| 826 |
-
train_loaders = {
|
| 827 |
-
'stream': DataLoader(
|
| 828 |
-
train_dataset,
|
| 829 |
-
batch_size=BATCH_SIZE,
|
| 830 |
-
shuffle=False,
|
| 831 |
-
num_workers=0,
|
| 832 |
-
pin_memory=True,
|
| 833 |
-
)
|
| 834 |
-
}
|
| 835 |
-
val_loaders = {
|
| 836 |
-
'stream': DataLoader(
|
| 837 |
-
val_dataset,
|
| 838 |
-
batch_size=VAL_BATCH_SIZE,
|
| 839 |
-
shuffle=False,
|
| 840 |
-
num_workers=0,
|
| 841 |
-
pin_memory=True,
|
| 842 |
-
)
|
| 843 |
-
}
|
| 844 |
-
if LOG_PRIMARY:
|
| 845 |
-
print("✅ Data loaders created successfully!")
|
| 846 |
-
|
| 847 |
-
# =============================================================================
|
| 848 |
-
# 9. MODEL INITIALIZATION
|
| 849 |
-
# - Instantiate the LWM transformer model and optionally load pre-trained weights
|
| 850 |
-
# - Wrap with DataParallel for multi-GPU support
|
| 851 |
-
# =============================================================================
|
| 852 |
-
|
| 853 |
-
# Device selection with HPU, CUDA, and MPS support
|
| 854 |
-
if LOG_PRIMARY:
|
| 855 |
-
print("🔧 Setting up device and accelerator configuration...")
|
| 856 |
-
|
| 857 |
-
requested_device = getattr(RUNTIME_ARGS, "device", "auto") or "auto"
|
| 858 |
-
requested_device = requested_device.lower()
|
| 859 |
-
runtime_device = requested_device
|
| 860 |
-
|
| 861 |
-
if runtime_device == "auto":
|
| 862 |
-
if HPU_AVAILABLE:
|
| 863 |
-
runtime_device = "hpu"
|
| 864 |
-
elif torch.cuda.is_available():
|
| 865 |
-
runtime_device = "cuda"
|
| 866 |
-
elif torch.backends.mps.is_available():
|
| 867 |
-
runtime_device = "mps"
|
| 868 |
-
else:
|
| 869 |
-
runtime_device = "cpu"
|
| 870 |
-
|
| 871 |
-
if runtime_device in {"hpu", "auto"} and not HPU_AVAILABLE:
|
| 872 |
-
if os.environ.get("HABANA_VISIBLE_DEVICES") and LOG_PRIMARY:
|
| 873 |
-
print("⚠️ HABANA_VISIBLE_DEVICES is set but Habana PyTorch extensions are not available.")
|
| 874 |
-
print(" Install the Habana PyTorch distribution or activate the appropriate environment.")
|
| 875 |
-
|
| 876 |
-
device = torch.device("cpu")
|
| 877 |
-
gpu_ids: list[int] = []
|
| 878 |
-
ddp_device_ids: Optional[list[int]] = None
|
| 879 |
-
|
| 880 |
-
if runtime_device == "hpu":
|
| 881 |
-
if not HPU_AVAILABLE:
|
| 882 |
-
raise RuntimeError("HPU device requested but torch.hpu is not available. "
|
| 883 |
-
"Install the Habana PyTorch distribution or select --device cpu.")
|
| 884 |
-
hpu_module = getattr(torch, "hpu", None)
|
| 885 |
-
|
| 886 |
-
# Get local rank first before any HPU operations
|
| 887 |
-
local_rank = distributed_context["local_rank"] if distributed_context["is_distributed"] else 0
|
| 888 |
-
_debug_hpu(f"Entering HPU device setup (local_rank={local_rank}, world_size={distributed_context.get('world_size')})")
|
| 889 |
-
|
| 890 |
-
# Query device count locally (safe after Habana runtime init)
|
| 891 |
-
hpu_count = max(1, _get_hpu_device_count())
|
| 892 |
-
if LOG_PRIMARY or HPU_DEBUG_LOG:
|
| 893 |
-
_debug_hpu(f"Detected {hpu_count} HPU devices via local query")
|
| 894 |
-
|
| 895 |
-
device = torch.device("hpu")
|
| 896 |
-
if hpu_module is not None and hasattr(hpu_module, "set_device"):
|
| 897 |
-
try:
|
| 898 |
-
_debug_hpu(f"Calling torch.hpu.set_device({local_rank})")
|
| 899 |
-
hpu_module.set_device(local_rank)
|
| 900 |
-
_debug_hpu("torch.hpu.set_device completed successfully")
|
| 901 |
-
except Exception as exc:
|
| 902 |
-
_debug_hpu(f"set_device raised exception: {exc}")
|
| 903 |
-
if LOG_PRIMARY:
|
| 904 |
-
print(f" ⚠️ Unable to set HPU device {local_rank}: {exc}")
|
| 905 |
-
ddp_device_ids = [local_rank] if distributed_context["is_distributed"] else None
|
| 906 |
-
if LOG_PRIMARY:
|
| 907 |
-
if hpu_count > 0:
|
| 908 |
-
print(f" HPU available: {hpu_count} device(s) detected")
|
| 909 |
-
if distributed_context["is_distributed"]:
|
| 910 |
-
print(f" Using HPU local rank: {local_rank}")
|
| 911 |
-
elif runtime_device == "cuda":
|
| 912 |
-
if not torch.cuda.is_available():
|
| 913 |
-
raise RuntimeError("CUDA device requested but torch.cuda.is_available() is False.")
|
| 914 |
-
device_count = torch.cuda.device_count()
|
| 915 |
-
if LOG_PRIMARY:
|
| 916 |
-
print(f" CUDA available: {device_count} GPU(s) detected")
|
| 917 |
-
if distributed_context["is_distributed"]:
|
| 918 |
-
local_rank = distributed_context["local_rank"]
|
| 919 |
-
torch.cuda.set_device(local_rank)
|
| 920 |
-
device = torch.device("cuda", local_rank)
|
| 921 |
-
ddp_device_ids = [local_rank]
|
| 922 |
-
if LOG_PRIMARY:
|
| 923 |
-
print(f" Using CUDA local rank: {local_rank}")
|
| 924 |
-
else:
|
| 925 |
-
device = torch.device("cuda:0")
|
| 926 |
-
gpu_ids = list(range(device_count))
|
| 927 |
-
if LOG_PRIMARY:
|
| 928 |
-
print(f" Using CUDA GPUs: {gpu_ids}")
|
| 929 |
-
for i in gpu_ids:
|
| 930 |
-
try:
|
| 931 |
-
mem_total = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
| 932 |
-
mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
|
| 933 |
-
if LOG_PRIMARY:
|
| 934 |
-
print(f" GPU {i}: Total: {mem_total:.1f}GB, Allocated: {mem_allocated:.1f}GB")
|
| 935 |
-
except Exception as exc:
|
| 936 |
-
if LOG_PRIMARY:
|
| 937 |
-
print(f" GPU {i}: Error getting memory info - {exc}")
|
| 938 |
-
elif runtime_device == "mps":
|
| 939 |
-
if not torch.backends.mps.is_available():
|
| 940 |
-
raise RuntimeError("MPS device requested but torch.backends.mps.is_available() is False.")
|
| 941 |
-
device = torch.device("mps")
|
| 942 |
-
if LOG_PRIMARY:
|
| 943 |
-
print(" Using MPS (Apple Silicon GPU)")
|
| 944 |
-
elif runtime_device == "cpu":
|
| 945 |
-
device = torch.device("cpu")
|
| 946 |
-
if LOG_PRIMARY:
|
| 947 |
-
print(" Using CPU")
|
| 948 |
-
else:
|
| 949 |
-
raise ValueError(f"Unsupported device selection: {runtime_device}")
|
| 950 |
-
|
| 951 |
-
distributed_context["device_type"] = device.type
|
| 952 |
-
if LOG_PRIMARY:
|
| 953 |
-
print(f" Final device: {device}")
|
| 954 |
-
if gpu_ids:
|
| 955 |
-
print(f" GPU IDs for DataParallel: {gpu_ids}")
|
| 956 |
-
|
| 957 |
-
if LOG_PRIMARY:
|
| 958 |
-
print("🤖 Initializing LWM model...")
|
| 959 |
-
print(f" Model parameters: element_length={ELEMENT_LENGTH}, d_model={D_MODEL}, n_layers={N_LAYERS}, max_len={MAX_LEN}, n_heads={N_HEADS}")
|
| 960 |
-
|
| 961 |
-
try:
|
| 962 |
-
model = pretrained_model.lwm(
|
| 963 |
-
element_length=ELEMENT_LENGTH, # Real-valued spectrograms
|
| 964 |
-
d_model=D_MODEL,
|
| 965 |
-
n_layers=N_LAYERS,
|
| 966 |
-
max_len=MAX_LEN, # Use pre-calculated value for safety
|
| 967 |
-
n_heads=N_HEADS,
|
| 968 |
-
dropout=DROPOUT
|
| 969 |
-
)
|
| 970 |
-
if LOG_PRIMARY:
|
| 971 |
-
print(" ✅ Model created successfully")
|
| 972 |
-
print(f" Moving model to device: {device}")
|
| 973 |
-
# MPS only supports float32, so set dtype
|
| 974 |
-
if 'mps' in str(device):
|
| 975 |
-
model = model.to(device).float()
|
| 976 |
-
if LOG_PRIMARY:
|
| 977 |
-
print(" ✅ Model moved to MPS device (float32)")
|
| 978 |
-
else:
|
| 979 |
-
model = model.to(device)
|
| 980 |
-
if LOG_PRIMARY:
|
| 981 |
-
print(" ✅ Model moved to device successfully")
|
| 982 |
-
|
| 983 |
-
# Synchronize all processes after moving model to device
|
| 984 |
-
# This prevents memory contention issues in multi-HPU/GPU setups
|
| 985 |
-
if distributed_context["is_distributed"]:
|
| 986 |
-
torch.distributed.barrier()
|
| 987 |
-
if LOG_PRIMARY:
|
| 988 |
-
print(" ✅ All processes synchronized after model transfer")
|
| 989 |
-
|
| 990 |
-
except Exception as e:
|
| 991 |
-
print(f" ❌ Model initialization failed: {e}")
|
| 992 |
-
import traceback
|
| 993 |
-
traceback.print_exc()
|
| 994 |
-
exit(1)
|
| 995 |
-
|
| 996 |
-
# Optional: Load pre-trained model
|
| 997 |
-
load_model = False
|
| 998 |
-
if load_model:
|
| 999 |
-
model.load_state_dict(torch.load("models/model_checkpoint.pth", map_location=device))
|
| 1000 |
-
if LOG_PRIMARY:
|
| 1001 |
-
print("Pre-trained model loaded successfully.")
|
| 1002 |
-
|
| 1003 |
-
# Wrap model for parallel/distributed execution
|
| 1004 |
-
if distributed_context["is_distributed"]:
|
| 1005 |
-
# Additional barrier before DDP wrapping to ensure all processes are ready
|
| 1006 |
-
torch.distributed.barrier()
|
| 1007 |
-
|
| 1008 |
-
ddp_kwargs: Dict[str, Any] = {"broadcast_buffers": False}
|
| 1009 |
-
if ddp_device_ids:
|
| 1010 |
-
ddp_kwargs["device_ids"] = ddp_device_ids
|
| 1011 |
-
ddp_kwargs["output_device"] = ddp_device_ids[0]
|
| 1012 |
-
else:
|
| 1013 |
-
ddp_kwargs["device_ids"] = None
|
| 1014 |
-
model = nn.parallel.DistributedDataParallel(model, **ddp_kwargs)
|
| 1015 |
-
if LOG_PRIMARY:
|
| 1016 |
-
print(f"Model wrapped with DistributedDataParallel on rank {distributed_context['rank']}")
|
| 1017 |
-
elif gpu_ids:
|
| 1018 |
-
model = nn.DataParallel(model, device_ids=gpu_ids)
|
| 1019 |
-
if LOG_PRIMARY:
|
| 1020 |
-
print(f"Model loaded successfully with DataParallel on CUDA devices {gpu_ids}")
|
| 1021 |
-
else:
|
| 1022 |
-
if LOG_PRIMARY:
|
| 1023 |
-
print(f"Model loaded successfully on {device}")
|
| 1024 |
-
n_parameters = count_parameters(model, log=LOG_PRIMARY)
|
| 1025 |
-
if LOG_PRIMARY:
|
| 1026 |
-
print(f"Number of trainable parameters: {n_parameters:,}")
|
| 1027 |
-
|
| 1028 |
-
# =============================================================================
|
| 1029 |
-
# 10. OPTIMIZER AND LEARNING RATE SCHEDULER
|
| 1030 |
-
# - Configure AdamW optimizer and a cosine-with-warmup LR schedule based on total steps
|
| 1031 |
-
# =============================================================================
|
| 1032 |
-
|
| 1033 |
-
steps_per_epoch = max(1, math.ceil(train_samples_per_rank / BATCH_SIZE))
|
| 1034 |
-
TOTAL_STEPS = steps_per_epoch * EPOCHS
|
| 1035 |
-
WARMUP_STEPS = steps_per_epoch * WARMUP_EPOCHS
|
| 1036 |
-
|
| 1037 |
-
optimizer = AdamW(
|
| 1038 |
-
model.parameters(),
|
| 1039 |
-
lr=BASE_LR,
|
| 1040 |
-
betas=(BETA1, BETA2),
|
| 1041 |
-
weight_decay=WEIGHT_DECAY
|
| 1042 |
-
)
|
| 1043 |
-
|
| 1044 |
-
def lr_lambda(current_step):
|
| 1045 |
-
if current_step < WARMUP_STEPS:
|
| 1046 |
-
return current_step / WARMUP_STEPS
|
| 1047 |
-
else:
|
| 1048 |
-
scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
|
| 1049 |
-
cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
|
| 1050 |
-
return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
|
| 1051 |
-
|
| 1052 |
-
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
|
| 1053 |
-
|
| 1054 |
-
# =============================================================================
|
| 1055 |
-
# 11. PRE-TRAINING LOOP
|
| 1056 |
-
# - Call the train_lwm utility to run the pre-training epochs, logging metrics and saving models
|
| 1057 |
-
# =============================================================================
|
| 1058 |
-
|
| 1059 |
-
# Create timestamp-based save directory
|
| 1060 |
-
if distributed_context["is_distributed"]:
|
| 1061 |
-
timestamp_source = datetime.now().strftime("%Y%m%d_%H%M%S") if LOG_PRIMARY else None
|
| 1062 |
-
timestamp = _broadcast_object(timestamp_source, src=0)
|
| 1063 |
-
else:
|
| 1064 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 1065 |
-
save_dir = f"models/{timestamp}"
|
| 1066 |
-
if LOG_PRIMARY:
|
| 1067 |
-
print(f"📁 Models and logs will be saved to: {save_dir}")
|
| 1068 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 1069 |
-
|
| 1070 |
-
stats_path = os.path.join(save_dir, "dataset_stats.json")
|
| 1071 |
-
if LOG_PRIMARY:
|
| 1072 |
-
with open(stats_path, 'w') as f:
|
| 1073 |
-
json.dump(dataset_normalization, f, indent=2)
|
| 1074 |
-
print(f"📝 Saved dataset stats to {stats_path}")
|
| 1075 |
-
_barrier(distributed_context)
|
| 1076 |
-
|
| 1077 |
-
comm_selection = sorted(ENABLED_COMM_TYPES) if ENABLED_COMM_TYPES else []
|
| 1078 |
-
if comm_selection:
|
| 1079 |
-
comm_suffix = "_" + "-".join(comm_selection)
|
| 1080 |
-
else:
|
| 1081 |
-
comm_suffix = ""
|
| 1082 |
-
if comm_selection and LOG_PRIMARY:
|
| 1083 |
-
print(f"[INFO] Communication standards for this run: {', '.join(comm_selection)}")
|
| 1084 |
-
|
| 1085 |
-
if __name__ == "__main__":
|
| 1086 |
-
# Patch: Ensure patches is not a dict before converting to tensor
|
| 1087 |
-
def safe_tensor_from_patches(patches, device):
|
| 1088 |
-
if isinstance(patches, dict):
|
| 1089 |
-
key = max(patches.keys())
|
| 1090 |
-
patches = patches[key]
|
| 1091 |
-
return torch.tensor(patches, dtype=torch.float32).to(device)
|
| 1092 |
-
|
| 1093 |
-
# Pass this function to train_lwm if needed, or use inside train_lwm
|
| 1094 |
-
pretrained_model = train_lwm(
|
| 1095 |
-
model,
|
| 1096 |
-
train_loaders,
|
| 1097 |
-
val_loaders,
|
| 1098 |
-
optimizer,
|
| 1099 |
-
scheduler,
|
| 1100 |
-
EPOCHS,
|
| 1101 |
-
device=device,
|
| 1102 |
-
save_dir=save_dir,
|
| 1103 |
-
log_file="training_log.csv",
|
| 1104 |
-
checkpoint_suffix=comm_suffix,
|
| 1105 |
-
distributed_context=distributed_context,
|
| 1106 |
-
# If train_lwm needs to convert patches, use safe_tensor_from_patches
|
| 1107 |
-
)
|
| 1108 |
-
_barrier(distributed_context)
|
| 1109 |
-
if LOG_PRIMARY:
|
| 1110 |
-
print("🏁 Training run complete.")
|
| 1111 |
-
if distributed_context["is_distributed"]:
|
| 1112 |
-
dist.destroy_process_group()
|
| 1113 |
-
SNR_PATTERN = re.compile(r"SNR(-?\d+)dB")
|
| 1114 |
-
DOPPLER_MAP = {"static": 0, "pedestrian": 1, "vehicular": 2}
|
| 1115 |
-
DOPPLER_INV = {v: k for k, v in DOPPLER_MAP.items()}
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
def _parse_snr_and_doppler(path: str) -> tuple[float, int]:
|
| 1119 |
-
snr_db = 0.0
|
| 1120 |
-
doppler_id = 0
|
| 1121 |
-
|
| 1122 |
-
matches = SNR_PATTERN.findall(path)
|
| 1123 |
-
if matches:
|
| 1124 |
-
try:
|
| 1125 |
-
snr_db = float(matches[-1])
|
| 1126 |
-
except ValueError:
|
| 1127 |
-
snr_db = 0.0
|
| 1128 |
-
|
| 1129 |
-
normalized_path = os.path.normpath(path)
|
| 1130 |
-
parts = normalized_path.split(os.sep)
|
| 1131 |
-
for part in parts:
|
| 1132 |
-
if part in DOPPLER_MAP:
|
| 1133 |
-
doppler_id = DOPPLER_MAP[part]
|
| 1134 |
-
break
|
| 1135 |
-
|
| 1136 |
-
return snr_db, doppler_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# UI/Hub
|
| 2 |
-
gradio==
|
| 3 |
-
huggingface_hub
|
| 4 |
|
| 5 |
# Core
|
| 6 |
torch
|
|
|
|
| 1 |
# UI/Hub
|
| 2 |
+
gradio==6.0.1
|
| 3 |
+
huggingface_hub>=0.33.5,<2.0
|
| 4 |
|
| 5 |
# Core
|
| 6 |
torch
|