wi-lab commited on
Commit
329b49f
·
1 Parent(s): b12ae03

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +147 -35
app.py CHANGED
@@ -1,6 +1,9 @@
1
 
2
- import huggingface_hub as hf_hub
3
  import inspect
 
 
 
 
4
 
5
  # Gradio <-> hub compatibility shim: newer huggingface_hub removed HfFolder.
6
  if not hasattr(hf_hub, "HfFolder"):
@@ -22,8 +25,20 @@ import pandas as pd
22
  from sklearn.manifold import TSNE
23
  from sklearn.decomposition import PCA
24
  from sklearn.preprocessing import StandardScaler
 
25
  import matplotlib.pyplot as plt
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Load data
28
  def load_data():
29
  print("Loading data...")
@@ -41,9 +56,9 @@ def load_data():
41
  })
42
  df = pd.DataFrame(records)
43
  print(f"Loaded {len(df)} samples.")
44
- return df
45
 
46
- df = load_data()
47
 
48
  # Get unique values for filters
49
  tech_choices = sorted(list(df['tech'].unique()))
@@ -160,40 +175,137 @@ def plot_tsne(tech_filter, snr_filter, mod_filter, mob_filter, representation, c
160
  trace_info = f"traces: {len(filtered_df[color_by].unique())}"
161
  return fig, f"{status_msg} | filtered samples: {len(filtered_df)} | {coord_info} | {trace_info}"
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # UI
164
  with gr.Blocks(title="LWM-Spectro Demo") as demo:
165
- gr.Markdown("# 🔬 LWM-Spectro Interactive t-SNE Demo")
166
- gr.Markdown("""
167
- Visualise t-SNE just like the local `task1/plot_tsne.py` script:
168
- standardised inputs, per-sample normalisation, and SNR/mod/tech/mob colour options.
169
- """)
170
-
171
- with gr.Row():
172
- with gr.Column(scale=1, min_width=300):
173
- gr.Markdown("### Filters")
174
- tech_filter = gr.CheckboxGroup(choices=tech_choices, value=tech_choices[:1], label="Technology (default: single tech)")
175
- snr_filter = gr.Dropdown(choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)")
176
- mod_filter = gr.Dropdown(choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)")
177
- mob_filter = gr.Dropdown(choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)")
178
-
179
- gr.Markdown("### Visualization Settings")
180
- representation = gr.Radio(choices=["LWM Embedding", "Raw Spectrogram"], value="LWM Embedding", label="Representation")
181
- color_by = gr.Dropdown(choices=["tech", "snr", "mod", "mob"], value="snr", label="Color By")
182
-
183
- with gr.Accordion("Advanced t-SNE Settings", open=False):
184
- perplexity = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Perplexity")
185
- n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
186
-
187
- btn = gr.Button("Update Plot", variant="primary")
188
- status = gr.Textbox(label="Status", interactive=False)
189
-
190
- with gr.Column(scale=3):
191
- plot = gr.Plot(label="t-SNE Visualization")
192
-
193
- btn.click(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
194
-
195
- # Initial load
196
- demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
 
 
 
 
 
 
 
 
 
197
 
198
  if __name__ == "__main__":
199
  demo.launch()
 
1
 
 
2
  import inspect
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import huggingface_hub as hf_hub
7
 
8
  # Gradio <-> hub compatibility shim: newer huggingface_hub removed HfFolder.
9
  if not hasattr(hf_hub, "HfFolder"):
 
25
  from sklearn.manifold import TSNE
26
  from sklearn.decomposition import PCA
27
  from sklearn.preprocessing import StandardScaler
28
+ from sklearn.metrics import confusion_matrix, f1_score
29
  import matplotlib.pyplot as plt
30
 
31
+ # Repo root for local imports
32
+ REPO_ROOT = Path(__file__).resolve().parent.parent
33
+ if str(REPO_ROOT) not in sys.path:
34
+ sys.path.append(str(REPO_ROOT))
35
+
36
+ from mixture.train_embedding_router import MoEPredictor # type: ignore
37
+
38
+ # ------------------------------------------------------------------------------
39
+ # Data loading (t-SNE + evaluation)
40
+ # ------------------------------------------------------------------------------
41
+
42
  # Load data
43
  def load_data():
44
  print("Loading data...")
 
56
  })
57
  df = pd.DataFrame(records)
58
  print(f"Loaded {len(df)} samples.")
59
+ return df, data
60
 
61
+ df, raw_samples = load_data()
62
 
63
  # Get unique values for filters
64
  tech_choices = sorted(list(df['tech'].unique()))
 
175
  trace_info = f"traces: {len(filtered_df[color_by].unique())}"
176
  return fig, f"{status_msg} | filtered samples: {len(filtered_df)} | {coord_info} | {trace_info}"
177
 
