Dataset-Creator / hf_inspect.py
TitleOS's picture
Upload 9 files
390cebe verified
Raw
History Blame Contribute Delete
3.21 kB
"""Service layer for peeking at a HF dataset's shape without pulling the
whole thing down. Tries the lightweight datasets-server API first; falls
back to a short streaming pull for datasets that API doesn't cover
(gated, script-based, or just not indexed yet).
"""
from __future__ import annotations
from typing import Optional
import requests
from datasets import load_dataset
_ROWS_URL = "https://datasets-server.huggingface.co/rows"
_SPLITS_URL = "https://datasets-server.huggingface.co/splits"
_REQUEST_TIMEOUT_SECONDS = 15
class DatasetInspectError(Exception):
"""Raised when we genuinely can't get a peek at a dataset, after
trying both the fast path and the streaming fallback."""
def list_splits(repo_id: str, token: Optional[str] = None) -> list:
"""Returns [{"config": ..., "split": ...}, ...] for the dataset."""
headers = {"Authorization": f"Bearer {token}"} if token else {}
resp = requests.get(
_SPLITS_URL, params={"dataset": repo_id}, headers=headers, timeout=_REQUEST_TIMEOUT_SECONDS
)
if resp.status_code != 200:
raise DatasetInspectError(
f"Couldn't list splits for '{repo_id}' (HTTP {resp.status_code}): {resp.text[:300]}"
)
data = resp.json()
return [{"config": s["config"], "split": s["split"]} for s in data.get("splits", [])]
def peek_rows(
repo_id: str,
subset: str,
split: str,
sample_size: int = 8,
token: Optional[str] = None,
) -> list:
"""Returns up to `sample_size` raw rows as plain dicts."""
if not repo_id.strip():
raise DatasetInspectError("No dataset repo id given.")
try:
return _peek_via_datasets_server(repo_id, subset, split, sample_size, token)
except DatasetInspectError:
return _peek_via_streaming(repo_id, subset, split, sample_size, token)
def _peek_via_datasets_server(
repo_id: str, subset: str, split: str, sample_size: int, token: Optional[str]
) -> list:
headers = {"Authorization": f"Bearer {token}"} if token else {}
params = {
"dataset": repo_id,
"config": subset or "default",
"split": split,
"offset": 0,
"length": sample_size,
}
resp = requests.get(_ROWS_URL, params=params, headers=headers, timeout=_REQUEST_TIMEOUT_SECONDS)
if resp.status_code != 200:
raise DatasetInspectError(f"datasets-server returned HTTP {resp.status_code} for '{repo_id}'")
data = resp.json()
rows = data.get("rows", [])
if not rows:
raise DatasetInspectError(f"datasets-server returned no rows for '{repo_id}'")
return [r["row"] for r in rows]
def _peek_via_streaming(
repo_id: str, subset: str, split: str, sample_size: int, token: Optional[str]
) -> list:
try:
ds = load_dataset(repo_id, subset or None, split=split, streaming=True, token=token)
except Exception as exc:
raise DatasetInspectError(f"Couldn't load '{repo_id}': {exc}") from exc
rows = []
for i, row in enumerate(ds):
if i >= sample_size:
break
rows.append(dict(row))
if not rows:
raise DatasetInspectError(f"'{repo_id}' (config={subset or 'default'}, split={split}) has no rows")
return rows