Spaces:
Sleeping
Sleeping
| """Benchmark method implementations for DataForge.""" | |
| from __future__ import annotations | |
| import json | |
| import random | |
| import time | |
| from math import ceil | |
| from statistics import median | |
| from dataforge.bench.core import BenchmarkRepair, SeedBenchmarkResult, quota_units, score_repairs | |
| from dataforge.bench.groq_client import GroqBenchClient | |
| from dataforge.datasets.real_world import RealWorldDataset | |
| from dataforge.detectors import run_all_detectors | |
| from dataforge.repairers import propose_fixes | |
| def _reproduction_command(method: str, dataset: str, seeds: int) -> str: | |
| """Build the canonical reproduction command for one method/dataset selection.""" | |
| return f"dataforge bench --methods {method} --datasets {dataset} --seeds {seeds}" | |
| def _repairs_from_proposed_fixes(dataset: RealWorldDataset) -> list[BenchmarkRepair]: | |
| """Run the shipped deterministic detector/repair stack on one dataset.""" | |
| issues = run_all_detectors(dataset.dirty_df.copy(deep=True), schema=None) | |
| proposals = propose_fixes( | |
| issues, | |
| dataset.dirty_df.copy(deep=True), | |
| None, | |
| cache_dir=None, | |
| allow_llm=False, | |
| ) | |
| return [ | |
| BenchmarkRepair( | |
| row=proposal.fix.row, | |
| column=proposal.fix.column, | |
| new_value=proposal.fix.new_value, | |
| reason=proposal.reason, | |
| ) | |
| for proposal in proposals | |
| ] | |
| def run_heuristic_episode(dataset: RealWorldDataset, *, seed: int) -> SeedBenchmarkResult: | |
| """Run the current deterministic DataForge stack as the heuristic baseline.""" | |
| start = time.perf_counter() | |
| repairs = _repairs_from_proposed_fixes(dataset) | |
| metrics = score_repairs(dataset.ground_truth, repairs) | |
| runtime_s = round(time.perf_counter() - start, 4) | |
| return SeedBenchmarkResult( | |
| method="heuristic", | |
| dataset=dataset.metadata.name, | |
| seed=seed, | |
| status="ok", | |
| precision=metrics.precision, | |
| recall=metrics.recall, | |
| f1=metrics.f1, | |
| tp=metrics.tp, | |
| fp=metrics.fp, | |
| fn=metrics.fn, | |
| avg_steps=float(1 + len(repairs)), | |
| llm_calls=0, | |
| prompt_tokens=0, | |
| completion_tokens=0, | |
| quota_units=0.0, | |
| runtime_s=runtime_s, | |
| provider="local", | |
| model="deterministic", | |
| reproduction_command=_reproduction_command("heuristic", dataset.metadata.name, 1), | |
| ) | |
| def run_random_episode(dataset: RealWorldDataset, *, seed: int) -> SeedBenchmarkResult: | |
| """Run the bounded random baseline on one dataset.""" | |
| rng = random.Random(seed) | |
| start = time.perf_counter() | |
| budget = min(200, max(25, ceil(len(dataset.ground_truth) / 10))) | |
| column_values = { | |
| column: [str(value) for value in dataset.dirty_df[column].tolist()] | |
| for column in dataset.canonical_columns | |
| } | |
| repairs: list[BenchmarkRepair] = [] | |
| for _ in range(budget): | |
| row_index = rng.randrange(len(dataset.dirty_df.index)) | |
| column = rng.choice(dataset.canonical_columns) | |
| new_value = rng.choice(column_values[column]) | |
| repairs.append( | |
| BenchmarkRepair( | |
| row=row_index, | |
| column=column, | |
| new_value=new_value, | |
| reason="random baseline", | |
| ) | |
| ) | |
| metrics = score_repairs(dataset.ground_truth, repairs) | |
| runtime_s = round(time.perf_counter() - start, 4) | |
| return SeedBenchmarkResult( | |
| method="random", | |
| dataset=dataset.metadata.name, | |
| seed=seed, | |
| status="ok", | |
| precision=metrics.precision, | |
| recall=metrics.recall, | |
| f1=metrics.f1, | |
| tp=metrics.tp, | |
| fp=metrics.fp, | |
| fn=metrics.fn, | |
| avg_steps=float(budget), | |
| llm_calls=0, | |
| prompt_tokens=0, | |
| completion_tokens=0, | |
| quota_units=0.0, | |
| runtime_s=runtime_s, | |
| provider="local", | |
| model="random", | |
| reproduction_command=_reproduction_command("random", dataset.metadata.name, 1), | |
| ) | |
| def _chunk_records(dataset: RealWorldDataset, row_indices: tuple[int, ...]) -> list[dict[str, str]]: | |
| """Serialize one row chunk for prompting.""" | |
| records: list[dict[str, str]] = [] | |
| for row_index in row_indices: | |
| row_payload: dict[str, str] = {"_row": str(row_index)} | |
| for column in dataset.canonical_columns: | |
| row_payload[column] = str(dataset.dirty_df.iloc[row_index][column]) | |
| records.append(row_payload) | |
| return records | |
| def _column_stats( | |
| dataset: RealWorldDataset, columns: list[str] | |
| ) -> dict[str, dict[str, str | float | int]]: | |
| """Return simple benchmark-local column statistics for ReAct prompting.""" | |
| stats: dict[str, dict[str, str | float | int]] = {} | |
| for column in columns: | |
| series = dataset.dirty_df[column].astype(str) | |
| non_empty = [value for value in series.tolist() if value != ""] | |
| numeric_values: list[float] = [] | |
| for value in non_empty: | |
| try: | |
| numeric_values.append(float(value)) | |
| except ValueError: | |
| continue | |
| stats[column] = { | |
| "non_empty_count": len(non_empty), | |
| "unique_count": len(set(non_empty)), | |
| } | |
| if numeric_values: | |
| stats[column]["median"] = round(float(median(numeric_values)), 4) | |
| return stats | |
| def _extract_json_object(text: str) -> dict[str, object] | None: | |
| """Parse the first JSON object found in an LLM response string.""" | |
| stripped = text.strip() | |
| if stripped.startswith("```"): | |
| stripped = stripped.strip("`") | |
| if stripped.lower().startswith("json"): | |
| stripped = stripped[4:].strip() | |
| decoder = json.JSONDecoder() | |
| for offset, char in enumerate(stripped): | |
| if char != "{": | |
| continue | |
| try: | |
| payload, _ = decoder.raw_decode(stripped[offset:]) | |
| except json.JSONDecodeError: | |
| continue | |
| if isinstance(payload, dict): | |
| return payload | |
| return None | |
| def _repairs_from_payload(payload: dict[str, object]) -> list[BenchmarkRepair]: | |
| """Convert a parsed JSON payload into benchmark repairs.""" | |
| raw_repairs = payload.get("repairs", []) | |
| if not isinstance(raw_repairs, list): | |
| return [] | |
| repairs: list[BenchmarkRepair] = [] | |
| for raw_repair in raw_repairs: | |
| if not isinstance(raw_repair, dict): | |
| continue | |
| row = raw_repair.get("row") | |
| column = raw_repair.get("column") | |
| new_value = raw_repair.get("new_value") | |
| reason = raw_repair.get("reason", "LLM repair") | |
| if ( | |
| not isinstance(row, int) | |
| or not isinstance(column, str) | |
| or not isinstance(new_value, str) | |
| ): | |
| continue | |
| repairs.append( | |
| BenchmarkRepair( | |
| row=row, | |
| column=column, | |
| new_value=new_value, | |
| reason=str(reason), | |
| ) | |
| ) | |
| return repairs | |
| def run_llm_zeroshot_episode( | |
| dataset: RealWorldDataset, | |
| *, | |
| seed: int, | |
| client: GroqBenchClient, | |
| ) -> SeedBenchmarkResult: | |
| """Run the zero-shot Groq baseline across fixed contiguous row chunks.""" | |
| start = time.perf_counter() | |
| llm_calls = 0 | |
| prompt_tokens = 0 | |
| completion_tokens = 0 | |
| warnings: list[str] = [] | |
| repairs: list[BenchmarkRepair] = [] | |
| for row_indices in chunk_row_indices(len(dataset.dirty_df.index)): | |
| chunk_payload = _chunk_records(dataset, row_indices) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are benchmarking tabular data cleaning. Reply with strict JSON: " | |
| '{"repairs":[{"row":0,"column":"Column","new_value":"value","reason":"why"}]}.' | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": json.dumps( | |
| { | |
| "dataset": dataset.metadata.name, | |
| "columns": list(dataset.canonical_columns), | |
| "rows": chunk_payload, | |
| }, | |
| sort_keys=True, | |
| ), | |
| }, | |
| ] | |
| completion = client.complete(messages) | |
| llm_calls += 1 | |
| prompt_tokens += completion.prompt_tokens | |
| completion_tokens += completion.completion_tokens | |
| warnings.extend(list(completion.warnings)) | |
| parsed = _extract_json_object(completion.text) | |
| if parsed is not None: | |
| repairs.extend(_repairs_from_payload(parsed)) | |
| metrics = score_repairs(dataset.ground_truth, repairs) | |
| runtime_s = round(time.perf_counter() - start, 4) | |
| return SeedBenchmarkResult( | |
| method="llm_zeroshot", | |
| dataset=dataset.metadata.name, | |
| seed=seed, | |
| status="ok", | |
| precision=metrics.precision, | |
| recall=metrics.recall, | |
| f1=metrics.f1, | |
| tp=metrics.tp, | |
| fp=metrics.fp, | |
| fn=metrics.fn, | |
| avg_steps=float(llm_calls), | |
| llm_calls=llm_calls, | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| quota_units=quota_units( | |
| llm_calls=llm_calls, | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| ), | |
| runtime_s=runtime_s, | |
| provider="groq", | |
| model=client.model, | |
| warnings=warnings, | |
| reproduction_command=_reproduction_command("llm_zeroshot", dataset.metadata.name, 1), | |
| ) | |
| def run_llm_react_episode( | |
| dataset: RealWorldDataset, | |
| *, | |
| seed: int, | |
| client: GroqBenchClient, | |
| ) -> SeedBenchmarkResult: | |
| """Run the constrained ReAct-style Groq baseline with one optional tool step.""" | |
| start = time.perf_counter() | |
| llm_calls = 0 | |
| tool_calls = 0 | |
| prompt_tokens = 0 | |
| completion_tokens = 0 | |
| warnings: list[str] = [] | |
| repairs: list[BenchmarkRepair] = [] | |
| for row_indices in chunk_row_indices(len(dataset.dirty_df.index)): | |
| chunk_payload = _chunk_records(dataset, row_indices) | |
| schema_summary = { | |
| "dataset": dataset.metadata.name, | |
| "columns": list(dataset.canonical_columns), | |
| "chunk_rows": len(row_indices), | |
| } | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are benchmarking tabular data cleaning with a constrained tool loop. " | |
| "Respond with one JSON action object. Allowed actions: " | |
| "inspect_rows, column_stats, submit_repairs, finish." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": json.dumps( | |
| { | |
| "schema_summary": schema_summary, | |
| "rows": chunk_payload, | |
| }, | |
| sort_keys=True, | |
| ), | |
| }, | |
| ] | |
| first = client.complete(messages) | |
| llm_calls += 1 | |
| prompt_tokens += first.prompt_tokens | |
| completion_tokens += first.completion_tokens | |
| warnings.extend(list(first.warnings)) | |
| first_payload = _extract_json_object(first.text) | |
| if first_payload is None: | |
| continue | |
| action = first_payload.get("action") | |
| if action == "submit_repairs": | |
| repairs.extend(_repairs_from_payload(first_payload)) | |
| continue | |
| if action == "finish": | |
| continue | |
| tool_result: dict[str, object] | |
| if action == "inspect_rows": | |
| requested_rows = first_payload.get("row_indices", []) | |
| if not isinstance(requested_rows, list): | |
| requested_rows = [] | |
| safe_rows = [ | |
| row for row in requested_rows if isinstance(row, int) and row in row_indices | |
| ] | |
| tool_result = {"rows": _chunk_records(dataset, tuple(safe_rows))} | |
| elif action == "column_stats": | |
| requested_columns = first_payload.get("columns", []) | |
| if not isinstance(requested_columns, list): | |
| requested_columns = [] | |
| safe_columns = [ | |
| column | |
| for column in requested_columns | |
| if isinstance(column, str) and column in dataset.canonical_columns | |
| ] | |
| tool_result = {"column_stats": _column_stats(dataset, safe_columns)} | |
| else: | |
| continue | |
| tool_calls += 1 | |
| messages.append({"role": "assistant", "content": first.text}) | |
| messages.append({"role": "user", "content": json.dumps(tool_result, sort_keys=True)}) | |
| second = client.complete(messages) | |
| llm_calls += 1 | |
| prompt_tokens += second.prompt_tokens | |
| completion_tokens += second.completion_tokens | |
| warnings.extend(list(second.warnings)) | |
| second_payload = _extract_json_object(second.text) | |
| if second_payload is not None and second_payload.get("action") == "submit_repairs": | |
| repairs.extend(_repairs_from_payload(second_payload)) | |
| metrics = score_repairs(dataset.ground_truth, repairs) | |
| runtime_s = round(time.perf_counter() - start, 4) | |
| return SeedBenchmarkResult( | |
| method="llm_react", | |
| dataset=dataset.metadata.name, | |
| seed=seed, | |
| status="ok", | |
| precision=metrics.precision, | |
| recall=metrics.recall, | |
| f1=metrics.f1, | |
| tp=metrics.tp, | |
| fp=metrics.fp, | |
| fn=metrics.fn, | |
| avg_steps=float(llm_calls + tool_calls), | |
| llm_calls=llm_calls, | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| quota_units=quota_units( | |
| llm_calls=llm_calls, | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| ), | |
| runtime_s=runtime_s, | |
| provider="groq", | |
| model=client.model, | |
| warnings=warnings, | |
| reproduction_command=_reproduction_command("llm_react", dataset.metadata.name, 1), | |
| ) | |
| def chunk_row_indices(n_rows: int) -> tuple[tuple[int, ...], ...]: | |
| """Local import wrapper that avoids circular imports in the LLM helpers.""" | |
| from dataforge.bench.core import chunk_row_indices as _chunk_row_indices | |
| return _chunk_row_indices(n_rows) | |