File size: 9,356 Bytes
5143557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""Shared benchmark types, metrics, and quota helpers."""

from __future__ import annotations

from collections import OrderedDict
from math import ceil
from pathlib import Path
from statistics import mean, stdev
from typing import Literal

from pydantic import BaseModel, Field

from dataforge.datasets.real_world import GroundTruthCell
from dataforge.datasets.registry import DATASET_REGISTRY

BenchmarkStatus = Literal["ok", "skipped"]


class BenchmarkRepair(BaseModel):
    """One benchmark repair prediction."""

    row: int = Field(ge=0)
    column: str = Field(min_length=1)
    new_value: str
    reason: str = Field(min_length=1)

    model_config = {"frozen": True}


class RepairScore(BaseModel):
    """Exact-match cell repair metrics for one episode."""

    tp: int = Field(ge=0)
    fp: int = Field(ge=0)
    fn: int = Field(ge=0)
    precision: float = Field(ge=0.0, le=1.0)
    recall: float = Field(ge=0.0, le=1.0)
    f1: float = Field(ge=0.0, le=1.0)

    model_config = {"frozen": True}


class SeedBenchmarkResult(BaseModel):
    """Benchmark result for one dataset/method/seed run."""

    method: str = Field(min_length=1)
    dataset: str = Field(min_length=1)
    seed: int = Field(ge=0)
    status: BenchmarkStatus
    skip_reason: str | None = None
    precision: float | None = None
    recall: float | None = None
    f1: float | None = None
    tp: int | None = None
    fp: int | None = None
    fn: int | None = None
    avg_steps: float | None = None
    llm_calls: int = Field(ge=0, default=0)
    prompt_tokens: int = Field(ge=0, default=0)
    completion_tokens: int = Field(ge=0, default=0)
    quota_units: float = Field(ge=0.0, default=0.0)
    runtime_s: float = Field(ge=0.0, default=0.0)
    provider: str | None = None
    model: str | None = None
    warnings: list[str] = Field(default_factory=list)
    reproduction_command: str = Field(min_length=1)


class AggregateBenchmarkResult(BaseModel):
    """Aggregated benchmark result across seeds for one method/dataset pair."""

    method: str = Field(min_length=1)
    dataset: str = Field(min_length=1)
    status: BenchmarkStatus
    skip_reason: str | None = None
    seeds_requested: int = Field(ge=0)
    seeds_completed: int = Field(ge=0)
    precision_mean: float | None = None
    precision_std: float | None = None
    recall_mean: float | None = None
    recall_std: float | None = None
    f1_mean: float | None = None
    f1_std: float | None = None
    avg_steps_mean: float | None = None
    avg_steps_std: float | None = None
    quota_units_mean: float | None = None
    quota_units_std: float | None = None
    runtime_s_mean: float | None = None
    runtime_s_std: float | None = None
    provider: str | None = None
    model: str | None = None
    reproduction_command: str = Field(min_length=1)


class BenchmarkRunOutput(BaseModel):
    """Serializable benchmark run output."""

    metadata: dict[str, object]
    records: list[SeedBenchmarkResult]
    aggregates: list[AggregateBenchmarkResult]


def chunk_row_indices(n_rows: int) -> tuple[tuple[int, ...], ...]:
    """Split rows into contiguous chunks with a target of twenty chunks."""
    if n_rows <= 0:
        return ()
    chunk_size = ceil(n_rows / 20)
    chunks: list[tuple[int, ...]] = []
    for start in range(0, n_rows, chunk_size):
        stop = min(start + chunk_size, n_rows)
        chunks.append(tuple(range(start, stop)))
    return tuple(chunks)


def normalize_repairs(repairs: list[BenchmarkRepair]) -> list[BenchmarkRepair]:
    """Collapse repairs to one final prediction per cell using last-write-wins."""
    by_cell: OrderedDict[tuple[int, str], BenchmarkRepair] = OrderedDict()
    for repair in repairs:
        key = (repair.row, repair.column)
        if key in by_cell:
            del by_cell[key]
        by_cell[key] = repair
    return list(by_cell.values())


def score_repairs(
    ground_truth: tuple[GroundTruthCell, ...] | list[GroundTruthCell],
    repairs: list[BenchmarkRepair],
) -> RepairScore:
    """Score repaired cells against exact dirty-to-clean ground truth."""
    normalized = normalize_repairs(repairs)
    ground_truth_map = {(cell.row, cell.column): cell.clean_value for cell in ground_truth}

    matched: set[tuple[int, str]] = set()
    tp = 0
    fp = 0
    for repair in normalized:
        key = (repair.row, repair.column)
        clean_value = ground_truth_map.get(key)
        if clean_value is not None and repair.new_value == clean_value:
            tp += 1
            matched.add(key)
        else:
            fp += 1

    fn = len(ground_truth_map) - len(matched)
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
    return RepairScore(
        tp=tp,
        fp=fp,
        fn=fn,
        precision=round(precision, 4),
        recall=round(recall, 4),
        f1=round(f1, 4),
    )


def quota_units(*, llm_calls: int, prompt_tokens: int, completion_tokens: int) -> float:
    """Compute free-tier quota units consumed by one episode."""
    request_fraction = llm_calls / 1000 if llm_calls else 0.0
    token_fraction = (
        (prompt_tokens + completion_tokens) / 100000
        if (prompt_tokens or completion_tokens)
        else 0.0
    )
    return round(max(request_fraction, token_fraction), 4)


def estimate_llm_calls(*, methods: list[str], datasets: list[str], seeds: int) -> int:
    """Estimate total LLM calls for the selected run configuration."""
    estimated = 0
    for dataset_name in datasets:
        chunks = len(chunk_row_indices(DATASET_REGISTRY[dataset_name].n_rows))
        for method in methods:
            if method == "llm_zeroshot":
                estimated += chunks * seeds
            elif method == "llm_react":
                estimated += chunks * 2 * seeds
    return estimated


def validate_estimated_calls(*, estimated_calls: int, really_run_big_bench: bool) -> None:
    """Enforce the free-tier call budget."""
    if estimated_calls > 500 and not really_run_big_bench:
        raise ValueError(
            "Estimated benchmark size exceeds 500 free-tier LLM calls. "
            "Pass --really-run-big-bench to continue."
        )


def aggregate_seed_results(
    records: list[SeedBenchmarkResult],
    *,
    seeds_requested: int,
) -> list[AggregateBenchmarkResult]:
    """Aggregate seed-level results into method/dataset summaries."""
    grouped: OrderedDict[tuple[str, str], list[SeedBenchmarkResult]] = OrderedDict()
    for record in records:
        grouped.setdefault((record.method, record.dataset), []).append(record)

    def _mean_std(values: list[float]) -> tuple[float, float]:
        if len(values) == 1:
            return round(values[0], 4), 0.0
        return round(mean(values), 4), round(stdev(values), 4)

    aggregates: list[AggregateBenchmarkResult] = []
    for (method, dataset), rows in grouped.items():
        ok_rows = [row for row in rows if row.status == "ok"]
        if not ok_rows:
            aggregates.append(
                AggregateBenchmarkResult(
                    method=method,
                    dataset=dataset,
                    status="skipped",
                    skip_reason=rows[0].skip_reason,
                    seeds_requested=seeds_requested,
                    seeds_completed=0,
                    provider=rows[0].provider,
                    model=rows[0].model,
                    reproduction_command=rows[0].reproduction_command,
                )
            )
            continue

        precision_mean, precision_std = _mean_std([row.precision or 0.0 for row in ok_rows])
        recall_mean, recall_std = _mean_std([row.recall or 0.0 for row in ok_rows])
        f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows])
        avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows])
        quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows])
        runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows])
        aggregates.append(
            AggregateBenchmarkResult(
                method=method,
                dataset=dataset,
                status="ok",
                skip_reason=None,
                seeds_requested=seeds_requested,
                seeds_completed=len(ok_rows),
                precision_mean=precision_mean,
                precision_std=precision_std,
                recall_mean=recall_mean,
                recall_std=recall_std,
                f1_mean=f1_mean,
                f1_std=f1_std,
                avg_steps_mean=avg_steps_mean,
                avg_steps_std=avg_steps_std,
                quota_units_mean=quota_mean,
                quota_units_std=quota_std,
                runtime_s_mean=runtime_mean,
                runtime_s_std=runtime_std,
                provider=ok_rows[0].provider,
                model=ok_rows[0].model,
                reproduction_command=ok_rows[0].reproduction_command,
            )
        )
    return aggregates


def write_run_output(output: BenchmarkRunOutput, path: Path) -> None:
    """Serialize benchmark run output to JSON."""
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(output.model_dump_json(indent=2), encoding="utf-8")