Spaces:
Running
Running
| """Download, cache, and align real-world benchmark datasets.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import httpx | |
| import pandas as pd | |
| from pydantic import BaseModel, Field | |
| from dataforge.datasets.registry import DatasetMetadata, HeaderMismatch, get_dataset_metadata | |
| class DatasetDownloadError(RuntimeError): | |
| """Raised when a real-world dataset cannot be downloaded or loaded from cache.""" | |
| class GroundTruthCell(BaseModel): | |
| """Single cell-level dirty-to-clean correction used for benchmark scoring.""" | |
| row: int = Field(ge=0) | |
| column: str = Field(min_length=1) | |
| dirty_value: str | |
| clean_value: str | |
| model_config = {"frozen": True} | |
| class RealWorldDataset: | |
| """Loaded real-world dataset with aligned dirty/clean DataFrames.""" | |
| metadata: DatasetMetadata | |
| dirty_df: pd.DataFrame | |
| clean_df: pd.DataFrame | |
| canonical_columns: tuple[str, ...] | |
| ground_truth: tuple[GroundTruthCell, ...] | |
| def _resolve_cache_root(cache_root: Path | None) -> Path: | |
| """Resolve the root benchmark cache directory.""" | |
| if cache_root is not None: | |
| return cache_root | |
| return Path.home() / ".dataforge" / "cache" | |
| def _dataset_cache_dir(dataset_name: str, *, cache_root: Path | None) -> Path: | |
| """Return the cache directory for one dataset.""" | |
| return _resolve_cache_root(cache_root) / "real_world" / dataset_name | |
| def _read_cached_csv(path: Path) -> pd.DataFrame: | |
| """Read a cached CSV using string-preserving defaults.""" | |
| return pd.read_csv(path, dtype=str, keep_default_na=False, na_filter=False) | |
| def _download_bytes(url: str) -> bytes: | |
| """Download raw CSV bytes from an upstream source URL.""" | |
| with httpx.Client(timeout=60.0, follow_redirects=True) as client: | |
| response = client.get(url) | |
| response.raise_for_status() | |
| return response.content | |
| def _download_to_cache(metadata: DatasetMetadata, dataset_dir: Path) -> None: | |
| """Download dirty/clean CSV files into the dataset cache directory.""" | |
| dataset_dir.mkdir(parents=True, exist_ok=True) | |
| dirty_url, clean_url = metadata.source_urls | |
| (dataset_dir / "dirty.csv").write_bytes(_download_bytes(dirty_url)) | |
| (dataset_dir / "clean.csv").write_bytes(_download_bytes(clean_url)) | |
| def _manual_download_message(metadata: DatasetMetadata, dataset_dir: Path, cause: Exception) -> str: | |
| """Build a user-facing manual download error message.""" | |
| dirty_url, clean_url = metadata.source_urls | |
| return ( | |
| f"Could not download dataset '{metadata.name}' and no cached copy was found.\n\n" | |
| f"Cause: {cause}\n" | |
| f"Cache target: {dataset_dir}\n" | |
| f"Dirty URL: {dirty_url}\n" | |
| f"Clean URL: {clean_url}\n\n" | |
| "How to download manually:\n" | |
| f"1. Download both CSV files from the URLs above into '{dataset_dir}'.\n" | |
| "2. Save them exactly as 'dirty.csv' and 'clean.csv', then rerun the benchmark." | |
| ) | |
| def _header_mismatches( | |
| dirty_columns: list[str], clean_columns: list[str] | |
| ) -> tuple[HeaderMismatch, ...]: | |
| """Collect header-name mismatches across aligned dirty/clean columns.""" | |
| mismatches: list[HeaderMismatch] = [] | |
| for dirty_name, clean_name in zip(dirty_columns, clean_columns, strict=True): | |
| if dirty_name != clean_name: | |
| mismatches.append(HeaderMismatch(dirty_name=dirty_name, clean_name=clean_name)) | |
| return tuple(mismatches) | |
| def _compute_ground_truth( | |
| dirty_df: pd.DataFrame, | |
| clean_df: pd.DataFrame, | |
| ) -> tuple[GroundTruthCell, ...]: | |
| """Compute cell-level dirty-to-clean diffs across aligned DataFrames.""" | |
| ground_truth: list[GroundTruthCell] = [] | |
| for row_index, (dirty_row, clean_row) in enumerate( | |
| zip( | |
| dirty_df.itertuples(index=False, name=None), | |
| clean_df.itertuples(index=False, name=None), | |
| strict=True, | |
| ) | |
| ): | |
| for column, dirty_value, clean_value in zip( | |
| clean_df.columns, | |
| dirty_row, | |
| clean_row, | |
| strict=True, | |
| ): | |
| dirty_text = str(dirty_value) | |
| clean_text = str(clean_value) | |
| if dirty_text != clean_text: | |
| ground_truth.append( | |
| GroundTruthCell( | |
| row=row_index, | |
| column=str(column), | |
| dirty_value=dirty_text, | |
| clean_value=clean_text, | |
| ) | |
| ) | |
| return tuple(ground_truth) | |
| def load_real_world_dataset( | |
| name: str, | |
| *, | |
| cache_root: Path | None = None, | |
| ) -> RealWorldDataset: | |
| """Load a real-world benchmark dataset from cache or upstream. | |
| Args: | |
| name: Canonical dataset name. | |
| cache_root: Optional cache root override, mainly for tests. | |
| Returns: | |
| The aligned dirty/clean dataset bundle. | |
| Raises: | |
| DatasetDownloadError: If the dataset is not cached and download fails. | |
| ValueError: If dirty/clean files disagree on row or column count. | |
| """ | |
| metadata = get_dataset_metadata(name) | |
| dataset_dir = _dataset_cache_dir(name, cache_root=cache_root) | |
| dirty_path = dataset_dir / "dirty.csv" | |
| clean_path = dataset_dir / "clean.csv" | |
| if not dirty_path.exists() or not clean_path.exists(): | |
| try: | |
| _download_to_cache(metadata, dataset_dir) | |
| except Exception as exc: # pragma: no cover - exercised through tests via monkeypatch | |
| raise DatasetDownloadError( | |
| _manual_download_message(metadata, dataset_dir, exc) | |
| ) from exc | |
| dirty_df = _read_cached_csv(dirty_path) | |
| clean_df = _read_cached_csv(clean_path) | |
| if len(dirty_df.index) != len(clean_df.index): | |
| raise ValueError(f"Dataset '{name}' dirty/clean row counts do not match.") | |
| if len(dirty_df.columns) != len(clean_df.columns): | |
| raise ValueError(f"Dataset '{name}' dirty/clean column counts do not match.") | |
| clean_columns = [str(column) for column in clean_df.columns] | |
| mismatches = _header_mismatches( | |
| [str(column) for column in dirty_df.columns], | |
| clean_columns, | |
| ) | |
| dirty_df.columns = clean_columns | |
| clean_df.columns = clean_columns | |
| loaded_metadata = metadata.model_copy( | |
| update={ | |
| "n_rows": len(clean_df.index), | |
| "n_columns": len(clean_columns), | |
| "header_mismatches": mismatches, | |
| } | |
| ) | |
| return RealWorldDataset( | |
| metadata=loaded_metadata, | |
| dirty_df=dirty_df, | |
| clean_df=clean_df, | |
| canonical_columns=tuple(clean_columns), | |
| ground_truth=_compute_ground_truth(dirty_df, clean_df), | |
| ) | |