File size: 6,963 Bytes
18d5764
 
 
 
 
 
38b11ab
18d5764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38b11ab
 
 
 
 
18d5764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38b11ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d5764
 
 
 
 
 
 
38b11ab
 
 
 
 
 
 
 
 
 
18d5764
 
 
 
 
 
 
 
 
 
 
 
38b11ab
 
 
 
 
 
18d5764
 
 
 
 
 
38b11ab
 
 
 
18d5764
38b11ab
 
 
 
 
 
18d5764
 
 
 
 
 
 
 
 
38b11ab
 
 
 
 
 
18d5764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from __future__ import annotations

import argparse
import json
import statistics
import time
from collections import defaultdict
from pathlib import Path

from underdog_lab.domain import MatchRecord
from underdog_lab.scenarios.factory import build_extractor
from underdog_lab.scenarios.schemas import ScenarioExtraction


def factor_keys(extraction: ScenarioExtraction) -> set[tuple[str, str]]:
    return {(factor.factor_type.value, factor.team) for factor in extraction.factors}


def score(path: Path) -> dict:
    extractor = build_extractor()
    tp = fp = fn = 0
    team_correct = team_total = 0
    unsupported_tp = unsupported_fp = unsupported_fn = 0
    ambiguity_tp = ambiguity_fp = ambiguity_fn = 0
    severity_errors = []
    exact_matches = 0
    per_factor = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})
    latencies = []
    examples = []
    claim_ready = True

    with path.open(encoding="utf-8") as stream:
        for line in stream:
            row = json.loads(line)
            claim_ready = claim_ready and row.get("review_status") == "approved"
            expected = ScenarioExtraction.model_validate(row["expected"])
            match = MatchRecord(
                match_id=row["id"],
                kickoff_date="2026-01-01",
                competition="Evaluation",
                stage="Test",
                home_team=row["home_team"],
                away_team=row["away_team"],
                venue="Evaluation venue",
                neutral_venue=True,
                home_goals=0,
                away_goals=0,
                pre_match_home_elo=1800,
                pre_match_away_elo=1800,
                lambda_home=1.18,
                lambda_away=1.18,
                context="Frozen extraction evaluation.",
            )
            started = time.perf_counter()
            actual = extractor.extract(row["text"], match)
            latencies.append((time.perf_counter() - started) * 1000)
            expected_keys = factor_keys(expected)
            actual_keys = factor_keys(actual)
            tp += len(expected_keys & actual_keys)
            fp += len(actual_keys - expected_keys)
            fn += len(expected_keys - actual_keys)
            for factor_type, team in expected_keys & actual_keys:
                per_factor[factor_type]["tp"] += 1
            for factor_type, team in actual_keys - expected_keys:
                per_factor[factor_type]["fp"] += 1
            for factor_type, team in expected_keys - actual_keys:
                per_factor[factor_type]["fn"] += 1

            expected_unsupported = bool(expected.unsupported_claims)
            actual_unsupported = bool(actual.unsupported_claims)
            unsupported_tp += expected_unsupported and actual_unsupported
            unsupported_fp += not expected_unsupported and actual_unsupported
            unsupported_fn += expected_unsupported and not actual_unsupported

            expected_ambiguous = bool(expected.ambiguities)
            actual_ambiguous = bool(actual.ambiguities)
            ambiguity_tp += expected_ambiguous and actual_ambiguous
            ambiguity_fp += not expected_ambiguous and actual_ambiguous
            ambiguity_fn += expected_ambiguous and not actual_ambiguous

            actual_by_key = {
                (factor.factor_type.value, factor.team): factor
                for factor in actual.factors
            }
            for expected_factor in expected.factors:
                team_total += 1
                team_correct += any(
                    actual_factor.factor_type == expected_factor.factor_type
                    and actual_factor.team == expected_factor.team
                    for actual_factor in actual.factors
                )
                key = (expected_factor.factor_type.value, expected_factor.team)
                if key in actual_by_key:
                    severity_errors.append(
                        abs(expected_factor.severity - actual_by_key[key].severity)
                    )
            exact_matches += (
                expected_keys == actual_keys
                and expected_unsupported == actual_unsupported
                and expected_ambiguous == actual_ambiguous
            )
            examples.append(
                {
                    "id": row["id"],
                    "text": row["text"],
                    "expected": expected.model_dump(mode="json"),
                    "actual": actual.model_dump(mode="json"),
                }
            )

    precision = tp / (tp + fp) if tp + fp else 0.0
    recall = tp / (tp + fn) if tp + fn else 0.0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
    unsupported_f1 = _binary_f1(unsupported_tp, unsupported_fp, unsupported_fn)
    ambiguity_f1 = _binary_f1(ambiguity_tp, ambiguity_fp, ambiguity_fn)
    factor_f1s = {
        factor: _binary_f1(counts["tp"], counts["fp"], counts["fn"])
        for factor, counts in sorted(per_factor.items())
    }
    return {
        "extractor": extractor.name,
        "examples": len(examples),
        "factor_micro_precision": precision,
        "factor_micro_recall": recall,
        "factor_micro_f1": f1,
        "factor_macro_f1": (
            statistics.mean(factor_f1s.values()) if factor_f1s else 0.0
        ),
        "factor_f1_by_type": factor_f1s,
        "team_attribution_accuracy": team_correct / team_total if team_total else 0.0,
        "severity_mae_on_matched_factors": (
            statistics.mean(severity_errors) if severity_errors else None
        ),
        "unsupported_claim_f1": unsupported_f1,
        "ambiguity_detection_f1": ambiguity_f1,
        "exact_semantic_match_rate": exact_matches / len(examples) if examples else 0.0,
        "median_latency_ms": statistics.median(latencies) if latencies else 0.0,
        "claim_ready": claim_ready,
        "warning": (
            "" if claim_ready else "This test set contains unreviewed synthetic labels."
        ),
        "details": examples,
    }


def _binary_f1(tp: int, fp: int, fn: int) -> float:
    precision = tp / (tp + fp) if tp + fp else 0.0
    recall = tp / (tp + fn) if tp + fn else 0.0
    return 2 * precision * recall / (precision + recall) if precision + recall else 0.0


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--test-set",
        type=Path,
        default=Path("data/scenarios/test.jsonl"),
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=Path("data/scenarios/evaluation.json"),
    )
    args = parser.parse_args()
    report = score(args.test_set)
    args.output.parent.mkdir(parents=True, exist_ok=True)
    args.output.write_text(
        json.dumps(report, indent=2, ensure_ascii=True) + "\n",
        encoding="utf-8",
    )
    print(json.dumps({key: value for key, value in report.items() if key != "details"}, indent=2))


if __name__ == "__main__":
    main()