"""Service layer for peeking at a HF dataset's shape without pulling the whole thing down. Tries the lightweight datasets-server API first; falls back to a short streaming pull for datasets that API doesn't cover (gated, script-based, or just not indexed yet). """ from __future__ import annotations from typing import Optional import requests from datasets import load_dataset _ROWS_URL = "https://datasets-server.huggingface.co/rows" _SPLITS_URL = "https://datasets-server.huggingface.co/splits" _REQUEST_TIMEOUT_SECONDS = 15 class DatasetInspectError(Exception): """Raised when we genuinely can't get a peek at a dataset, after trying both the fast path and the streaming fallback.""" def list_splits(repo_id: str, token: Optional[str] = None) -> list: """Returns [{"config": ..., "split": ...}, ...] for the dataset.""" headers = {"Authorization": f"Bearer {token}"} if token else {} resp = requests.get( _SPLITS_URL, params={"dataset": repo_id}, headers=headers, timeout=_REQUEST_TIMEOUT_SECONDS ) if resp.status_code != 200: raise DatasetInspectError( f"Couldn't list splits for '{repo_id}' (HTTP {resp.status_code}): {resp.text[:300]}" ) data = resp.json() return [{"config": s["config"], "split": s["split"]} for s in data.get("splits", [])] def peek_rows( repo_id: str, subset: str, split: str, sample_size: int = 8, token: Optional[str] = None, ) -> list: """Returns up to `sample_size` raw rows as plain dicts.""" if not repo_id.strip(): raise DatasetInspectError("No dataset repo id given.") try: return _peek_via_datasets_server(repo_id, subset, split, sample_size, token) except DatasetInspectError: return _peek_via_streaming(repo_id, subset, split, sample_size, token) def _peek_via_datasets_server( repo_id: str, subset: str, split: str, sample_size: int, token: Optional[str] ) -> list: headers = {"Authorization": f"Bearer {token}"} if token else {} params = { "dataset": repo_id, "config": subset or "default", "split": split, "offset": 0, "length": sample_size, } resp = requests.get(_ROWS_URL, params=params, headers=headers, timeout=_REQUEST_TIMEOUT_SECONDS) if resp.status_code != 200: raise DatasetInspectError(f"datasets-server returned HTTP {resp.status_code} for '{repo_id}'") data = resp.json() rows = data.get("rows", []) if not rows: raise DatasetInspectError(f"datasets-server returned no rows for '{repo_id}'") return [r["row"] for r in rows] def _peek_via_streaming( repo_id: str, subset: str, split: str, sample_size: int, token: Optional[str] ) -> list: try: ds = load_dataset(repo_id, subset or None, split=split, streaming=True, token=token) except Exception as exc: raise DatasetInspectError(f"Couldn't load '{repo_id}': {exc}") from exc rows = [] for i, row in enumerate(ds): if i >= sample_size: break rows.append(dict(row)) if not rows: raise DatasetInspectError(f"'{repo_id}' (config={subset or 'default'}, split={split}) has no rows") return rows