wi-lab commited on
Commit
d9e7eb5
·
1 Parent(s): be20f10

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -278,17 +278,37 @@ def _normalize_label(val):
278
 
279
 
280
  def compute_eval(task: str):
281
- """Compute confusion matrix + macro F1 for the small demo set."""
282
  predictor = load_predictor()
283
  y_true, y_pred = [], []
284
 
285
- max_samples = min(len(raw_samples), 200) # keep eval lightweight in Spaces
286
- for sample in raw_samples[:max_samples]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  spec = _to_tensor(sample["data"])
288
  try:
289
  res = predictor.predict(spec, return_routing=True)
290
  except Exception as exc:
291
- # Skip problematic samples but keep going
292
  print(f"[WARN] predict failed: {exc}")
293
  continue
294
 
 
278
 
279
 
280
  def compute_eval(task: str):
281
+ """Compute confusion matrix + macro F1 with balanced sampling per class."""
282
  predictor = load_predictor()
283
  y_true, y_pred = [], []
284
 
285
+ # Balanced sampling per class
286
+ rng = random.Random(42)
287
+ per_class_target = 100
288
+
289
+ def class_key(sample):
290
+ if task == "comm":
291
+ return _normalize_label(sample["tech"])
292
+ return _normalize_label((sample["snr"], sample["mob"]))
293
+
294
+ buckets = {}
295
+ for s in raw_samples:
296
+ key = class_key(s)
297
+ buckets.setdefault(key, []).append(s)
298
+
299
+ selected = []
300
+ for key, items in buckets.items():
301
+ rng.shuffle(items)
302
+ take = min(per_class_target, len(items))
303
+ selected.extend(items[:take])
304
+
305
+ rng.shuffle(selected)
306
+
307
+ for sample in selected:
308
  spec = _to_tensor(sample["data"])
309
  try:
310
  res = predictor.predict(spec, return_routing=True)
311
  except Exception as exc:
 
312
  print(f"[WARN] predict failed: {exc}")
313
  continue
314