PPP commited on
Commit
e598ece
·
1 Parent(s): ef60390

feat(eval): add fallback statistics and failure summaries

Browse files
Files changed (1) hide show
  1. evaluation/run_evaluations.py +152 -7
evaluation/run_evaluations.py CHANGED
@@ -223,6 +223,30 @@ def _percentile(values: list[float], percentile: float) -> float:
223
  return ordered[index]
224
 
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  def evaluate_intent_accuracy() -> dict[str, Any]:
227
  dataset = _load_dataset("intent_accuracy")
228
  details = []
@@ -353,6 +377,8 @@ def evaluate_latency(repeats: int) -> dict[str, Any]:
353
  all_total = []
354
  fallback_total = 0
355
  total_runs = 0
 
 
356
 
357
  for scenario in dataset:
358
  runs = []
@@ -378,8 +404,25 @@ def evaluate_latency(repeats: int) -> dict[str, Any]:
378
  "engine_mode": telemetry.get("engine_mode"),
379
  }
380
  )
 
381
 
382
  total_values = [item["total_latency_ms"] for item in runs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  scenario_summaries.append(
384
  {
385
  "id": scenario["id"],
@@ -387,14 +430,19 @@ def evaluate_latency(repeats: int) -> dict[str, Any]:
387
  "repeats": repeats,
388
  "avg_total_latency_ms": round(statistics.mean(total_values), 2),
389
  "p95_total_latency_ms": round(_percentile(total_values, 95), 2),
390
- "fallback_rate": round(
391
- sum(1 for item in runs if item["used_fallback"]) / len(runs),
392
- 4,
 
 
 
 
393
  ),
394
  "runs": runs,
395
  }
396
  )
397
 
 
398
  return {
399
  "task": "latency",
400
  "scenario_count": len(dataset),
@@ -404,6 +452,10 @@ def evaluate_latency(repeats: int) -> dict[str, Any]:
404
  "avg_total_latency_ms": round(statistics.mean(all_total), 2) if all_total else 0.0,
405
  "p95_total_latency_ms": round(_percentile(all_total, 95), 2) if all_total else 0.0,
406
  "fallback_rate": round(fallback_total / total_runs, 4) if total_runs else 0.0,
 
 
 
 
407
  "scenarios": scenario_summaries,
408
  }
409
 
@@ -412,6 +464,8 @@ def evaluate_branch_divergence() -> dict[str, Any]:
412
  dataset = _load_dataset("branch_divergence")
413
  group_summaries = []
414
  pair_scores = []
 
 
415
 
416
  for group in dataset:
417
  branch_results = []
@@ -428,6 +482,15 @@ def evaluate_branch_divergence() -> dict[str, Any]:
428
  "telemetry": run_result["final_result"].get("telemetry", {}),
429
  }
430
  )
 
 
 
 
 
 
 
 
 
431
 
432
  group_pairs = []
433
  for left, right in combinations(branch_results, 2):
@@ -457,13 +520,23 @@ def evaluate_branch_divergence() -> dict[str, Any]:
457
  pair_scores.append(pair_score)
458
  group_pairs.append(pair_detail)
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  group_summaries.append(
461
  {
462
  "id": group["id"],
463
- "avg_pair_divergence": round(
464
- statistics.mean([pair["pair_divergence_score"] for pair in group_pairs]),
465
- 4,
466
- ) if group_pairs else 0.0,
467
  "branches": [
468
  {
469
  "label": branch["label"],
@@ -478,6 +551,7 @@ def evaluate_branch_divergence() -> dict[str, Any]:
478
  )
479
 
480
  meaningful_pairs = sum(1 for score in pair_scores if score >= 0.2)
 
481
  return {
482
  "task": "branch_divergence",
483
  "group_count": len(dataset),
@@ -486,6 +560,11 @@ def evaluate_branch_divergence() -> dict[str, Any]:
486
  meaningful_pairs / len(pair_scores),
487
  4,
488
  ) if pair_scores else 0.0,
 
 
 
 
 
489
  "groups": group_summaries,
490
  }
491
 
@@ -498,6 +577,69 @@ TASK_RUNNERS = {
498
  }
499
 
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  def _build_summary(results: dict[str, Any]) -> dict[str, Any]:
502
  summary = {}
503
  if "intent" in results:
@@ -507,8 +649,10 @@ def _build_summary(results: dict[str, Any]) -> dict[str, Any]:
507
  if "latency" in results:
508
  summary["avg_total_latency_ms"] = results["latency"]["avg_total_latency_ms"]
509
  summary["latency_fallback_rate"] = results["latency"]["fallback_rate"]
 
510
  if "branch" in results:
511
  summary["avg_pair_divergence"] = results["branch"]["avg_pair_divergence"]
 
512
  return summary
513
 
514
 
@@ -541,6 +685,7 @@ def main() -> int:
541
  "generated_at": datetime.now().isoformat(timespec="seconds"),
542
  "task": args.task,
543
  "summary": _build_summary(task_results),
 
544
  "results": task_results,
545
  }
546
 
 
223
  return ordered[index]
224
 
225
 
226
+ def _summarize_fallback_records(records: list[dict[str, Any]]) -> dict[str, Any]:
227
+ fallback_count = 0
228
+ reason_counter = Counter()
229
+ engine_counter = Counter()
230
+
231
+ for record in records:
232
+ if record.get("used_fallback"):
233
+ fallback_count += 1
234
+ reason_counter[str(record.get("fallback_reason") or "unknown")] += 1
235
+ engine_counter[str(record.get("engine_mode") or "unknown")] += 1
236
+
237
+ total = len(records)
238
+ return {
239
+ "fallback_count": fallback_count,
240
+ "fallback_rate": round(fallback_count / total, 4) if total else 0.0,
241
+ "fallback_reason_breakdown": dict(reason_counter),
242
+ "engine_mode_breakdown": dict(engine_counter),
243
+ }
244
+
245
+
246
+ def _limit_cases(cases: list[dict[str, Any]], limit: int = 5) -> list[dict[str, Any]]:
247
+ return cases[:limit]
248
+
249
+
250
  def evaluate_intent_accuracy() -> dict[str, Any]:
251
  dataset = _load_dataset("intent_accuracy")
252
  details = []
 
377
  all_total = []
378
  fallback_total = 0
379
  total_runs = 0
380
+ fallback_records = []
381
+ failure_cases = []
382
 
383
  for scenario in dataset:
384
  runs = []
 
404
  "engine_mode": telemetry.get("engine_mode"),
405
  }
