Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -278,17 +278,37 @@ def _normalize_label(val):
|
|
| 278 |
|
| 279 |
|
| 280 |
def compute_eval(task: str):
|
| 281 |
-
"""Compute confusion matrix + macro F1
|
| 282 |
predictor = load_predictor()
|
| 283 |
y_true, y_pred = [], []
|
| 284 |
|
| 285 |
-
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
spec = _to_tensor(sample["data"])
|
| 288 |
try:
|
| 289 |
res = predictor.predict(spec, return_routing=True)
|
| 290 |
except Exception as exc:
|
| 291 |
-
# Skip problematic samples but keep going
|
| 292 |
print(f"[WARN] predict failed: {exc}")
|
| 293 |
continue
|
| 294 |
|
|
|
|
| 278 |
|
| 279 |
|
| 280 |
def compute_eval(task: str):
|
| 281 |
+
"""Compute confusion matrix + macro F1 with balanced sampling per class."""
|
| 282 |
predictor = load_predictor()
|
| 283 |
y_true, y_pred = [], []
|
| 284 |
|
| 285 |
+
# Balanced sampling per class
|
| 286 |
+
rng = random.Random(42)
|
| 287 |
+
per_class_target = 100
|
| 288 |
+
|
| 289 |
+
def class_key(sample):
|
| 290 |
+
if task == "comm":
|
| 291 |
+
return _normalize_label(sample["tech"])
|
| 292 |
+
return _normalize_label((sample["snr"], sample["mob"]))
|
| 293 |
+
|
| 294 |
+
buckets = {}
|
| 295 |
+
for s in raw_samples:
|
| 296 |
+
key = class_key(s)
|
| 297 |
+
buckets.setdefault(key, []).append(s)
|
| 298 |
+
|
| 299 |
+
selected = []
|
| 300 |
+
for key, items in buckets.items():
|
| 301 |
+
rng.shuffle(items)
|
| 302 |
+
take = min(per_class_target, len(items))
|
| 303 |
+
selected.extend(items[:take])
|
| 304 |
+
|
| 305 |
+
rng.shuffle(selected)
|
| 306 |
+
|
| 307 |
+
for sample in selected:
|
| 308 |
spec = _to_tensor(sample["data"])
|
| 309 |
try:
|
| 310 |
res = predictor.predict(spec, return_routing=True)
|
| 311 |
except Exception as exc:
|
|
|
|
| 312 |
print(f"[WARN] predict failed: {exc}")
|
| 313 |
continue
|
| 314 |
|