Pranesh
deploy: sync staged DataForge Space
66b1c50
"""Download, cache, and align real-world benchmark datasets."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import httpx
import pandas as pd
from pydantic import BaseModel, Field
from dataforge.datasets.registry import DatasetMetadata, HeaderMismatch, get_dataset_metadata
class DatasetDownloadError(RuntimeError):
"""Raised when a real-world dataset cannot be downloaded or loaded from cache."""
class GroundTruthCell(BaseModel):
"""Single cell-level dirty-to-clean correction used for benchmark scoring."""
row: int = Field(ge=0)
column: str = Field(min_length=1)
dirty_value: str
clean_value: str
model_config = {"frozen": True}
@dataclass(frozen=True, kw_only=True)
class RealWorldDataset:
"""Loaded real-world dataset with aligned dirty/clean DataFrames."""
metadata: DatasetMetadata
dirty_df: pd.DataFrame
clean_df: pd.DataFrame
canonical_columns: tuple[str, ...]
ground_truth: tuple[GroundTruthCell, ...]
def _resolve_cache_root(cache_root: Path | None) -> Path:
"""Resolve the root benchmark cache directory."""
if cache_root is not None:
return cache_root
return Path.home() / ".dataforge" / "cache"
def _dataset_cache_dir(dataset_name: str, *, cache_root: Path | None) -> Path:
"""Return the cache directory for one dataset."""
return _resolve_cache_root(cache_root) / "real_world" / dataset_name
def _read_cached_csv(path: Path) -> pd.DataFrame:
"""Read a cached CSV using string-preserving defaults."""
return pd.read_csv(path, dtype=str, keep_default_na=False, na_filter=False)
def _download_bytes(url: str) -> bytes:
"""Download raw CSV bytes from an upstream source URL."""
with httpx.Client(timeout=60.0, follow_redirects=True) as client:
response = client.get(url)
response.raise_for_status()
return response.content
def _download_to_cache(metadata: DatasetMetadata, dataset_dir: Path) -> None:
"""Download dirty/clean CSV files into the dataset cache directory."""
dataset_dir.mkdir(parents=True, exist_ok=True)
dirty_url, clean_url = metadata.source_urls
(dataset_dir / "dirty.csv").write_bytes(_download_bytes(dirty_url))
(dataset_dir / "clean.csv").write_bytes(_download_bytes(clean_url))
def _manual_download_message(metadata: DatasetMetadata, dataset_dir: Path, cause: Exception) -> str:
"""Build a user-facing manual download error message."""
dirty_url, clean_url = metadata.source_urls
return (
f"Could not download dataset '{metadata.name}' and no cached copy was found.\n\n"
f"Cause: {cause}\n"
f"Cache target: {dataset_dir}\n"
f"Dirty URL: {dirty_url}\n"
f"Clean URL: {clean_url}\n\n"
"How to download manually:\n"
f"1. Download both CSV files from the URLs above into '{dataset_dir}'.\n"
"2. Save them exactly as 'dirty.csv' and 'clean.csv', then rerun the benchmark."
)
def _header_mismatches(
dirty_columns: list[str], clean_columns: list[str]
) -> tuple[HeaderMismatch, ...]:
"""Collect header-name mismatches across aligned dirty/clean columns."""
mismatches: list[HeaderMismatch] = []
for dirty_name, clean_name in zip(dirty_columns, clean_columns, strict=True):
if dirty_name != clean_name:
mismatches.append(HeaderMismatch(dirty_name=dirty_name, clean_name=clean_name))
return tuple(mismatches)
def _compute_ground_truth(
dirty_df: pd.DataFrame,
clean_df: pd.DataFrame,
) -> tuple[GroundTruthCell, ...]:
"""Compute cell-level dirty-to-clean diffs across aligned DataFrames."""
ground_truth: list[GroundTruthCell] = []
for row_index, (dirty_row, clean_row) in enumerate(
zip(
dirty_df.itertuples(index=False, name=None),
clean_df.itertuples(index=False, name=None),
strict=True,
)
):
for column, dirty_value, clean_value in zip(
clean_df.columns,
dirty_row,
clean_row,
strict=True,
):
dirty_text = str(dirty_value)
clean_text = str(clean_value)
if dirty_text != clean_text:
ground_truth.append(
GroundTruthCell(
row=row_index,
column=str(column),
dirty_value=dirty_text,
clean_value=clean_text,
)
)
return tuple(ground_truth)
def load_real_world_dataset(
name: str,
*,
cache_root: Path | None = None,
) -> RealWorldDataset:
"""Load a real-world benchmark dataset from cache or upstream.
Args:
name: Canonical dataset name.
cache_root: Optional cache root override, mainly for tests.
Returns:
The aligned dirty/clean dataset bundle.
Raises:
DatasetDownloadError: If the dataset is not cached and download fails.
ValueError: If dirty/clean files disagree on row or column count.
"""
metadata = get_dataset_metadata(name)
dataset_dir = _dataset_cache_dir(name, cache_root=cache_root)
dirty_path = dataset_dir / "dirty.csv"
clean_path = dataset_dir / "clean.csv"
if not dirty_path.exists() or not clean_path.exists():
try:
_download_to_cache(metadata, dataset_dir)
except Exception as exc: # pragma: no cover - exercised through tests via monkeypatch
raise DatasetDownloadError(
_manual_download_message(metadata, dataset_dir, exc)
) from exc
dirty_df = _read_cached_csv(dirty_path)
clean_df = _read_cached_csv(clean_path)
if len(dirty_df.index) != len(clean_df.index):
raise ValueError(f"Dataset '{name}' dirty/clean row counts do not match.")
if len(dirty_df.columns) != len(clean_df.columns):
raise ValueError(f"Dataset '{name}' dirty/clean column counts do not match.")
clean_columns = [str(column) for column in clean_df.columns]
mismatches = _header_mismatches(
[str(column) for column in dirty_df.columns],
clean_columns,
)
dirty_df.columns = clean_columns
clean_df.columns = clean_columns
loaded_metadata = metadata.model_copy(
update={
"n_rows": len(clean_df.index),
"n_columns": len(clean_columns),
"header_mismatches": mismatches,
}
)
return RealWorldDataset(
metadata=loaded_metadata,
dirty_df=dirty_df,
clean_df=clean_df,
canonical_columns=tuple(clean_columns),
ground_truth=_compute_ground_truth(dirty_df, clean_df),
)