178
+ # ------------------------------------------------------------------------------
179
+ # Evaluation utilities (confusion matrix, F1) using the MoE checkpoint
180
+ # ------------------------------------------------------------------------------
181
+
182
+ _predictor: MoEPredictor | None = None
183
+
184
+
185
+ def load_predictor() -> MoEPredictor:
186
+ global _predictor
187
+ if _predictor is not None:
188
+ return _predictor
189
+
190
+ # Prefer local checkpoint if present; otherwise pull from Hub
191
+ candidates = [
192
+ REPO_ROOT / "mixture" / "runs" / "embedding_router" / "moe_checkpoint.pth",
193
+ REPO_ROOT / "moe_checkpoint.pth",
194
+ ]
195
+ ckpt_path = None
196
+ for cand in candidates:
197
+ if cand.exists():
198
+ ckpt_path = cand
199
+ break
200
+ if ckpt_path is None:
201
+ ckpt_path = Path(
202
+ hf_hub.hf_hub_download(repo_id="wi-lab/lwm-spectro", filename="moe_checkpoint.pth")
203
+ )
204
+
205
+ _predictor = MoEPredictor.from_checkpoint(ckpt_path)
206
+ return _predictor
207
+
208
+
209
+ def _to_tensor(spec) -> torch.Tensor:
210
+ t = spec
211
+ if not isinstance(t, torch.Tensor):
212
+ t = torch.as_tensor(t)
213
+ if t.dim() == 2:
214
+ t = t.unsqueeze(0)
215
+ return t
216
+
217
+
218
+ def compute_eval(task: str):
219
+ """Compute confusion matrix + macro F1 for the small demo set."""
220
+ predictor = load_predictor()
221
+ y_true, y_pred = [], []
222
+
223
+ for sample in raw_samples:
224
+ spec = _to_tensor(sample["data"])
225
+ res = predictor.predict(spec, return_routing=True)
226
+
227
+ if task == "comm":
228
+ routing = res.get("routing") or []
229
+ pred = routing[0]["comm"] if routing else "Unknown"
230
+ true = sample["tech"]
231
+ else: # snr_mobility
232
+ pred = res.get("label", res["predicted_class"])
233
+ true = (sample["snr"], sample["mob"])
234
+ y_true.append(true)
235
+ y_pred.append(pred)
236
+
237
+ labels = sorted(list({*y_true, *y_pred}))
238
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
239
+ f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)
240
+ acc = (np.array(y_true) == np.array(y_pred)).mean()
241
+ return cm, labels, f1, acc
242
+
243
+
244
+ def plot_confusion(cm: np.ndarray, labels):
245
+ fig, ax = plt.subplots(figsize=(6, 5))
246
+ im = ax.imshow(cm, cmap="Blues")
247
+ ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
248
+ ax.set_xticks(np.arange(len(labels)), labels=labels, rotation=45, ha="right")
249
+ ax.set_yticks(np.arange(len(labels)), labels=labels)
250
+ ax.set_xlabel("Predicted")
251
+ ax.set_ylabel("True")
252
+ for i in range(cm.shape[0]):
253
+ for j in range(cm.shape[1]):
254
+ ax.text(j, i, int(cm[i, j]), ha="center", va="center", color="black")
255
+ fig.tight_layout()
256
+ return fig
257
+
258
+
259
+ def run_eval(task):
260
+ cm, labels, f1, acc = compute_eval(task)
261
+ fig = plot_confusion(cm, labels)
262
+ summary = f"Task: {task} | Accuracy: {acc:.4f} | Macro F1: {f1:.4f}"
263
+ return fig, summary
264
+
265
+
266
  # UI
267
  with gr.Blocks(title="LWM-Spectro Demo") as demo:
268
+ gr.Markdown("# 🔬 LWM-Spectro Interactive Demo")
269
+ gr.Markdown("Compare embeddings vs raw for t-SNE, and view quick metrics from the latest MoE checkpoint.")
270
+
271
+ with gr.Tab("t-SNE"):
272
+ with gr.Row():
273
+ with gr.Column(scale=1, min_width=300):
274
+ gr.Markdown("### Filters")
275
+ tech_filter = gr.CheckboxGroup(choices=tech_choices, value=tech_choices[:1], label="Technology (default: single tech)")
276
+ snr_filter = gr.Dropdown(choices=snr_choices, value=None, multiselect=True, label="SNR (Empty = All)")
277
+ mod_filter = gr.Dropdown(choices=mod_choices, value=None, multiselect=True, label="Modulation (Empty = All)")
278
+ mob_filter = gr.Dropdown(choices=mob_choices, value=None, multiselect=True, label="Mobility (Empty = All)")
279
+
280
+ gr.Markdown("### Visualization Settings")
281
+ representation = gr.Radio(choices=["LWM Embedding", "Raw Spectrogram"], value="LWM Embedding", label="Representation")
282
+ color_by = gr.Dropdown(choices=["tech", "snr", "mod", "mob"], value="snr", label="Color By")
283
+
284
+ with gr.Accordion("Advanced t-SNE Settings", open=False):
285
+ perplexity = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Perplexity")
286
+ n_iter = gr.Slider(minimum=250, maximum=2000, value=1000, step=50, label="Iterations")
287
+
288
+ btn = gr.Button("Update Plot", variant="primary")
289
+ status = gr.Textbox(label="Status", interactive=False)
290
+
291
+ with gr.Column(scale=3):
292
+ plot = gr.Plot(label="t-SNE Visualization")
293
+
294
+ btn.click(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
295
+
296
+ # Initial load
297
+ demo.load(plot_tsne, inputs=[tech_filter, snr_filter, mod_filter, mob_filter, representation, color_by, perplexity, n_iter], outputs=[plot, status])
298
+
299
+ with gr.Tab("Evaluation (MoE)"):
300
+ gr.Markdown("Uses the latest MoE checkpoint to score the bundled demo set. Communication uses router gating; SNR/Mobility uses the classifier head.")
301
+ task_choice = gr.Radio(choices=["comm", "snr_mobility"], value="snr_mobility", label="Task")
302
+ eval_btn = gr.Button("Run Evaluation", variant="primary")
303
+ cm_plot = gr.Plot(label="Confusion Matrix")
304
+ eval_summary = gr.Textbox(label="Metrics", interactive=False)
305
+
306
+ eval_btn.click(run_eval, inputs=[task_choice], outputs=[cm_plot, eval_summary])
307
+ # Run once on load for convenience
308
+ demo.load(run_eval, inputs=[task_choice], outputs=[cm_plot, eval_summary])
309
 
310
  if __name__ == "__main__":
311
  demo.launch()