wi-lab commited on
Commit
0205171
·
1 Parent(s): 6979df5

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -220,9 +220,15 @@ def compute_eval(task: str):
220
  predictor = load_predictor()
221
  y_true, y_pred = [], []
222
 
223
- for sample in raw_samples:
 
224
  spec = _to_tensor(sample["data"])
225
- res = predictor.predict(spec, return_routing=True)
 
 
 
 
 
226
 
227
  if task == "comm":
228
  routing = res.get("routing") or []
@@ -238,7 +244,7 @@ def compute_eval(task: str):
238
  cm = confusion_matrix(y_true, y_pred, labels=labels)
239
  f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
240
  acc = (np.array(y_true) == np.array(y_pred)).mean()
241
- return cm, labels, f1, acc
242
 
243
 
244
  def plot_confusion(cm: np.ndarray, labels):
@@ -257,13 +263,15 @@ def plot_confusion(cm: np.ndarray, labels):
257
 
258
 
259
  def run_eval(task):
260
- cm, labels, f1, acc = compute_eval(task)
261
  fig = plot_confusion(cm, labels)
262
- summary = f"Task: {task} | Accuracy: {acc:.4f} | Macro F1: {f1:.4f}"
263
  return fig, summary
264
 
265
 
 
266
  # UI
 
267
  with gr.Blocks(title="LWM-Spectro Demo") as demo:
268
  gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
269
  gr.Markdown("Compare embeddings vs raw for t-SNE, and view quick metrics from the latest MoE checkpoint.")
@@ -297,15 +305,21 @@ with gr.Blocks(title="LWM-Spectro Demo") as demo:
297
  demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
298
 
299
  with gr.Tab("Evaluation (MoE)"):
300
- gr.Markdown("Uses the latest MoE checkpoint to score the bundled demo set. Communication uses router gating; SNR/Mobility uses the classifier head.")
301
  task_choice = gr.Radio(choices=["comm", "snr_mobility"], value="snr_mobility", label="Task")
302
  eval_btn = gr.Button("Run Evaluation", variant="primary")
303
  cm_plot = gr.Plot(label="Confusion Matrix")
304
  eval_summary = gr.Textbox(label="Metrics", interactive=False)
305
 
306
- eval_btn.click(run_eval, inputs=[task_choice], outputs=[cm_plot, eval_summary])
 
 
 
 
 
 
307
  # Run once on load for convenience
308
- demo.load(run_eval, inputs=[task_choice], outputs=[cm_plot, eval_summary])
309
 
310
  if __name__ == "__main__":
311
  demo.launch()
 
220
  predictor = load_predictor()
221
  y_true, y_pred = [], []
222
 
223
+ max_samples = min(len(raw_samples), 500) # keep eval lightweight in Spaces
224
+ for sample in raw_samples[:max_samples]:
225
  spec = _to_tensor(sample["data"])
226
+ try:
227
+ res = predictor.predict(spec, return_routing=True)
228
+ except Exception as exc:
229
+ # Skip problematic samples but keep going
230
+ print(f"[WARN] predict failed: {exc}")
231
+ continue
232
 
233
  if task == "comm":
234
  routing = res.get("routing") or []
 
244
  cm = confusion_matrix(y_true, y_pred, labels=labels)
245
  f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
246
  acc = (np.array(y_true) == np.array(y_pred)).mean()
247
+ return cm, labels, f1, acc, len(y_true)
248
 
249
 
250
  def plot_confusion(cm: np.ndarray, labels):
 
263
 
264
 
265
  def run_eval(task):
266
+ cm, labels, f1, acc, n = compute_eval(task)
267
  fig = plot_confusion(cm, labels)
268
+ summary = f"Task: {task} | Samples: {n} | Accuracy: {acc:.4f} | Macro F1: {f1:.4f}"
269
  return fig, summary
270
 
271
 
272
+ # ------------------------------------------------------------------------------
273
  # UI
274
+ # ------------------------------------------------------------------------------
275
  with gr.Blocks(title="LWM-Spectro Demo") as demo:
276
  gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
277
  gr.Markdown("Compare embeddings vs raw for t-SNE, and view quick metrics from the latest MoE checkpoint.")
 
305
  demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
306
 
307
  with gr.Tab("Evaluation (MoE)"):
308
+ gr.Markdown("Uses the latest MoE checkpoint to score the bundled demo set.\n\n- **comm**: predicts communication type (LTE/WiFi/5G) via router gating.\n- **snr_mobility**: predicts the SNR/Mobility class via the classifier head.")
309
  task_choice = gr.Radio(choices=["comm", "snr_mobility"], value="snr_mobility", label="Task")
310
  eval_btn = gr.Button("Run Evaluation", variant="primary")
311
  cm_plot = gr.Plot(label="Confusion Matrix")
312
  eval_summary = gr.Textbox(label="Metrics", interactive=False)
313
 
314
+ def _safe_run(task):
315
+ try:
316
+ return run_eval(task)
317
+ except Exception as exc:
318
+ return None, f"Error during evaluation: {exc}"
319
+
320
+ eval_btn.click(_safe_run, inputs=[task_choice], outputs=[cm_plot, eval_summary])
321
  # Run once on load for convenience
322
+ demo.load(_safe_run, inputs=[task_choice], outputs=[cm_plot, eval_summary])
323
 
324
  if __name__ == "__main__":
325
  demo.launch()