Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -270,12 +270,19 @@ def _to_tensor(spec) -> torch.Tensor:
|
|
| 270 |
return t
|
| 271 |
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
def compute_eval(task: str):
|
| 274 |
"""Compute confusion matrix + macro F1 for the small demo set."""
|
| 275 |
predictor = load_predictor()
|
| 276 |
y_true, y_pred = [], []
|
| 277 |
|
| 278 |
-
max_samples = min(len(raw_samples),
|
| 279 |
for sample in raw_samples[:max_samples]:
|
| 280 |
spec = _to_tensor(sample["data"])
|
| 281 |
try:
|
|
@@ -287,11 +294,12 @@ def compute_eval(task: str):
|
|
| 287 |
|
| 288 |
if task == "comm":
|
| 289 |
routing = res.get("routing") or []
|
| 290 |
-
pred = routing[0]["comm"] if routing else "Unknown"
|
| 291 |
-
true = sample["tech"]
|
| 292 |
else: # snr_mobility
|
| 293 |
-
|
| 294 |
-
|
|
|
|
| 295 |
y_true.append(true)
|
| 296 |
y_pred.append(pred)
|
| 297 |
|
|
|
|
| 270 |
return t
|
| 271 |
|
| 272 |
|
| 273 |
+
def _normalize_label(val):
|
| 274 |
+
"""Convert labels to hashable, comparable form."""
|
| 275 |
+
if isinstance(val, (list, tuple)):
|
| 276 |
+
return tuple(str(v) for v in val)
|
| 277 |
+
return str(val)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
def compute_eval(task: str):
|
| 281 |
"""Compute confusion matrix + macro F1 for the small demo set."""
|
| 282 |
predictor = load_predictor()
|
| 283 |
y_true, y_pred = [], []
|
| 284 |
|
| 285 |
+
max_samples = min(len(raw_samples), 200) # keep eval lightweight in Spaces
|
| 286 |
for sample in raw_samples[:max_samples]:
|
| 287 |
spec = _to_tensor(sample["data"])
|
| 288 |
try:
|
|
|
|
| 294 |
|
| 295 |
if task == "comm":
|
| 296 |
routing = res.get("routing") or []
|
| 297 |
+
pred = _normalize_label(routing[0]["comm"]) if routing else "Unknown"
|
| 298 |
+
true = _normalize_label(sample["tech"])
|
| 299 |
else: # snr_mobility
|
| 300 |
+
pred_raw = res.get("label", res["predicted_class"])
|
| 301 |
+
pred = _normalize_label(pred_raw)
|
| 302 |
+
true = _normalize_label((sample["snr"], sample["mob"]))
|
| 303 |
y_true.append(true)
|
| 304 |
y_pred.append(pred)
|
| 305 |
|