"""Shared benchmark types, metrics, and quota helpers.""" from __future__ import annotations from collections import OrderedDict from math import ceil from pathlib import Path from statistics import mean, stdev from typing import Literal from pydantic import BaseModel, Field from dataforge.datasets.real_world import GroundTruthCell from dataforge.datasets.registry import DATASET_REGISTRY BenchmarkStatus = Literal["ok", "skipped"] class BenchmarkRepair(BaseModel): """One benchmark repair prediction.""" row: int = Field(ge=0) column: str = Field(min_length=1) new_value: str reason: str = Field(min_length=1) model_config = {"frozen": True} class RepairScore(BaseModel): """Exact-match cell repair metrics for one episode.""" tp: int = Field(ge=0) fp: int = Field(ge=0) fn: int = Field(ge=0) precision: float = Field(ge=0.0, le=1.0) recall: float = Field(ge=0.0, le=1.0) f1: float = Field(ge=0.0, le=1.0) model_config = {"frozen": True} class SeedBenchmarkResult(BaseModel): """Benchmark result for one dataset/method/seed run.""" method: str = Field(min_length=1) dataset: str = Field(min_length=1) seed: int = Field(ge=0) status: BenchmarkStatus skip_reason: str | None = None precision: float | None = None recall: float | None = None f1: float | None = None tp: int | None = None fp: int | None = None fn: int | None = None avg_steps: float | None = None llm_calls: int = Field(ge=0, default=0) prompt_tokens: int = Field(ge=0, default=0) completion_tokens: int = Field(ge=0, default=0) quota_units: float = Field(ge=0.0, default=0.0) runtime_s: float = Field(ge=0.0, default=0.0) provider: str | None = None model: str | None = None warnings: list[str] = Field(default_factory=list) reproduction_command: str = Field(min_length=1) class AggregateBenchmarkResult(BaseModel): """Aggregated benchmark result across seeds for one method/dataset pair.""" method: str = Field(min_length=1) dataset: str = Field(min_length=1) status: BenchmarkStatus skip_reason: str | None = None seeds_requested: int = Field(ge=0) seeds_completed: int = Field(ge=0) precision_mean: float | None = None precision_std: float | None = None recall_mean: float | None = None recall_std: float | None = None f1_mean: float | None = None f1_std: float | None = None avg_steps_mean: float | None = None avg_steps_std: float | None = None quota_units_mean: float | None = None quota_units_std: float | None = None runtime_s_mean: float | None = None runtime_s_std: float | None = None provider: str | None = None model: str | None = None reproduction_command: str = Field(min_length=1) class BenchmarkRunOutput(BaseModel): """Serializable benchmark run output.""" metadata: dict[str, object] records: list[SeedBenchmarkResult] aggregates: list[AggregateBenchmarkResult] def chunk_row_indices(n_rows: int) -> tuple[tuple[int, ...], ...]: """Split rows into contiguous chunks with a target of twenty chunks.""" if n_rows <= 0: return () chunk_size = ceil(n_rows / 20) chunks: list[tuple[int, ...]] = [] for start in range(0, n_rows, chunk_size): stop = min(start + chunk_size, n_rows) chunks.append(tuple(range(start, stop))) return tuple(chunks) def normalize_repairs(repairs: list[BenchmarkRepair]) -> list[BenchmarkRepair]: """Collapse repairs to one final prediction per cell using last-write-wins.""" by_cell: OrderedDict[tuple[int, str], BenchmarkRepair] = OrderedDict() for repair in repairs: key = (repair.row, repair.column) if key in by_cell: del by_cell[key] by_cell[key] = repair return list(by_cell.values()) def score_repairs( ground_truth: tuple[GroundTruthCell, ...] | list[GroundTruthCell], repairs: list[BenchmarkRepair], ) -> RepairScore: """Score repaired cells against exact dirty-to-clean ground truth.""" normalized = normalize_repairs(repairs) ground_truth_map = {(cell.row, cell.column): cell.clean_value for cell in ground_truth} matched: set[tuple[int, str]] = set() tp = 0 fp = 0 for repair in normalized: key = (repair.row, repair.column) clean_value = ground_truth_map.get(key) if clean_value is not None and repair.new_value == clean_value: tp += 1 matched.add(key) else: fp += 1 fn = len(ground_truth_map) - len(matched) precision = tp / (tp + fp) if (tp + fp) else 0.0 recall = tp / (tp + fn) if (tp + fn) else 0.0 f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 return RepairScore( tp=tp, fp=fp, fn=fn, precision=round(precision, 4), recall=round(recall, 4), f1=round(f1, 4), ) def quota_units(*, llm_calls: int, prompt_tokens: int, completion_tokens: int) -> float: """Compute free-tier quota units consumed by one episode.""" request_fraction = llm_calls / 1000 if llm_calls else 0.0 token_fraction = ( (prompt_tokens + completion_tokens) / 100000 if (prompt_tokens or completion_tokens) else 0.0 ) return round(max(request_fraction, token_fraction), 4) def estimate_llm_calls(*, methods: list[str], datasets: list[str], seeds: int) -> int: """Estimate total LLM calls for the selected run configuration.""" estimated = 0 for dataset_name in datasets: chunks = len(chunk_row_indices(DATASET_REGISTRY[dataset_name].n_rows)) for method in methods: if method == "llm_zeroshot": estimated += chunks * seeds elif method == "llm_react": estimated += chunks * 2 * seeds return estimated def validate_estimated_calls(*, estimated_calls: int, really_run_big_bench: bool) -> None: """Enforce the free-tier call budget.""" if estimated_calls > 500 and not really_run_big_bench: raise ValueError( "Estimated benchmark size exceeds 500 free-tier LLM calls. " "Pass --really-run-big-bench to continue." ) def aggregate_seed_results( records: list[SeedBenchmarkResult], *, seeds_requested: int, ) -> list[AggregateBenchmarkResult]: """Aggregate seed-level results into method/dataset summaries.""" grouped: OrderedDict[tuple[str, str], list[SeedBenchmarkResult]] = OrderedDict() for record in records: grouped.setdefault((record.method, record.dataset), []).append(record) def _mean_std(values: list[float]) -> tuple[float, float]: if len(values) == 1: return round(values[0], 4), 0.0 return round(mean(values), 4), round(stdev(values), 4) aggregates: list[AggregateBenchmarkResult] = [] for (method, dataset), rows in grouped.items(): ok_rows = [row for row in rows if row.status == "ok"] if not ok_rows: aggregates.append( AggregateBenchmarkResult( method=method, dataset=dataset, status="skipped", skip_reason=rows[0].skip_reason, seeds_requested=seeds_requested, seeds_completed=0, provider=rows[0].provider, model=rows[0].model, reproduction_command=rows[0].reproduction_command, ) ) continue precision_mean, precision_std = _mean_std([row.precision or 0.0 for row in ok_rows]) recall_mean, recall_std = _mean_std([row.recall or 0.0 for row in ok_rows]) f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows]) avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows]) quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows]) runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows]) aggregates.append( AggregateBenchmarkResult( method=method, dataset=dataset, status="ok", skip_reason=None, seeds_requested=seeds_requested, seeds_completed=len(ok_rows), precision_mean=precision_mean, precision_std=precision_std, recall_mean=recall_mean, recall_std=recall_std, f1_mean=f1_mean, f1_std=f1_std, avg_steps_mean=avg_steps_mean, avg_steps_std=avg_steps_std, quota_units_mean=quota_mean, quota_units_std=quota_std, runtime_s_mean=runtime_mean, runtime_s_std=runtime_std, provider=ok_rows[0].provider, model=ok_rows[0].model, reproduction_command=ok_rows[0].reproduction_command, ) ) return aggregates def write_run_output(output: BenchmarkRunOutput, path: Path) -> None: """Serialize benchmark run output to JSON.""" path.parent.mkdir(parents=True, exist_ok=True) path.write_text(output.model_dump_json(indent=2), encoding="utf-8")