"""Dataset loading — flat, config-per-model, PR-based. OCR column discovery.""" from __future__ import annotations import json import structlog from datasets import Dataset, get_dataset_config_names, load_dataset from huggingface_hub import HfApi logger = structlog.get_logger() class DatasetError(Exception): """Raised when dataset loading or column discovery fails.""" # --------------------------------------------------------------------------- # OCR column discovery # --------------------------------------------------------------------------- def discover_ocr_columns(dataset: Dataset) -> dict[str, str]: """Discover OCR output columns and their model names from a dataset. Strategy: 1. Parse ``inference_info`` JSON from the first row (list or single entry). 2. Fallback: heuristic column-name matching (``markdown``, ``ocr``, ``text``). 3. Disambiguate duplicate model names by appending the column name. Returns: Mapping of ``column_name → model_name``. Raises: DatasetError: If no OCR columns can be found. """ columns: dict[str, str] = {} try: if "inference_info" not in dataset.column_names: raise KeyError("no inference_info column") info_raw = dataset["inference_info"][0] # column access avoids image decode if info_raw: info = json.loads(info_raw) if not isinstance(info, list): info = [info] for entry in info: col = entry.get("column_name", "") model = entry.get("model_id", entry.get("model_name", "unknown")) if col and col in dataset.column_names: columns[col] = model except (json.JSONDecodeError, TypeError, KeyError) as exc: logger.warning("could_not_parse_inference_info", error=str(exc)) # Fallback: heuristic if not columns: for col in dataset.column_names: lower = col.lower() if "markdown" in lower or "ocr" in lower or col == "text": columns[col] = col if not columns: raise DatasetError(f"No OCR columns found. Available columns: {dataset.column_names}") # Disambiguate duplicates model_counts: dict[str, int] = {} for model in columns.values(): model_counts[model] = model_counts.get(model, 0) + 1 disambiguated: dict[str, str] = {} for col, model in columns.items(): if model_counts[model] > 1: short = model.split("/")[-1] if "/" in model else model disambiguated[col] = f"{short} ({col})" else: disambiguated[col] = model return disambiguated # --------------------------------------------------------------------------- # PR-based config discovery # --------------------------------------------------------------------------- def discover_pr_configs( repo_id: str, merge: bool = False, api: HfApi | None = None, ) -> tuple[list[str], dict[str, str]]: """Discover dataset configs from open PRs on a Hub dataset repo. PR titles must end with ``[config_name]`` to be detected. Args: repo_id: HF dataset repo id. merge: If True, merge each discovered PR before loading. api: Optional pre-configured HfApi instance. Returns: Tuple of (config_names, {config_name: pr_revision}). """ if api is None: api = HfApi() config_names: list[str] = [] revisions: dict[str, str] = {} discussions = api.get_repo_discussions(repo_id, repo_type="dataset") for disc in discussions: if not disc.is_pull_request or disc.status != "open": continue title = disc.title if "[" in title and title.endswith("]"): config = title[title.rindex("[") + 1 : -1].strip() if config: if merge: api.merge_pull_request(repo_id, disc.num, repo_type="dataset") logger.info("merged_pr", pr=disc.num, config=config) else: revisions[config] = f"refs/pr/{disc.num}" config_names.append(config) return config_names, revisions def discover_configs(repo_id: str) -> list[str]: """List non-default configs from the main branch of a Hub dataset. Returns: Config names excluding "default", or empty list if none found. """ try: configs = get_dataset_config_names(repo_id) except Exception as exc: logger.info("no_configs_on_main", repo=repo_id, reason=str(exc)) return [] return [c for c in configs if c != "default"] # --------------------------------------------------------------------------- # Config-per-model loading # --------------------------------------------------------------------------- def load_config_dataset( repo_id: str, config_names: list[str], split: str = "train", pr_revisions: dict[str, str] | None = None, ) -> tuple[Dataset, dict[str, str]]: """Load multiple configs from a Hub dataset and merge into one. Each config becomes a column whose name is the config name and whose value is the OCR text (from the first column matching heuristics, or ``markdown``). Args: repo_id: HF dataset repo id. config_names: List of config names to load. split: Dataset split to load. pr_revisions: Optional mapping of config_name → revision for PR-based loading. Returns: Tuple of (unified Dataset, {column_name: model_id}). """ if not config_names: raise DatasetError("No config names provided") pr_revisions = pr_revisions or {} unified: Dataset | None = None ocr_columns: dict[str, str] = {} for config in config_names: revision = pr_revisions.get(config) kwargs: dict = {"path": repo_id, "name": config, "split": split} if revision: kwargs["revision"] = revision ds = load_dataset(**kwargs) # Find the OCR text column in this config text_col = _find_text_column(ds) if text_col is None: logger.warning("no_text_column_in_config", config=config) continue # Extract model_id from inference_info if available model_id = _extract_model_id(ds, config) ocr_columns[config] = model_id # Build unified dataset using Arrow-level ops (no per-row image decode) text_values = ds[text_col] # column access — no image decoding if unified is None: # First config: keep all columns except text_col, add text as config name drop = [text_col] if text_col != config else [] unified = ds.remove_columns(drop) if drop else ds if config != text_col: unified = unified.add_column(config, text_values) # Also rename text_col to config if they differ and text_col was kept else: if len(ds) != len(unified): logger.warning( "config_length_mismatch", config=config, expected=len(unified), got=len(ds), ) text_values = text_values[: len(unified)] unified = unified.add_column(config, text_values) if unified is None: raise DatasetError("No configs loaded successfully") return unified, ocr_columns def _extract_model_id(ds: Dataset, config: str) -> str: """Extract model_id from inference_info in first row, falling back to config name.""" if "inference_info" not in ds.column_names: return config try: info_raw = ds["inference_info"][0] # column access avoids image decode if info_raw: info = json.loads(info_raw) if isinstance(info, list): info = info[0] return info.get("model_id", info.get("model_name", config)) except (json.JSONDecodeError, TypeError, KeyError, IndexError): pass return config def _find_text_column(ds: Dataset) -> str | None: """Find the likely OCR text column in a dataset. Priority: 1. ``inference_info[0]["column_name"]`` if present and exists in dataset. 2. First column matching ``markdown`` (case-insensitive). 3. First column matching ``ocr`` (case-insensitive). 4. Column named exactly ``text``. """ # Try inference_info first (column access avoids image decoding) if "inference_info" in ds.column_names: try: info_raw = ds["inference_info"][0] if info_raw: info = json.loads(info_raw) if isinstance(info, list): info = info[0] col_name = info.get("column_name", "") if col_name and col_name in ds.column_names: return col_name except (json.JSONDecodeError, TypeError, KeyError, IndexError): pass # Prioritized heuristic: markdown > ocr > text for pattern in ["markdown", "ocr"]: for col in ds.column_names: if pattern in col.lower(): return col if "text" in ds.column_names: return "text" return None # --------------------------------------------------------------------------- # Flat dataset loading # --------------------------------------------------------------------------- def load_flat_dataset( repo_id: str, split: str = "train", columns: list[str] | None = None, ) -> tuple[Dataset, dict[str, str]]: """Load a flat dataset from Hub and discover OCR columns. Args: repo_id: HF dataset repo id. split: Dataset split. columns: If given, use these as OCR columns (maps col→col). Returns: Tuple of (Dataset, {column_name: model_name}). """ ds = load_dataset(repo_id, split=split) if columns: # Validate columns exist for col in columns: if col not in ds.column_names: raise DatasetError(f"Column '{col}' not found. Available: {ds.column_names}") ocr_columns = {col: col for col in columns} else: ocr_columns = discover_ocr_columns(ds) return ds, ocr_columns