Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import statistics | |
| import time | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from underdog_lab.domain import MatchRecord | |
| from underdog_lab.scenarios.factory import build_extractor | |
| from underdog_lab.scenarios.schemas import ScenarioExtraction | |
| def factor_keys(extraction: ScenarioExtraction) -> set[tuple[str, str]]: | |
| return {(factor.factor_type.value, factor.team) for factor in extraction.factors} | |
| def score(path: Path) -> dict: | |
| extractor = build_extractor() | |
| tp = fp = fn = 0 | |
| team_correct = team_total = 0 | |
| unsupported_tp = unsupported_fp = unsupported_fn = 0 | |
| ambiguity_tp = ambiguity_fp = ambiguity_fn = 0 | |
| severity_errors = [] | |
| exact_matches = 0 | |
| per_factor = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0}) | |
| latencies = [] | |
| examples = [] | |
| claim_ready = True | |
| with path.open(encoding="utf-8") as stream: | |
| for line in stream: | |
| row = json.loads(line) | |
| claim_ready = claim_ready and row.get("review_status") == "approved" | |
| expected = ScenarioExtraction.model_validate(row["expected"]) | |
| match = MatchRecord( | |
| match_id=row["id"], | |
| kickoff_date="2026-01-01", | |
| competition="Evaluation", | |
| stage="Test", | |
| home_team=row["home_team"], | |
| away_team=row["away_team"], | |
| venue="Evaluation venue", | |
| neutral_venue=True, | |
| home_goals=0, | |
| away_goals=0, | |
| pre_match_home_elo=1800, | |
| pre_match_away_elo=1800, | |
| lambda_home=1.18, | |
| lambda_away=1.18, | |
| context="Frozen extraction evaluation.", | |
| ) | |
| started = time.perf_counter() | |
| actual = extractor.extract(row["text"], match) | |
| latencies.append((time.perf_counter() - started) * 1000) | |
| expected_keys = factor_keys(expected) | |
| actual_keys = factor_keys(actual) | |
| tp += len(expected_keys & actual_keys) | |
| fp += len(actual_keys - expected_keys) | |
| fn += len(expected_keys - actual_keys) | |
| for factor_type, team in expected_keys & actual_keys: | |
| per_factor[factor_type]["tp"] += 1 | |
| for factor_type, team in actual_keys - expected_keys: | |
| per_factor[factor_type]["fp"] += 1 | |
| for factor_type, team in expected_keys - actual_keys: | |
| per_factor[factor_type]["fn"] += 1 | |
| expected_unsupported = bool(expected.unsupported_claims) | |
| actual_unsupported = bool(actual.unsupported_claims) | |
| unsupported_tp += expected_unsupported and actual_unsupported | |
| unsupported_fp += not expected_unsupported and actual_unsupported | |
| unsupported_fn += expected_unsupported and not actual_unsupported | |
| expected_ambiguous = bool(expected.ambiguities) | |
| actual_ambiguous = bool(actual.ambiguities) | |
| ambiguity_tp += expected_ambiguous and actual_ambiguous | |
| ambiguity_fp += not expected_ambiguous and actual_ambiguous | |
| ambiguity_fn += expected_ambiguous and not actual_ambiguous | |
| actual_by_key = { | |
| (factor.factor_type.value, factor.team): factor | |
| for factor in actual.factors | |
| } | |
| for expected_factor in expected.factors: | |
| team_total += 1 | |
| team_correct += any( | |
| actual_factor.factor_type == expected_factor.factor_type | |
| and actual_factor.team == expected_factor.team | |
| for actual_factor in actual.factors | |
| ) | |
| key = (expected_factor.factor_type.value, expected_factor.team) | |
| if key in actual_by_key: | |
| severity_errors.append( | |
| abs(expected_factor.severity - actual_by_key[key].severity) | |
| ) | |
| exact_matches += ( | |
| expected_keys == actual_keys | |
| and expected_unsupported == actual_unsupported | |
| and expected_ambiguous == actual_ambiguous | |
| ) | |
| examples.append( | |
| { | |
| "id": row["id"], | |
| "text": row["text"], | |
| "expected": expected.model_dump(mode="json"), | |
| "actual": actual.model_dump(mode="json"), | |
| } | |
| ) | |
| precision = tp / (tp + fp) if tp + fp else 0.0 | |
| recall = tp / (tp + fn) if tp + fn else 0.0 | |
| f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0 | |
| unsupported_f1 = _binary_f1(unsupported_tp, unsupported_fp, unsupported_fn) | |
| ambiguity_f1 = _binary_f1(ambiguity_tp, ambiguity_fp, ambiguity_fn) | |
| factor_f1s = { | |
| factor: _binary_f1(counts["tp"], counts["fp"], counts["fn"]) | |
| for factor, counts in sorted(per_factor.items()) | |
| } | |
| return { | |
| "extractor": extractor.name, | |
| "examples": len(examples), | |
| "factor_micro_precision": precision, | |
| "factor_micro_recall": recall, | |
| "factor_micro_f1": f1, | |
| "factor_macro_f1": ( | |
| statistics.mean(factor_f1s.values()) if factor_f1s else 0.0 | |
| ), | |
| "factor_f1_by_type": factor_f1s, | |
| "team_attribution_accuracy": team_correct / team_total if team_total else 0.0, | |
| "severity_mae_on_matched_factors": ( | |
| statistics.mean(severity_errors) if severity_errors else None | |
| ), | |
| "unsupported_claim_f1": unsupported_f1, | |
| "ambiguity_detection_f1": ambiguity_f1, | |
| "exact_semantic_match_rate": exact_matches / len(examples) if examples else 0.0, | |
| "median_latency_ms": statistics.median(latencies) if latencies else 0.0, | |
| "claim_ready": claim_ready, | |
| "warning": ( | |
| "" if claim_ready else "This test set contains unreviewed synthetic labels." | |
| ), | |
| "details": examples, | |
| } | |
| def _binary_f1(tp: int, fp: int, fn: int) -> float: | |
| precision = tp / (tp + fp) if tp + fp else 0.0 | |
| recall = tp / (tp + fn) if tp + fn else 0.0 | |
| return 2 * precision * recall / (precision + recall) if precision + recall else 0.0 | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--test-set", | |
| type=Path, | |
| default=Path("data/scenarios/test.jsonl"), | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=Path("data/scenarios/evaluation.json"), | |
| ) | |
| args = parser.parse_args() | |
| report = score(args.test_set) | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| args.output.write_text( | |
| json.dumps(report, indent=2, ensure_ascii=True) + "\n", | |
| encoding="utf-8", | |
| ) | |
| print(json.dumps({key: value for key, value in report.items() if key != "details"}, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |