underdog-lab / scripts /evaluate_extractor.py
Moftah
Harden dataset evaluation and expose runtime fallback
38b11ab
Raw
History Blame Contribute Delete
6.96 kB
from __future__ import annotations
import argparse
import json
import statistics
import time
from collections import defaultdict
from pathlib import Path
from underdog_lab.domain import MatchRecord
from underdog_lab.scenarios.factory import build_extractor
from underdog_lab.scenarios.schemas import ScenarioExtraction
def factor_keys(extraction: ScenarioExtraction) -> set[tuple[str, str]]:
return {(factor.factor_type.value, factor.team) for factor in extraction.factors}
def score(path: Path) -> dict:
extractor = build_extractor()
tp = fp = fn = 0
team_correct = team_total = 0
unsupported_tp = unsupported_fp = unsupported_fn = 0
ambiguity_tp = ambiguity_fp = ambiguity_fn = 0
severity_errors = []
exact_matches = 0
per_factor = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})
latencies = []
examples = []
claim_ready = True
with path.open(encoding="utf-8") as stream:
for line in stream:
row = json.loads(line)
claim_ready = claim_ready and row.get("review_status") == "approved"
expected = ScenarioExtraction.model_validate(row["expected"])
match = MatchRecord(
match_id=row["id"],
kickoff_date="2026-01-01",
competition="Evaluation",
stage="Test",
home_team=row["home_team"],
away_team=row["away_team"],
venue="Evaluation venue",
neutral_venue=True,
home_goals=0,
away_goals=0,
pre_match_home_elo=1800,
pre_match_away_elo=1800,
lambda_home=1.18,
lambda_away=1.18,
context="Frozen extraction evaluation.",
)
started = time.perf_counter()
actual = extractor.extract(row["text"], match)
latencies.append((time.perf_counter() - started) * 1000)
expected_keys = factor_keys(expected)
actual_keys = factor_keys(actual)
tp += len(expected_keys & actual_keys)
fp += len(actual_keys - expected_keys)
fn += len(expected_keys - actual_keys)
for factor_type, team in expected_keys & actual_keys:
per_factor[factor_type]["tp"] += 1
for factor_type, team in actual_keys - expected_keys:
per_factor[factor_type]["fp"] += 1
for factor_type, team in expected_keys - actual_keys:
per_factor[factor_type]["fn"] += 1
expected_unsupported = bool(expected.unsupported_claims)
actual_unsupported = bool(actual.unsupported_claims)
unsupported_tp += expected_unsupported and actual_unsupported
unsupported_fp += not expected_unsupported and actual_unsupported
unsupported_fn += expected_unsupported and not actual_unsupported
expected_ambiguous = bool(expected.ambiguities)
actual_ambiguous = bool(actual.ambiguities)
ambiguity_tp += expected_ambiguous and actual_ambiguous
ambiguity_fp += not expected_ambiguous and actual_ambiguous
ambiguity_fn += expected_ambiguous and not actual_ambiguous
actual_by_key = {
(factor.factor_type.value, factor.team): factor
for factor in actual.factors
}
for expected_factor in expected.factors:
team_total += 1
team_correct += any(
actual_factor.factor_type == expected_factor.factor_type
and actual_factor.team == expected_factor.team
for actual_factor in actual.factors
)
key = (expected_factor.factor_type.value, expected_factor.team)
if key in actual_by_key:
severity_errors.append(
abs(expected_factor.severity - actual_by_key[key].severity)
)
exact_matches += (
expected_keys == actual_keys
and expected_unsupported == actual_unsupported
and expected_ambiguous == actual_ambiguous
)
examples.append(
{
"id": row["id"],
"text": row["text"],
"expected": expected.model_dump(mode="json"),
"actual": actual.model_dump(mode="json"),
}
)
precision = tp / (tp + fp) if tp + fp else 0.0
recall = tp / (tp + fn) if tp + fn else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
unsupported_f1 = _binary_f1(unsupported_tp, unsupported_fp, unsupported_fn)
ambiguity_f1 = _binary_f1(ambiguity_tp, ambiguity_fp, ambiguity_fn)
factor_f1s = {
factor: _binary_f1(counts["tp"], counts["fp"], counts["fn"])
for factor, counts in sorted(per_factor.items())
}
return {
"extractor": extractor.name,
"examples": len(examples),
"factor_micro_precision": precision,
"factor_micro_recall": recall,
"factor_micro_f1": f1,
"factor_macro_f1": (
statistics.mean(factor_f1s.values()) if factor_f1s else 0.0
),
"factor_f1_by_type": factor_f1s,
"team_attribution_accuracy": team_correct / team_total if team_total else 0.0,
"severity_mae_on_matched_factors": (
statistics.mean(severity_errors) if severity_errors else None
),
"unsupported_claim_f1": unsupported_f1,
"ambiguity_detection_f1": ambiguity_f1,
"exact_semantic_match_rate": exact_matches / len(examples) if examples else 0.0,
"median_latency_ms": statistics.median(latencies) if latencies else 0.0,
"claim_ready": claim_ready,
"warning": (
"" if claim_ready else "This test set contains unreviewed synthetic labels."
),
"details": examples,
}
def _binary_f1(tp: int, fp: int, fn: int) -> float:
precision = tp / (tp + fp) if tp + fp else 0.0
recall = tp / (tp + fn) if tp + fn else 0.0
return 2 * precision * recall / (precision + recall) if precision + recall else 0.0
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--test-set",
type=Path,
default=Path("data/scenarios/test.jsonl"),
)
parser.add_argument(
"--output",
type=Path,
default=Path("data/scenarios/evaluation.json"),
)
args = parser.parse_args()
report = score(args.test_set)
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(
json.dumps(report, indent=2, ensure_ascii=True) + "\n",
encoding="utf-8",
)
print(json.dumps({key: value for key, value in report.items() if key != "details"}, indent=2))
if __name__ == "__main__":
main()