406
  )
407
+ fallback_records.append(runs[-1])
408
 
409
  total_values = [item["total_latency_ms"] for item in runs]
410
+ scenario_fallback_rate = sum(1 for item in runs if item["used_fallback"]) / len(runs)
411
+ if scenario_fallback_rate > 0:
412
+ failure_cases.append(
413
+ {
414
+ "scenario_id": scenario["id"],
415
+ "input": scenario["input"],
416
+ "fallback_rate": round(scenario_fallback_rate, 4),
417
+ "fallback_reasons": dict(
418
+ Counter(
419
+ str(item.get("fallback_reason") or "unknown")
420
+ for item in runs
421
+ if item["used_fallback"]
422
+ )
423
+ ),
424
+ }
425
+ )
426
  scenario_summaries.append(
427
  {
428
  "id": scenario["id"],
 
430
  "repeats": repeats,
431
  "avg_total_latency_ms": round(statistics.mean(total_values), 2),
432
  "p95_total_latency_ms": round(_percentile(total_values, 95), 2),
433
+ "fallback_rate": round(scenario_fallback_rate, 4),
434
+ "fallback_reason_breakdown": dict(
435
+ Counter(
436
+ str(item.get("fallback_reason") or "unknown")
437
+ for item in runs
438
+ if item["used_fallback"]
439
+ )
440
  ),
441
  "runs": runs,
442
  }
443
  )
444
 
445
+ fallback_summary = _summarize_fallback_records(fallback_records)
446
  return {
447
  "task": "latency",
448
  "scenario_count": len(dataset),
 
452
  "avg_total_latency_ms": round(statistics.mean(all_total), 2) if all_total else 0.0,
453
  "p95_total_latency_ms": round(_percentile(all_total, 95), 2) if all_total else 0.0,
454
  "fallback_rate": round(fallback_total / total_runs, 4) if total_runs else 0.0,
455
+ "fallback_count": fallback_summary["fallback_count"],
456
+ "fallback_reason_breakdown": fallback_summary["fallback_reason_breakdown"],
457
+ "engine_mode_breakdown": fallback_summary["engine_mode_breakdown"],
458
+ "failure_cases": _limit_cases(failure_cases),
459
  "scenarios": scenario_summaries,
460
  }
461
 
 
464
  dataset = _load_dataset("branch_divergence")
465
  group_summaries = []
466
  pair_scores = []
467
+ fallback_records = []
468
+ low_divergence_groups = []
469
 
470
  for group in dataset:
471
  branch_results = []
 
482
  "telemetry": run_result["final_result"].get("telemetry", {}),
483
  }
484
  )
485
+ fallback_records.append(
486
+ {
487
+ "used_fallback": bool(
488
+ run_result["final_result"].get("telemetry", {}).get("used_fallback", False)
489
+ ),
490
+ "fallback_reason": run_result["final_result"].get("telemetry", {}).get("fallback_reason"),
491
+ "engine_mode": run_result["final_result"].get("telemetry", {}).get("engine_mode"),
492
+ }
493
+ )
494
 
495
  group_pairs = []
496
  for left, right in combinations(branch_results, 2):
 
520
  pair_scores.append(pair_score)
521
  group_pairs.append(pair_detail)
522
 
