File size: 9,356 Bytes
5143557 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | """Shared benchmark types, metrics, and quota helpers."""
from __future__ import annotations
from collections import OrderedDict
from math import ceil
from pathlib import Path
from statistics import mean, stdev
from typing import Literal
from pydantic import BaseModel, Field
from dataforge.datasets.real_world import GroundTruthCell
from dataforge.datasets.registry import DATASET_REGISTRY
BenchmarkStatus = Literal["ok", "skipped"]
class BenchmarkRepair(BaseModel):
"""One benchmark repair prediction."""
row: int = Field(ge=0)
column: str = Field(min_length=1)
new_value: str
reason: str = Field(min_length=1)
model_config = {"frozen": True}
class RepairScore(BaseModel):
"""Exact-match cell repair metrics for one episode."""
tp: int = Field(ge=0)
fp: int = Field(ge=0)
fn: int = Field(ge=0)
precision: float = Field(ge=0.0, le=1.0)
recall: float = Field(ge=0.0, le=1.0)
f1: float = Field(ge=0.0, le=1.0)
model_config = {"frozen": True}
class SeedBenchmarkResult(BaseModel):
"""Benchmark result for one dataset/method/seed run."""
method: str = Field(min_length=1)
dataset: str = Field(min_length=1)
seed: int = Field(ge=0)
status: BenchmarkStatus
skip_reason: str | None = None
precision: float | None = None
recall: float | None = None
f1: float | None = None
tp: int | None = None
fp: int | None = None
fn: int | None = None
avg_steps: float | None = None
llm_calls: int = Field(ge=0, default=0)
prompt_tokens: int = Field(ge=0, default=0)
completion_tokens: int = Field(ge=0, default=0)
quota_units: float = Field(ge=0.0, default=0.0)
runtime_s: float = Field(ge=0.0, default=0.0)
provider: str | None = None
model: str | None = None
warnings: list[str] = Field(default_factory=list)
reproduction_command: str = Field(min_length=1)
class AggregateBenchmarkResult(BaseModel):
"""Aggregated benchmark result across seeds for one method/dataset pair."""
method: str = Field(min_length=1)
dataset: str = Field(min_length=1)
status: BenchmarkStatus
skip_reason: str | None = None
seeds_requested: int = Field(ge=0)
seeds_completed: int = Field(ge=0)
precision_mean: float | None = None
precision_std: float | None = None
recall_mean: float | None = None
recall_std: float | None = None
f1_mean: float | None = None
f1_std: float | None = None
avg_steps_mean: float | None = None
avg_steps_std: float | None = None
quota_units_mean: float | None = None
quota_units_std: float | None = None
runtime_s_mean: float | None = None
runtime_s_std: float | None = None
provider: str | None = None
model: str | None = None
reproduction_command: str = Field(min_length=1)
class BenchmarkRunOutput(BaseModel):
"""Serializable benchmark run output."""
metadata: dict[str, object]
records: list[SeedBenchmarkResult]
aggregates: list[AggregateBenchmarkResult]
def chunk_row_indices(n_rows: int) -> tuple[tuple[int, ...], ...]:
"""Split rows into contiguous chunks with a target of twenty chunks."""
if n_rows <= 0:
return ()
chunk_size = ceil(n_rows / 20)
chunks: list[tuple[int, ...]] = []
for start in range(0, n_rows, chunk_size):
stop = min(start + chunk_size, n_rows)
chunks.append(tuple(range(start, stop)))
return tuple(chunks)
def normalize_repairs(repairs: list[BenchmarkRepair]) -> list[BenchmarkRepair]:
"""Collapse repairs to one final prediction per cell using last-write-wins."""
by_cell: OrderedDict[tuple[int, str], BenchmarkRepair] = OrderedDict()
for repair in repairs:
key = (repair.row, repair.column)
if key in by_cell:
del by_cell[key]
by_cell[key] = repair
return list(by_cell.values())
def score_repairs(
ground_truth: tuple[GroundTruthCell, ...] | list[GroundTruthCell],
repairs: list[BenchmarkRepair],
) -> RepairScore:
"""Score repaired cells against exact dirty-to-clean ground truth."""
normalized = normalize_repairs(repairs)
ground_truth_map = {(cell.row, cell.column): cell.clean_value for cell in ground_truth}
matched: set[tuple[int, str]] = set()
tp = 0
fp = 0
for repair in normalized:
key = (repair.row, repair.column)
clean_value = ground_truth_map.get(key)
if clean_value is not None and repair.new_value == clean_value:
tp += 1
matched.add(key)
else:
fp += 1
fn = len(ground_truth_map) - len(matched)
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return RepairScore(
tp=tp,
fp=fp,
fn=fn,
precision=round(precision, 4),
recall=round(recall, 4),
f1=round(f1, 4),
)
def quota_units(*, llm_calls: int, prompt_tokens: int, completion_tokens: int) -> float:
"""Compute free-tier quota units consumed by one episode."""
request_fraction = llm_calls / 1000 if llm_calls else 0.0
token_fraction = (
(prompt_tokens + completion_tokens) / 100000
if (prompt_tokens or completion_tokens)
else 0.0
)
return round(max(request_fraction, token_fraction), 4)
def estimate_llm_calls(*, methods: list[str], datasets: list[str], seeds: int) -> int:
"""Estimate total LLM calls for the selected run configuration."""
estimated = 0
for dataset_name in datasets:
chunks = len(chunk_row_indices(DATASET_REGISTRY[dataset_name].n_rows))
for method in methods:
if method == "llm_zeroshot":
estimated += chunks * seeds
elif method == "llm_react":
estimated += chunks * 2 * seeds
return estimated
def validate_estimated_calls(*, estimated_calls: int, really_run_big_bench: bool) -> None:
"""Enforce the free-tier call budget."""
if estimated_calls > 500 and not really_run_big_bench:
raise ValueError(
"Estimated benchmark size exceeds 500 free-tier LLM calls. "
"Pass --really-run-big-bench to continue."
)
def aggregate_seed_results(
records: list[SeedBenchmarkResult],
*,
seeds_requested: int,
) -> list[AggregateBenchmarkResult]:
"""Aggregate seed-level results into method/dataset summaries."""
grouped: OrderedDict[tuple[str, str], list[SeedBenchmarkResult]] = OrderedDict()
for record in records:
grouped.setdefault((record.method, record.dataset), []).append(record)
def _mean_std(values: list[float]) -> tuple[float, float]:
if len(values) == 1:
return round(values[0], 4), 0.0
return round(mean(values), 4), round(stdev(values), 4)
aggregates: list[AggregateBenchmarkResult] = []
for (method, dataset), rows in grouped.items():
ok_rows = [row for row in rows if row.status == "ok"]
if not ok_rows:
aggregates.append(
AggregateBenchmarkResult(
method=method,
dataset=dataset,
status="skipped",
skip_reason=rows[0].skip_reason,
seeds_requested=seeds_requested,
seeds_completed=0,
provider=rows[0].provider,
model=rows[0].model,
reproduction_command=rows[0].reproduction_command,
)
)
continue
precision_mean, precision_std = _mean_std([row.precision or 0.0 for row in ok_rows])
recall_mean, recall_std = _mean_std([row.recall or 0.0 for row in ok_rows])
f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows])
avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows])
quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows])
runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows])
aggregates.append(
AggregateBenchmarkResult(
method=method,
dataset=dataset,
status="ok",
skip_reason=None,
seeds_requested=seeds_requested,
seeds_completed=len(ok_rows),
precision_mean=precision_mean,
precision_std=precision_std,
recall_mean=recall_mean,
recall_std=recall_std,
f1_mean=f1_mean,
f1_std=f1_std,
avg_steps_mean=avg_steps_mean,
avg_steps_std=avg_steps_std,
quota_units_mean=quota_mean,
quota_units_std=quota_std,
runtime_s_mean=runtime_mean,
runtime_s_std=runtime_std,
provider=ok_rows[0].provider,
model=ok_rows[0].model,
reproduction_command=ok_rows[0].reproduction_command,
)
)
return aggregates
def write_run_output(output: BenchmarkRunOutput, path: Path) -> None:
"""Serialize benchmark run output to JSON."""
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(output.model_dump_json(indent=2), encoding="utf-8")
|