"""Benchmark method implementations for DataForge.""" from __future__ import annotations import json import random import time from math import ceil from statistics import median from dataforge.bench.core import BenchmarkRepair, SeedBenchmarkResult, quota_units, score_repairs from dataforge.bench.groq_client import GroqBenchClient from dataforge.datasets.real_world import RealWorldDataset from dataforge.detectors import run_all_detectors from dataforge.repairers import propose_fixes def _reproduction_command(method: str, dataset: str, seeds: int) -> str: """Build the canonical reproduction command for one method/dataset selection.""" return f"dataforge bench --methods {method} --datasets {dataset} --seeds {seeds}" def _repairs_from_proposed_fixes(dataset: RealWorldDataset) -> list[BenchmarkRepair]: """Run the shipped deterministic detector/repair stack on one dataset.""" issues = run_all_detectors(dataset.dirty_df.copy(deep=True), schema=None) proposals = propose_fixes( issues, dataset.dirty_df.copy(deep=True), None, cache_dir=None, allow_llm=False, ) return [ BenchmarkRepair( row=proposal.fix.row, column=proposal.fix.column, new_value=proposal.fix.new_value, reason=proposal.reason, ) for proposal in proposals ] def run_heuristic_episode(dataset: RealWorldDataset, *, seed: int) -> SeedBenchmarkResult: """Run the current deterministic DataForge stack as the heuristic baseline.""" start = time.perf_counter() repairs = _repairs_from_proposed_fixes(dataset) metrics = score_repairs(dataset.ground_truth, repairs) runtime_s = round(time.perf_counter() - start, 4) return SeedBenchmarkResult( method="heuristic", dataset=dataset.metadata.name, seed=seed, status="ok", precision=metrics.precision, recall=metrics.recall, f1=metrics.f1, tp=metrics.tp, fp=metrics.fp, fn=metrics.fn, avg_steps=float(1 + len(repairs)), llm_calls=0, prompt_tokens=0, completion_tokens=0, quota_units=0.0, runtime_s=runtime_s, provider="local", model="deterministic", reproduction_command=_reproduction_command("heuristic", dataset.metadata.name, 1), ) def run_random_episode(dataset: RealWorldDataset, *, seed: int) -> SeedBenchmarkResult: """Run the bounded random baseline on one dataset.""" rng = random.Random(seed) start = time.perf_counter() budget = min(200, max(25, ceil(len(dataset.ground_truth) / 10))) column_values = { column: [str(value) for value in dataset.dirty_df[column].tolist()] for column in dataset.canonical_columns } repairs: list[BenchmarkRepair] = [] for _ in range(budget): row_index = rng.randrange(len(dataset.dirty_df.index)) column = rng.choice(dataset.canonical_columns) new_value = rng.choice(column_values[column]) repairs.append( BenchmarkRepair( row=row_index, column=column, new_value=new_value, reason="random baseline", ) ) metrics = score_repairs(dataset.ground_truth, repairs) runtime_s = round(time.perf_counter() - start, 4) return SeedBenchmarkResult( method="random", dataset=dataset.metadata.name, seed=seed, status="ok", precision=metrics.precision, recall=metrics.recall, f1=metrics.f1, tp=metrics.tp, fp=metrics.fp, fn=metrics.fn, avg_steps=float(budget), llm_calls=0, prompt_tokens=0, completion_tokens=0, quota_units=0.0, runtime_s=runtime_s, provider="local", model="random", reproduction_command=_reproduction_command("random", dataset.metadata.name, 1), ) def _chunk_records(dataset: RealWorldDataset, row_indices: tuple[int, ...]) -> list[dict[str, str]]: """Serialize one row chunk for prompting.""" records: list[dict[str, str]] = [] for row_index in row_indices: row_payload: dict[str, str] = {"_row": str(row_index)} for column in dataset.canonical_columns: row_payload[column] = str(dataset.dirty_df.iloc[row_index][column]) records.append(row_payload) return records def _column_stats( dataset: RealWorldDataset, columns: list[str] ) -> dict[str, dict[str, str | float | int]]: """Return simple benchmark-local column statistics for ReAct prompting.""" stats: dict[str, dict[str, str | float | int]] = {} for column in columns: series = dataset.dirty_df[column].astype(str) non_empty = [value for value in series.tolist() if value != ""] numeric_values: list[float] = [] for value in non_empty: try: numeric_values.append(float(value)) except ValueError: continue stats[column] = { "non_empty_count": len(non_empty), "unique_count": len(set(non_empty)), } if numeric_values: stats[column]["median"] = round(float(median(numeric_values)), 4) return stats def _extract_json_object(text: str) -> dict[str, object] | None: """Parse the first JSON object found in an LLM response string.""" stripped = text.strip() if stripped.startswith("```"): stripped = stripped.strip("`") if stripped.lower().startswith("json"): stripped = stripped[4:].strip() decoder = json.JSONDecoder() for offset, char in enumerate(stripped): if char != "{": continue try: payload, _ = decoder.raw_decode(stripped[offset:]) except json.JSONDecodeError: continue if isinstance(payload, dict): return payload return None def _repairs_from_payload(payload: dict[str, object]) -> list[BenchmarkRepair]: """Convert a parsed JSON payload into benchmark repairs.""" raw_repairs = payload.get("repairs", []) if not isinstance(raw_repairs, list): return [] repairs: list[BenchmarkRepair] = [] for raw_repair in raw_repairs: if not isinstance(raw_repair, dict): continue row = raw_repair.get("row") column = raw_repair.get("column") new_value = raw_repair.get("new_value") reason = raw_repair.get("reason", "LLM repair") if ( not isinstance(row, int) or not isinstance(column, str) or not isinstance(new_value, str) ): continue repairs.append( BenchmarkRepair( row=row, column=column, new_value=new_value, reason=str(reason), ) ) return repairs def run_llm_zeroshot_episode( dataset: RealWorldDataset, *, seed: int, client: GroqBenchClient, ) -> SeedBenchmarkResult: """Run the zero-shot Groq baseline across fixed contiguous row chunks.""" start = time.perf_counter() llm_calls = 0 prompt_tokens = 0 completion_tokens = 0 warnings: list[str] = [] repairs: list[BenchmarkRepair] = [] for row_indices in chunk_row_indices(len(dataset.dirty_df.index)): chunk_payload = _chunk_records(dataset, row_indices) messages = [ { "role": "system", "content": ( "You are benchmarking tabular data cleaning. Reply with strict JSON: " '{"repairs":[{"row":0,"column":"Column","new_value":"value","reason":"why"}]}.' ), }, { "role": "user", "content": json.dumps( { "dataset": dataset.metadata.name, "columns": list(dataset.canonical_columns), "rows": chunk_payload, }, sort_keys=True, ), }, ] completion = client.complete(messages) llm_calls += 1 prompt_tokens += completion.prompt_tokens completion_tokens += completion.completion_tokens warnings.extend(list(completion.warnings)) parsed = _extract_json_object(completion.text) if parsed is not None: repairs.extend(_repairs_from_payload(parsed)) metrics = score_repairs(dataset.ground_truth, repairs) runtime_s = round(time.perf_counter() - start, 4) return SeedBenchmarkResult( method="llm_zeroshot", dataset=dataset.metadata.name, seed=seed, status="ok", precision=metrics.precision, recall=metrics.recall, f1=metrics.f1, tp=metrics.tp, fp=metrics.fp, fn=metrics.fn, avg_steps=float(llm_calls), llm_calls=llm_calls, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, quota_units=quota_units( llm_calls=llm_calls, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ), runtime_s=runtime_s, provider="groq", model=client.model, warnings=warnings, reproduction_command=_reproduction_command("llm_zeroshot", dataset.metadata.name, 1), ) def run_llm_react_episode( dataset: RealWorldDataset, *, seed: int, client: GroqBenchClient, ) -> SeedBenchmarkResult: """Run the constrained ReAct-style Groq baseline with one optional tool step.""" start = time.perf_counter() llm_calls = 0 tool_calls = 0 prompt_tokens = 0 completion_tokens = 0 warnings: list[str] = [] repairs: list[BenchmarkRepair] = [] for row_indices in chunk_row_indices(len(dataset.dirty_df.index)): chunk_payload = _chunk_records(dataset, row_indices) schema_summary = { "dataset": dataset.metadata.name, "columns": list(dataset.canonical_columns), "chunk_rows": len(row_indices), } messages = [ { "role": "system", "content": ( "You are benchmarking tabular data cleaning with a constrained tool loop. " "Respond with one JSON action object. Allowed actions: " "inspect_rows, column_stats, submit_repairs, finish." ), }, { "role": "user", "content": json.dumps( { "schema_summary": schema_summary, "rows": chunk_payload, }, sort_keys=True, ), }, ] first = client.complete(messages) llm_calls += 1 prompt_tokens += first.prompt_tokens completion_tokens += first.completion_tokens warnings.extend(list(first.warnings)) first_payload = _extract_json_object(first.text) if first_payload is None: continue action = first_payload.get("action") if action == "submit_repairs": repairs.extend(_repairs_from_payload(first_payload)) continue if action == "finish": continue tool_result: dict[str, object] if action == "inspect_rows": requested_rows = first_payload.get("row_indices", []) if not isinstance(requested_rows, list): requested_rows = [] safe_rows = [ row for row in requested_rows if isinstance(row, int) and row in row_indices ] tool_result = {"rows": _chunk_records(dataset, tuple(safe_rows))} elif action == "column_stats": requested_columns = first_payload.get("columns", []) if not isinstance(requested_columns, list): requested_columns = [] safe_columns = [ column for column in requested_columns if isinstance(column, str) and column in dataset.canonical_columns ] tool_result = {"column_stats": _column_stats(dataset, safe_columns)} else: continue tool_calls += 1 messages.append({"role": "assistant", "content": first.text}) messages.append({"role": "user", "content": json.dumps(tool_result, sort_keys=True)}) second = client.complete(messages) llm_calls += 1 prompt_tokens += second.prompt_tokens completion_tokens += second.completion_tokens warnings.extend(list(second.warnings)) second_payload = _extract_json_object(second.text) if second_payload is not None and second_payload.get("action") == "submit_repairs": repairs.extend(_repairs_from_payload(second_payload)) metrics = score_repairs(dataset.ground_truth, repairs) runtime_s = round(time.perf_counter() - start, 4) return SeedBenchmarkResult( method="llm_react", dataset=dataset.metadata.name, seed=seed, status="ok", precision=metrics.precision, recall=metrics.recall, f1=metrics.f1, tp=metrics.tp, fp=metrics.fp, fn=metrics.fn, avg_steps=float(llm_calls + tool_calls), llm_calls=llm_calls, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, quota_units=quota_units( llm_calls=llm_calls, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ), runtime_s=runtime_s, provider="groq", model=client.model, warnings=warnings, reproduction_command=_reproduction_command("llm_react", dataset.metadata.name, 1), ) def chunk_row_indices(n_rows: int) -> tuple[tuple[int, ...], ...]: """Local import wrapper that avoids circular imports in the LLM helpers.""" from dataforge.bench.core import chunk_row_indices as _chunk_row_indices return _chunk_row_indices(n_rows)