523
+ avg_pair_divergence = round(
524
+ statistics.mean([pair["pair_divergence_score"] for pair in group_pairs]),
525
+ 4,
526
+ ) if group_pairs else 0.0
527
+ if avg_pair_divergence < 0.2:
528
+ low_divergence_groups.append(
529
+ {
530
+ "group_id": group["id"],
531
+ "avg_pair_divergence": avg_pair_divergence,
532
+ "branch_labels": [branch["label"] for branch in branch_results],
533
+ }
534
+ )
535
+
536
  group_summaries.append(
537
  {
538
  "id": group["id"],
539
+ "avg_pair_divergence": avg_pair_divergence,
 
 
 
540
  "branches": [
541
  {
542
  "label": branch["label"],
 
551
  )
552
 
553
  meaningful_pairs = sum(1 for score in pair_scores if score >= 0.2)
554
+ fallback_summary = _summarize_fallback_records(fallback_records)
555
  return {
556
  "task": "branch_divergence",
557
  "group_count": len(dataset),
 
560
  meaningful_pairs / len(pair_scores),
561
  4,
562
  ) if pair_scores else 0.0,
563
+ "fallback_count": fallback_summary["fallback_count"],
564
+ "fallback_rate": fallback_summary["fallback_rate"],
565
+ "fallback_reason_breakdown": fallback_summary["fallback_reason_breakdown"],
566
+ "engine_mode_breakdown": fallback_summary["engine_mode_breakdown"],
567
+ "failure_cases": _limit_cases(low_divergence_groups),
568
  "groups": group_summaries,
569
  }
570
 
 
577
  }
578
 
579
 
580
+ def _build_failure_summary(results: dict[str, Any]) -> dict[str, Any]:
581
+ failure_summary: dict[str, Any] = {}
582
+
583
+ if "intent" in results:
584
+ intent_failures = [
585
+ {
586
+ "id": detail["id"],
587
+ "input": detail["input"],
588
+ "expected_intent": detail["expected_intent"],
589
+ "predicted_intent": detail["predicted_intent"],
590
+ "parser_source": detail["parser_source"],
591
+ }
592
+ for detail in results["intent"]["details"]
593
+ if not detail["intent_correct"]
594
+ ]
595
+ failure_summary["intent_failures"] = {
596
+ "count": len(intent_failures),
597
+ "cases": _limit_cases(intent_failures),
598
+ }
599
+
600
+ if "consistency" in results:
601
+ consistency_failures = [
602
+ {
603
+ "id": detail["id"],
604
+ "type": "action_guard",
605
+ "expected_valid": detail["expected_valid"],
606
+ "predicted_valid": detail["predicted_valid"],
607
+ "rejection_reason": detail["rejection_reason"],
608
+ }
609
+ for detail in results["consistency"]["action_guard_details"]
610
+ if not detail["correct"]
611
+ ]
612
+ consistency_failures.extend(
613
+ {
614
+ "id": detail["id"],
615
+ "type": "state_check",
616
+ "expected_contradiction": detail["expected_contradiction"],
617
+ "predicted_contradiction": detail["predicted_contradiction"],
618
+ "contradictions": detail["contradictions"],
619
+ }
620
+ for detail in results["consistency"]["state_check_details"]
621
+ if not detail["correct"]
622
+ )
623
+ failure_summary["consistency_failures"] = {
624
+ "count": len(consistency_failures),
625
+ "cases": _limit_cases(consistency_failures),
626
+ }
627
+
628
+ if "latency" in results:
629
+ failure_summary["latency_failures"] = {
630
+ "count": len(results["latency"].get("failure_cases", [])),
631
+ "cases": _limit_cases(results["latency"].get("failure_cases", [])),
632
+ }
633
+
634
+ if "branch" in results:
635
+ failure_summary["branch_failures"] = {
636
+ "count": len(results["branch"].get("failure_cases", [])),
637
+ "cases": _limit_cases(results["branch"].get("failure_cases", [])),
638
+ }
639
+
640
+ return failure_summary
641
+
642
+
643
  def _build_summary(results: dict[str, Any]) -> dict[str, Any]:
644
  summary = {}
645
  if "intent" in results:
 
649
  if "latency" in results:
650
  summary["avg_total_latency_ms"] = results["latency"]["avg_total_latency_ms"]
651
  summary["latency_fallback_rate"] = results["latency"]["fallback_rate"]
652
+ summary["latency_fallback_count"] = results["latency"]["fallback_count"]
653
  if "branch" in results:
654
  summary["avg_pair_divergence"] = results["branch"]["avg_pair_divergence"]
655
+ summary["branch_fallback_rate"] = results["branch"]["fallback_rate"]
656
  return summary
657
 
658
 
 
685
  "generated_at": datetime.now().isoformat(timespec="seconds"),
686
  "task": args.task,
687
  "summary": _build_summary(task_results),
688
+ "failure_summary": _build_failure_summary(task_results),
689
  "results": task_results,
690
  }
691