Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
|
| 2 |
-
import huggingface_hub as hf_hub
|
| 3 |
import inspect
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# Gradio <-> hub compatibility shim: newer huggingface_hub removed HfFolder.
|
| 6 |
if not hasattr(hf_hub, "HfFolder"):
|
|
@@ -22,8 +25,20 @@ import pandas as pd
|
|
| 22 |
from sklearn.manifold import TSNE
|
| 23 |
from sklearn.decomposition import PCA
|
| 24 |
from sklearn.preprocessing import StandardScaler
|
|
|
|
| 25 |
import matplotlib.pyplot as plt
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Load data
|
| 28 |
def load_data():
|
| 29 |
print("Loading data...")
|
|
@@ -41,9 +56,9 @@ def load_data():
|
|
| 41 |
})
|
| 42 |
df = pd.DataFrame(records)
|
| 43 |
print(f"Loaded {len(df)} samples.")
|
| 44 |
-
return df
|
| 45 |
|
| 46 |
-
df = load_data()
|
| 47 |
|
| 48 |
# Get unique values for filters
|
| 49 |
tech_choices = sorted(list(df['tech'].unique()))
|
|
@@ -160,40 +175,137 @@ def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, c
|
|
| 160 |
trace_info = f"traces: {len(filtered_df[color_by].unique())}"
|
| 161 |
return fig, f"{status_msg} | filtered samples: {len(filtered_df)} | {coord_info} | {trace_info}"
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# UI
|
| 164 |
with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
| 165 |
-
gr.Markdown("# 🔬 LWM-Spectro Interactive
|
| 166 |
-
gr.Markdown(""
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
if __name__ == "__main__":
|
| 199 |
demo.launch()
|
|
|
|
| 1 |
|
|
|
|
| 2 |
import inspect
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import huggingface_hub as hf_hub
|
| 7 |
|
| 8 |
# Gradio <-> hub compatibility shim: newer huggingface_hub removed HfFolder.
|
| 9 |
if not hasattr(hf_hub, "HfFolder"):
|
|
|
|
| 25 |
from sklearn.manifold import TSNE
|
| 26 |
from sklearn.decomposition import PCA
|
| 27 |
from sklearn.preprocessing import StandardScaler
|
| 28 |
+
from sklearn.metrics import confusion_matrix, f1_score
|
| 29 |
import matplotlib.pyplot as plt
|
| 30 |
|
| 31 |
+
# Repo root for local imports
|
| 32 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 33 |
+
if str(REPO_ROOT) not in sys.path:
|
| 34 |
+
sys.path.append(str(REPO_ROOT))
|
| 35 |
+
|
| 36 |
+
from mixture.train_embedding_router import MoEPredictor # type: ignore
|
| 37 |
+
|
| 38 |
+
# ------------------------------------------------------------------------------
|
| 39 |
+
# Data loading (t-SNE + evaluation)
|
| 40 |
+
# ------------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
# Load data
|
| 43 |
def load_data():
|
| 44 |
print("Loading data...")
|
|
|
|
| 56 |
})
|
| 57 |
df = pd.DataFrame(records)
|
| 58 |
print(f"Loaded {len(df)} samples.")
|
| 59 |
+
return df, data
|
| 60 |
|
| 61 |
+
df, raw_samples = load_data()
|
| 62 |
|
| 63 |
# Get unique values for filters
|
| 64 |
tech_choices = sorted(list(df['tech'].unique()))
|
|
|
|
| 175 |
trace_info = f"traces: {len(filtered_df[color_by].unique())}"
|
| 176 |
return fig, f"{status_msg} | filtered samples: {len(filtered_df)} | {coord_info} | {trace_info}"
|
| 177 |
|
| 178 |
+
# ------------------------------------------------------------------------------
|
| 179 |
+
# Evaluation utilities (confusion matrix, F1) using the MoE checkpoint
|
| 180 |
+
# ------------------------------------------------------------------------------
|
| 181 |
+
|
| 182 |
+
_predictor: MoEPredictor | None = None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def load_predictor() -> MoEPredictor:
|
| 186 |
+
global _predictor
|
| 187 |
+
if _predictor is not None:
|
| 188 |
+
return _predictor
|
| 189 |
+
|
| 190 |
+
# Prefer local checkpoint if present; otherwise pull from Hub
|
| 191 |
+
candidates = [
|
| 192 |
+
REPO_ROOT / "mixture" / "runs" / "embedding_router" / "moe_checkpoint.pth",
|
| 193 |
+
REPO_ROOT / "moe_checkpoint.pth",
|
| 194 |
+
]
|
| 195 |
+
ckpt_path = None
|
| 196 |
+
for cand in candidates:
|
| 197 |
+
if cand.exists():
|
| 198 |
+
ckpt_path = cand
|
| 199 |
+
break
|
| 200 |
+
if ckpt_path is None:
|
| 201 |
+
ckpt_path = Path(
|
| 202 |
+
hf_hub.hf_hub_download(repo_id="wi-lab/lwm-spectro", filename="moe_checkpoint.pth")
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
_predictor = MoEPredictor.from_checkpoint(ckpt_path)
|
| 206 |
+
return _predictor
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _to_tensor(spec) -> torch.Tensor:
|
| 210 |
+
t = spec
|
| 211 |
+
if not isinstance(t, torch.Tensor):
|
| 212 |
+
t = torch.as_tensor(t)
|
| 213 |
+
if t.dim() == 2:
|
| 214 |
+
t = t.unsqueeze(0)
|
| 215 |
+
return t
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def compute_eval(task: str):
|
| 219 |
+
"""Compute confusion matrix + macro F1 for the small demo set."""
|
| 220 |
+
predictor = load_predictor()
|
| 221 |
+
y_true, y_pred = [], []
|
| 222 |
+
|
| 223 |
+
for sample in raw_samples:
|
| 224 |
+
spec = _to_tensor(sample["data"])
|
| 225 |
+
res = predictor.predict(spec, return_routing=True)
|
| 226 |
+
|
| 227 |
+
if task == "comm":
|
| 228 |
+
routing = res.get("routing") or []
|
| 229 |
+
pred = routing[0]["comm"] if routing else "Unknown"
|
| 230 |
+
true = sample["tech"]
|
| 231 |
+
else: # snr_mobility
|
| 232 |
+
pred = res.get("label", res["predicted_class"])
|
| 233 |
+
true = (sample["snr"], sample["mob"])
|
| 234 |
+
y_true.append(true)
|
| 235 |
+
y_pred.append(pred)
|
| 236 |
+
|
| 237 |
+
labels = sorted(list({*y_true, *y_pred}))
|
| 238 |
+
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 239 |
+
f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
|
| 240 |
+
acc = (np.array(y_true) == np.array(y_pred)).mean()
|
| 241 |
+
return cm, labels, f1, acc
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def plot_confusion(cm: np.ndarray, labels):
|
| 245 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 246 |
+
im = ax.imshow(cm, cmap="Blues")
|
| 247 |
+
ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 248 |
+
ax.set_xticks(np.arange(len(labels)), labels=labels, rotation=45, ha="right")
|
| 249 |
+
ax.set_yticks(np.arange(len(labels)), labels=labels)
|
| 250 |
+
ax.set_xlabel("Predicted")
|
| 251 |
+
ax.set_ylabel("True")
|
| 252 |
+
for i in range(cm.shape[0]):
|
| 253 |
+
for j in range(cm.shape[1]):
|
| 254 |
+
ax.text(j, i, int(cm[i, j]), ha="center", va="center", color="black")
|
| 255 |
+
fig.tight_layout()
|
| 256 |
+
return fig
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def run_eval(task):
|
| 260 |
+
cm, labels, f1, acc = compute_eval(task)
|
| 261 |
+
fig = plot_confusion(cm, labels)
|
| 262 |
+
summary = f"Task: {task} | Accuracy: {acc:.4f} | Macro F1: {f1:.4f}"
|
| 263 |
+
return fig, summary
|
| 264 |
+
|
| 265 |
+
|
| 266 |
# UI
|
| 267 |
with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
| 268 |
+
gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
|
| 269 |
+
gr.Markdown("Compare embeddings vs raw for t-SNE, and view quick metrics from the latest MoE checkpoint.")
|
| 270 |
+
|
| 271 |
+
with gr.Tab("t-SNE"):
|
| 272 |
+
with gr.Row():
|
| 273 |
+
with gr.Column(scale=1, min_width=300):
|
| 274 |
+
gr.Markdown("### Filters")
|
| 275 |
+
tech_filter = gr.CheckboxGroup(choices=tech_choices, value=tech_choices[:1], label="Technology (default: single tech)")
|
| 276 |
+
snr_filter = gr.Dropdown(choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)")
|
| 277 |
+
mod_filter = gr.Dropdown(choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)")
|
| 278 |
+
mob_filter = gr.Dropdown(choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)")
|
| 279 |
+
|
| 280 |
+
gr.Markdown("### Visualization Settings")
|
| 281 |
+
representation = gr.Radio(choices=["LWM Embedding", "Raw Spectrogram"], value="LWM Embedding", label="Representation")
|
| 282 |
+
color_by = gr.Dropdown(choices=["tech", "snr", "mod", "mob"], value="snr", label="Color By")
|
| 283 |
+
|
| 284 |
+
with gr.Accordion("Advanced t-SNE Settings", open=False):
|
| 285 |
+
perplexity = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Perplexity")
|
| 286 |
+
n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
|
| 287 |
+
|
| 288 |
+
btn = gr.Button("Update Plot", variant="primary")
|
| 289 |
+
status = gr.Textbox(label="Status", interactive=False)
|
| 290 |
+
|
| 291 |
+
with gr.Column(scale=3):
|
| 292 |
+
plot = gr.Plot(label="t-SNE Visualization")
|
| 293 |
+
|
| 294 |
+
btn.click(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
|
| 295 |
+
|
| 296 |
+
# Initial load
|
| 297 |
+
demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
|
| 298 |
+
|
| 299 |
+
with gr.Tab("Evaluation (MoE)"):
|
| 300 |
+
gr.Markdown("Uses the latest MoE checkpoint to score the bundled demo set. Communication uses router gating; SNR/Mobility uses the classifier head.")
|
| 301 |
+
task_choice = gr.Radio(choices=["comm", "snr_mobility"], value="snr_mobility", label="Task")
|
| 302 |
+
eval_btn = gr.Button("Run Evaluation", variant="primary")
|
| 303 |
+
cm_plot = gr.Plot(label="Confusion Matrix")
|
| 304 |
+
eval_summary = gr.Textbox(label="Metrics", interactive=False)
|
| 305 |
+
|
| 306 |
+
eval_btn.click(run_eval, inputs=[task_choice], outputs=[cm_plot, eval_summary])
|
| 307 |
+
# Run once on load for convenience
|
| 308 |
+
demo.load(run_eval, inputs=[task_choice], outputs=[cm_plot, eval_summary])
|
| 309 |
|
| 310 |
if __name__ == "__main__":
|
| 311 |
demo.launch()
|