#!/usr/bin/env python3 """Inspect a Hugging Face dataset using the Dataset Viewer API. This mirrors the useful parts of upstream ml-intern's hf_inspect_dataset tool: status, configs/splits, schema, sample rows, and parquet file availability. """ from __future__ import annotations import argparse import json import os import sys import urllib.error import urllib.parse import urllib.request from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any BASE_URL = "https://datasets-server.huggingface.co" MAX_SAMPLE_VALUE_LEN = 150 def fetch_json(path: str, params: dict[str, Any], token: str | None) -> dict[str, Any]: query = urllib.parse.urlencode({k: v for k, v in params.items() if v is not None}) request = urllib.request.Request(f"{BASE_URL}{path}?{query}") if token: request.add_header("Authorization", f"Bearer {token}") try: with urllib.request.urlopen(request, timeout=30) as response: return json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError as exc: body = exc.read().decode("utf-8", errors="replace") raise RuntimeError(f"{path} returned HTTP {exc.code}: {body[:500]}") from exc def type_name(feature: Any) -> str: if isinstance(feature, str): return feature if not isinstance(feature, dict): return type(feature).__name__ feature_type = feature.get("_type") if feature_type == "ClassLabel": names = feature.get("names") or [] if 0 < len(names) <= 5: values = ", ".join(f"{name}={idx}" for idx, name in enumerate(names)) return f"ClassLabel ({values})" return f"ClassLabel ({len(names)} classes)" if feature_type: return feature_type if "dtype" in feature: return str(feature["dtype"]) return json.dumps(feature, ensure_ascii=False)[:120] def extract_configs(splits_data: dict[str, Any]) -> list[dict[str, Any]]: configs: dict[str, dict[str, Any]] = {} for item in splits_data.get("splits", []): config = item.get("config", "default") split = item.get("split", "train") row_count = item.get("num_rows") or item.get("num_examples") configs.setdefault(config, {"name": config, "splits": []}) configs[config]["splits"].append({"name": split, "rows": row_count}) return list(configs.values()) def format_status(data: dict[str, Any]) -> str: available = [ key for key in ("viewer", "preview", "search", "filter", "statistics") if data.get(key) ] if available: return f"## Status\nValid ({', '.join(available)})" return "## Status\nDataset may have Dataset Viewer issues" def format_structure(configs: list[dict[str, Any]], max_rows: int = 20) -> str: lines = ["## Structure (configs & splits)", "| Config | Split | Rows |", "|---|---|---:|"] total = sum(len(config["splits"]) for config in configs) shown = 0 for config in configs: for split in config["splits"]: if shown >= max_rows: break rows = split["rows"] if split["rows"] is not None else "?" lines.append(f"| {config['name']} | {split['name']} | {rows} |") shown += 1 if shown >= max_rows: break if total > shown: lines.append("| ... | ... | ... |") lines.append(f"_Showing {shown} of {total} config/split rows._") return "\n".join(lines) def format_schema(info: dict[str, Any], config: str) -> str: features = info.get("dataset_info", {}).get("features", {}) lines = [f"## Schema ({config})", "| Column | Type |", "|---|---|"] if not features: lines.append("| (none found) | unknown |") for column, feature in features.items(): lines.append(f"| {column} | {type_name(feature)} |") return "\n".join(lines) def maybe_json(value: Any) -> Any: if isinstance(value, str): try: return json.loads(value) except json.JSONDecodeError: return value return value def format_messages(messages: Any) -> str | None: messages = maybe_json(messages) if not isinstance(messages, list) or not messages: return None roles: set[str] = set() keys: set[str] = set() has_tool_calls = False has_tool_results = False example: dict[str, Any] | None = None fallback: dict[str, Any] | None = None for message in messages: if not isinstance(message, dict): continue keys.update(message.keys()) if message.get("role"): roles.add(str(message["role"])) if message.get("tool_calls") or message.get("function_call"): has_tool_calls = True example = example or message if message.get("role") in {"tool", "function"} or message.get("tool_call_id"): has_tool_results = True if message.get("role") == "assistant": example = example or message elif message.get("role") != "system": fallback = fallback or message example = example or fallback lines = ["## Messages Column Format"] lines.append(f"Roles: {', '.join(sorted(roles)) if roles else 'unknown'}") common = ["role", "content", "tool_calls", "tool_call_id", "name", "function_call"] lines.append("Message keys: " + ", ".join(f"{key} {'yes' if key in keys else 'no'}" for key in common)) if has_tool_calls: lines.append("Tool calls: present") if has_tool_results: lines.append("Tool results: present") if example: cleaned = dict(example) content = cleaned.get("content") if isinstance(content, str) and len(content) > 100: cleaned["content"] = content[:100] + "..." lines.append("") lines.append("Example message structure:") lines.append("```json") lines.append(json.dumps(cleaned, indent=2, ensure_ascii=False)) lines.append("```") return "\n".join(lines) def format_samples(rows_data: dict[str, Any], config: str, split: str, limit: int) -> str: rows = rows_data.get("rows", [])[:limit] lines = [f"## Sample Rows ({config}/{split})"] first_messages: Any = None for idx, row_wrapper in enumerate(rows, 1): row = row_wrapper.get("row", {}) lines.append(f"**Row {idx}:**") for key, value in row.items(): if key.lower() == "messages" and first_messages is None: first_messages = value text = str(value) if len(text) > MAX_SAMPLE_VALUE_LEN: text = text[:MAX_SAMPLE_VALUE_LEN] + "..." lines.append(f"- {key}: {text}") if not rows: lines.append("(no rows returned)") message_section = format_messages(first_messages) if first_messages is not None else None if message_section: lines.append("") lines.append(message_section) return "\n".join(lines) def format_parquet(data: dict[str, Any], max_rows: int = 20) -> str | None: files = data.get("parquet_files", []) if not files: return None groups: dict[str, dict[str, int]] = {} for item in files: key = f"{item.get('config', 'default')}/{item.get('split', 'train')}" groups.setdefault(key, {"count": 0, "size": 0}) groups[key]["count"] += 1 size = item.get("size") or 0 groups[key]["size"] += int(size) if isinstance(size, (int, float)) else 0 lines = ["## Files (Parquet)"] for key, values in list(groups.items())[:max_rows]: size_mb = values["size"] / (1024 * 1024) lines.append(f"- {key}: {values['count']} file(s), {size_mb:.1f} MB") if len(groups) > max_rows: lines.append(f"- ... showing {max_rows} of {len(groups)} groups") return "\n".join(lines) def compatibility_notes(features: dict[str, Any]) -> str: columns = set(features) lines = ["## Training Compatibility"] checks = { "SFT": bool({"messages", "text"} & columns or {"prompt", "completion"} <= columns), "DPO": {"prompt", "chosen", "rejected"} <= columns, "GRPO": "prompt" in columns, } for method, ok in checks.items(): lines.append(f"- {method}: {'looks compatible' if ok else 'columns not sufficient'}") if "messages" in columns: lines.append("- Chat data: inspect the sample message roles before choosing a trainer template.") return "\n".join(lines) def inspect_dataset(dataset: str, config: str | None, split: str | None, sample_rows: int, token: str | None) -> str: warnings: list[str] = [] with ThreadPoolExecutor(max_workers=3) as pool: futures = { pool.submit(fetch_json, "/is-valid", {"dataset": dataset}, token): "is-valid", pool.submit(fetch_json, "/splits", {"dataset": dataset}, token): "splits", pool.submit(fetch_json, "/parquet", {"dataset": dataset}, token): "parquet", } phase1: dict[str, Any] = {} for future in as_completed(futures): name = futures[future] try: phase1[name] = future.result() except Exception as exc: warnings.append(f"{name}: {exc}") configs = extract_configs(phase1.get("splits", {})) selected_config = config or (configs[0]["name"] if configs else "default") selected_split = split or (configs[0]["splits"][0]["name"] if configs and configs[0]["splits"] else "train") with ThreadPoolExecutor(max_workers=2) as pool: futures = { pool.submit(fetch_json, "/info", {"dataset": dataset, "config": selected_config}, token): "info", pool.submit( fetch_json, "/first-rows", {"dataset": dataset, "config": selected_config, "split": selected_split}, token, ): "first-rows", } phase2: dict[str, Any] = {} for future in as_completed(futures): name = futures[future] try: phase2[name] = future.result() except Exception as exc: warnings.append(f"{name}: {exc}") features = phase2.get("info", {}).get("dataset_info", {}).get("features", {}) sections = [f"# {dataset}"] if "is-valid" in phase1: sections.append(format_status(phase1["is-valid"])) if configs: sections.append(format_structure(configs)) if "info" in phase2: sections.append(format_schema(phase2["info"], selected_config)) sections.append(compatibility_notes(features)) if "first-rows" in phase2: sections.append(format_samples(phase2["first-rows"], selected_config, selected_split, sample_rows)) parquet = format_parquet(phase1.get("parquet", {})) if parquet: sections.append(parquet) if warnings: sections.append("## Warnings\n" + "\n".join(f"- {warning}" for warning in warnings)) return "\n\n".join(sections) def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("dataset", help="Dataset id, for example stanfordnlp/imdb") parser.add_argument("--config", help="Config/subset name") parser.add_argument("--split", help="Split for sample rows") parser.add_argument("--sample-rows", type=int, default=3, help="Number of rows to show, max 10") parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing an HF token") args = parser.parse_args() token = os.environ.get(args.token_env) print(inspect_dataset(args.dataset, args.config, args.split, min(args.sample_rows, 10), token)) return 0 if __name__ == "__main__": sys.exit(main())