Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -220,9 +220,15 @@ def compute_eval(task: str):
|
|
| 220 |
predictor = load_predictor()
|
| 221 |
y_true, y_pred = [], []
|
| 222 |
|
| 223 |
-
|
|
|
|
| 224 |
spec = _to_tensor(sample["data"])
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
if task == "comm":
|
| 228 |
routing = res.get("routing") or []
|
|
@@ -238,7 +244,7 @@ def compute_eval(task: str):
|
|
| 238 |
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 239 |
f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
|
| 240 |
acc = (np.array(y_true) == np.array(y_pred)).mean()
|
| 241 |
-
return cm, labels, f1, acc
|
| 242 |
|
| 243 |
|
| 244 |
def plot_confusion(cm: np.ndarray, labels):
|
|
@@ -257,13 +263,15 @@ def plot_confusion(cm: np.ndarray, labels):
|
|
| 257 |
|
| 258 |
|
| 259 |
def run_eval(task):
|
| 260 |
-
cm, labels, f1, acc = compute_eval(task)
|
| 261 |
fig = plot_confusion(cm, labels)
|
| 262 |
-
summary = f"Task: {task} | Accuracy: {acc:.4f} | Macro F1: {f1:.4f}"
|
| 263 |
return fig, summary
|
| 264 |
|
| 265 |
|
|
|
|
| 266 |
# UI
|
|
|
|
| 267 |
with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
| 268 |
gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
|
| 269 |
gr.Markdown("Compare embeddings vs raw for t-SNE, and view quick metrics from the latest MoE checkpoint.")
|
|
@@ -297,15 +305,21 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
|
| 297 |
demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
|
| 298 |
|
| 299 |
with gr.Tab("Evaluation (MoE)"):
|
| 300 |
-
gr.Markdown("Uses the latest MoE checkpoint to score the bundled demo set
|
| 301 |
task_choice = gr.Radio(choices=["comm", "snr_mobility"], value="snr_mobility", label="Task")
|
| 302 |
eval_btn = gr.Button("Run Evaluation", variant="primary")
|
| 303 |
cm_plot = gr.Plot(label="Confusion Matrix")
|
| 304 |
eval_summary = gr.Textbox(label="Metrics", interactive=False)
|
| 305 |
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# Run once on load for convenience
|
| 308 |
-
demo.load(
|
| 309 |
|
| 310 |
if __name__ == "__main__":
|
| 311 |
demo.launch()
|
|
|
|
| 220 |
predictor = load_predictor()
|
| 221 |
y_true, y_pred = [], []
|
| 222 |
|
| 223 |
+
max_samples = min(len(raw_samples), 500) # keep eval lightweight in Spaces
|
| 224 |
+
for sample in raw_samples[:max_samples]:
|
| 225 |
spec = _to_tensor(sample["data"])
|
| 226 |
+
try:
|
| 227 |
+
res = predictor.predict(spec, return_routing=True)
|
| 228 |
+
except Exception as exc:
|
| 229 |
+
# Skip problematic samples but keep going
|
| 230 |
+
print(f"[WARN] predict failed: {exc}")
|
| 231 |
+
continue
|
| 232 |
|
| 233 |
if task == "comm":
|
| 234 |
routing = res.get("routing") or []
|
|
|
|
| 244 |
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 245 |
f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
|
| 246 |
acc = (np.array(y_true) == np.array(y_pred)).mean()
|
| 247 |
+
return cm, labels, f1, acc, len(y_true)
|
| 248 |
|
| 249 |
|
| 250 |
def plot_confusion(cm: np.ndarray, labels):
|
|
|
|
| 263 |
|
| 264 |
|
| 265 |
def run_eval(task):
|
| 266 |
+
cm, labels, f1, acc, n = compute_eval(task)
|
| 267 |
fig = plot_confusion(cm, labels)
|
| 268 |
+
summary = f"Task: {task} | Samples: {n} | Accuracy: {acc:.4f} | Macro F1: {f1:.4f}"
|
| 269 |
return fig, summary
|
| 270 |
|
| 271 |
|
| 272 |
+
# ------------------------------------------------------------------------------
|
| 273 |
# UI
|
| 274 |
+
# ------------------------------------------------------------------------------
|
| 275 |
with gr.Blocks(title="LWM-Spectro Demo") as demo:
|
| 276 |
gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
|
| 277 |
gr.Markdown("Compare embeddings vs raw for t-SNE, and view quick metrics from the latest MoE checkpoint.")
|
|
|
|
| 305 |
demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
|
| 306 |
|
| 307 |
with gr.Tab("Evaluation (MoE)"):
|
| 308 |
+
gr.Markdown("Uses the latest MoE checkpoint to score the bundled demo set.\n\n- **comm**: predicts communication type (LTE/WiFi/5G) via router gating.\n- **snr_mobility**: predicts the SNR/Mobility class via the classifier head.")
|
| 309 |
task_choice = gr.Radio(choices=["comm", "snr_mobility"], value="snr_mobility", label="Task")
|
| 310 |
eval_btn = gr.Button("Run Evaluation", variant="primary")
|
| 311 |
cm_plot = gr.Plot(label="Confusion Matrix")
|
| 312 |
eval_summary = gr.Textbox(label="Metrics", interactive=False)
|
| 313 |
|
| 314 |
+
def _safe_run(task):
|
| 315 |
+
try:
|
| 316 |
+
return run_eval(task)
|
| 317 |
+
except Exception as exc:
|
| 318 |
+
return None, f"Error during evaluation: {exc}"
|
| 319 |
+
|
| 320 |
+
eval_btn.click(_safe_run, inputs=[task_choice], outputs=[cm_plot, eval_summary])
|
| 321 |
# Run once on load for convenience
|
| 322 |
+
demo.load(_safe_run, inputs=[task_choice], outputs=[cm_plot, eval_summary])
|
| 323 |
|
| 324 |
if __name__ == "__main__":
|
| 325 |
demo.launch()
|