Spaces:
Running
Running
File size: 6,963 Bytes
18d5764 38b11ab 18d5764 38b11ab 18d5764 38b11ab 18d5764 38b11ab 18d5764 38b11ab 18d5764 38b11ab 18d5764 38b11ab 18d5764 38b11ab 18d5764 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | from __future__ import annotations
import argparse
import json
import statistics
import time
from collections import defaultdict
from pathlib import Path
from underdog_lab.domain import MatchRecord
from underdog_lab.scenarios.factory import build_extractor
from underdog_lab.scenarios.schemas import ScenarioExtraction
def factor_keys(extraction: ScenarioExtraction) -> set[tuple[str, str]]:
return {(factor.factor_type.value, factor.team) for factor in extraction.factors}
def score(path: Path) -> dict:
extractor = build_extractor()
tp = fp = fn = 0
team_correct = team_total = 0
unsupported_tp = unsupported_fp = unsupported_fn = 0
ambiguity_tp = ambiguity_fp = ambiguity_fn = 0
severity_errors = []
exact_matches = 0
per_factor = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})
latencies = []
examples = []
claim_ready = True
with path.open(encoding="utf-8") as stream:
for line in stream:
row = json.loads(line)
claim_ready = claim_ready and row.get("review_status") == "approved"
expected = ScenarioExtraction.model_validate(row["expected"])
match = MatchRecord(
match_id=row["id"],
kickoff_date="2026-01-01",
competition="Evaluation",
stage="Test",
home_team=row["home_team"],
away_team=row["away_team"],
venue="Evaluation venue",
neutral_venue=True,
home_goals=0,
away_goals=0,
pre_match_home_elo=1800,
pre_match_away_elo=1800,
lambda_home=1.18,
lambda_away=1.18,
context="Frozen extraction evaluation.",
)
started = time.perf_counter()
actual = extractor.extract(row["text"], match)
latencies.append((time.perf_counter() - started) * 1000)
expected_keys = factor_keys(expected)
actual_keys = factor_keys(actual)
tp += len(expected_keys & actual_keys)
fp += len(actual_keys - expected_keys)
fn += len(expected_keys - actual_keys)
for factor_type, team in expected_keys & actual_keys:
per_factor[factor_type]["tp"] += 1
for factor_type, team in actual_keys - expected_keys:
per_factor[factor_type]["fp"] += 1
for factor_type, team in expected_keys - actual_keys:
per_factor[factor_type]["fn"] += 1
expected_unsupported = bool(expected.unsupported_claims)
actual_unsupported = bool(actual.unsupported_claims)
unsupported_tp += expected_unsupported and actual_unsupported
unsupported_fp += not expected_unsupported and actual_unsupported
unsupported_fn += expected_unsupported and not actual_unsupported
expected_ambiguous = bool(expected.ambiguities)
actual_ambiguous = bool(actual.ambiguities)
ambiguity_tp += expected_ambiguous and actual_ambiguous
ambiguity_fp += not expected_ambiguous and actual_ambiguous
ambiguity_fn += expected_ambiguous and not actual_ambiguous
actual_by_key = {
(factor.factor_type.value, factor.team): factor
for factor in actual.factors
}
for expected_factor in expected.factors:
team_total += 1
team_correct += any(
actual_factor.factor_type == expected_factor.factor_type
and actual_factor.team == expected_factor.team
for actual_factor in actual.factors
)
key = (expected_factor.factor_type.value, expected_factor.team)
if key in actual_by_key:
severity_errors.append(
abs(expected_factor.severity - actual_by_key[key].severity)
)
exact_matches += (
expected_keys == actual_keys
and expected_unsupported == actual_unsupported
and expected_ambiguous == actual_ambiguous
)
examples.append(
{
"id": row["id"],
"text": row["text"],
"expected": expected.model_dump(mode="json"),
"actual": actual.model_dump(mode="json"),
}
)
precision = tp / (tp + fp) if tp + fp else 0.0
recall = tp / (tp + fn) if tp + fn else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
unsupported_f1 = _binary_f1(unsupported_tp, unsupported_fp, unsupported_fn)
ambiguity_f1 = _binary_f1(ambiguity_tp, ambiguity_fp, ambiguity_fn)
factor_f1s = {
factor: _binary_f1(counts["tp"], counts["fp"], counts["fn"])
for factor, counts in sorted(per_factor.items())
}
return {
"extractor": extractor.name,
"examples": len(examples),
"factor_micro_precision": precision,
"factor_micro_recall": recall,
"factor_micro_f1": f1,
"factor_macro_f1": (
statistics.mean(factor_f1s.values()) if factor_f1s else 0.0
),
"factor_f1_by_type": factor_f1s,
"team_attribution_accuracy": team_correct / team_total if team_total else 0.0,
"severity_mae_on_matched_factors": (
statistics.mean(severity_errors) if severity_errors else None
),
"unsupported_claim_f1": unsupported_f1,
"ambiguity_detection_f1": ambiguity_f1,
"exact_semantic_match_rate": exact_matches / len(examples) if examples else 0.0,
"median_latency_ms": statistics.median(latencies) if latencies else 0.0,
"claim_ready": claim_ready,
"warning": (
"" if claim_ready else "This test set contains unreviewed synthetic labels."
),
"details": examples,
}
def _binary_f1(tp: int, fp: int, fn: int) -> float:
precision = tp / (tp + fp) if tp + fp else 0.0
recall = tp / (tp + fn) if tp + fn else 0.0
return 2 * precision * recall / (precision + recall) if precision + recall else 0.0
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--test-set",
type=Path,
default=Path("data/scenarios/test.jsonl"),
)
parser.add_argument(
"--output",
type=Path,
default=Path("data/scenarios/evaluation.json"),
)
args = parser.parse_args()
report = score(args.test_set)
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(
json.dumps(report, indent=2, ensure_ascii=True) + "\n",
encoding="utf-8",
)
print(json.dumps({key: value for key, value in report.items() if key != "details"}, indent=2))
if __name__ == "__main__":
main()
|