| |
| """Inspect a Hugging Face dataset using the Dataset Viewer API. |
| |
| This mirrors the useful parts of upstream ml-intern's hf_inspect_dataset tool: |
| status, configs/splits, schema, sample rows, and parquet file availability. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import urllib.error |
| import urllib.parse |
| import urllib.request |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from typing import Any |
|
|
|
|
| BASE_URL = "https://datasets-server.huggingface.co" |
| MAX_SAMPLE_VALUE_LEN = 150 |
|
|
|
|
| def fetch_json(path: str, params: dict[str, Any], token: str | None) -> dict[str, Any]: |
| query = urllib.parse.urlencode({k: v for k, v in params.items() if v is not None}) |
| request = urllib.request.Request(f"{BASE_URL}{path}?{query}") |
| if token: |
| request.add_header("Authorization", f"Bearer {token}") |
| try: |
| with urllib.request.urlopen(request, timeout=30) as response: |
| return json.loads(response.read().decode("utf-8")) |
| except urllib.error.HTTPError as exc: |
| body = exc.read().decode("utf-8", errors="replace") |
| raise RuntimeError(f"{path} returned HTTP {exc.code}: {body[:500]}") from exc |
|
|
|
|
| def type_name(feature: Any) -> str: |
| if isinstance(feature, str): |
| return feature |
| if not isinstance(feature, dict): |
| return type(feature).__name__ |
| feature_type = feature.get("_type") |
| if feature_type == "ClassLabel": |
| names = feature.get("names") or [] |
| if 0 < len(names) <= 5: |
| values = ", ".join(f"{name}={idx}" for idx, name in enumerate(names)) |
| return f"ClassLabel ({values})" |
| return f"ClassLabel ({len(names)} classes)" |
| if feature_type: |
| return feature_type |
| if "dtype" in feature: |
| return str(feature["dtype"]) |
| return json.dumps(feature, ensure_ascii=False)[:120] |
|
|
|
|
| def extract_configs(splits_data: dict[str, Any]) -> list[dict[str, Any]]: |
| configs: dict[str, dict[str, Any]] = {} |
| for item in splits_data.get("splits", []): |
| config = item.get("config", "default") |
| split = item.get("split", "train") |
| row_count = item.get("num_rows") or item.get("num_examples") |
| configs.setdefault(config, {"name": config, "splits": []}) |
| configs[config]["splits"].append({"name": split, "rows": row_count}) |
| return list(configs.values()) |
|
|
|
|
| def format_status(data: dict[str, Any]) -> str: |
| available = [ |
| key |
| for key in ("viewer", "preview", "search", "filter", "statistics") |
| if data.get(key) |
| ] |
| if available: |
| return f"## Status\nValid ({', '.join(available)})" |
| return "## Status\nDataset may have Dataset Viewer issues" |
|
|
|
|
| def format_structure(configs: list[dict[str, Any]], max_rows: int = 20) -> str: |
| lines = ["## Structure (configs & splits)", "| Config | Split | Rows |", "|---|---|---:|"] |
| total = sum(len(config["splits"]) for config in configs) |
| shown = 0 |
| for config in configs: |
| for split in config["splits"]: |
| if shown >= max_rows: |
| break |
| rows = split["rows"] if split["rows"] is not None else "?" |
| lines.append(f"| {config['name']} | {split['name']} | {rows} |") |
| shown += 1 |
| if shown >= max_rows: |
| break |
| if total > shown: |
| lines.append("| ... | ... | ... |") |
| lines.append(f"_Showing {shown} of {total} config/split rows._") |
| return "\n".join(lines) |
|
|
|
|
| def format_schema(info: dict[str, Any], config: str) -> str: |
| features = info.get("dataset_info", {}).get("features", {}) |
| lines = [f"## Schema ({config})", "| Column | Type |", "|---|---|"] |
| if not features: |
| lines.append("| (none found) | unknown |") |
| for column, feature in features.items(): |
| lines.append(f"| {column} | {type_name(feature)} |") |
| return "\n".join(lines) |
|
|
|
|
| def maybe_json(value: Any) -> Any: |
| if isinstance(value, str): |
| try: |
| return json.loads(value) |
| except json.JSONDecodeError: |
| return value |
| return value |
|
|
|
|
| def format_messages(messages: Any) -> str | None: |
| messages = maybe_json(messages) |
| if not isinstance(messages, list) or not messages: |
| return None |
| roles: set[str] = set() |
| keys: set[str] = set() |
| has_tool_calls = False |
| has_tool_results = False |
| example: dict[str, Any] | None = None |
| fallback: dict[str, Any] | None = None |
| for message in messages: |
| if not isinstance(message, dict): |
| continue |
| keys.update(message.keys()) |
| if message.get("role"): |
| roles.add(str(message["role"])) |
| if message.get("tool_calls") or message.get("function_call"): |
| has_tool_calls = True |
| example = example or message |
| if message.get("role") in {"tool", "function"} or message.get("tool_call_id"): |
| has_tool_results = True |
| if message.get("role") == "assistant": |
| example = example or message |
| elif message.get("role") != "system": |
| fallback = fallback or message |
| example = example or fallback |
| lines = ["## Messages Column Format"] |
| lines.append(f"Roles: {', '.join(sorted(roles)) if roles else 'unknown'}") |
| common = ["role", "content", "tool_calls", "tool_call_id", "name", "function_call"] |
| lines.append("Message keys: " + ", ".join(f"{key} {'yes' if key in keys else 'no'}" for key in common)) |
| if has_tool_calls: |
| lines.append("Tool calls: present") |
| if has_tool_results: |
| lines.append("Tool results: present") |
| if example: |
| cleaned = dict(example) |
| content = cleaned.get("content") |
| if isinstance(content, str) and len(content) > 100: |
| cleaned["content"] = content[:100] + "..." |
| lines.append("") |
| lines.append("Example message structure:") |
| lines.append("```json") |
| lines.append(json.dumps(cleaned, indent=2, ensure_ascii=False)) |
| lines.append("```") |
| return "\n".join(lines) |
|
|
|
|
| def format_samples(rows_data: dict[str, Any], config: str, split: str, limit: int) -> str: |
| rows = rows_data.get("rows", [])[:limit] |
| lines = [f"## Sample Rows ({config}/{split})"] |
| first_messages: Any = None |
| for idx, row_wrapper in enumerate(rows, 1): |
| row = row_wrapper.get("row", {}) |
| lines.append(f"**Row {idx}:**") |
| for key, value in row.items(): |
| if key.lower() == "messages" and first_messages is None: |
| first_messages = value |
| text = str(value) |
| if len(text) > MAX_SAMPLE_VALUE_LEN: |
| text = text[:MAX_SAMPLE_VALUE_LEN] + "..." |
| lines.append(f"- {key}: {text}") |
| if not rows: |
| lines.append("(no rows returned)") |
| message_section = format_messages(first_messages) if first_messages is not None else None |
| if message_section: |
| lines.append("") |
| lines.append(message_section) |
| return "\n".join(lines) |
|
|
|
|
| def format_parquet(data: dict[str, Any], max_rows: int = 20) -> str | None: |
| files = data.get("parquet_files", []) |
| if not files: |
| return None |
| groups: dict[str, dict[str, int]] = {} |
| for item in files: |
| key = f"{item.get('config', 'default')}/{item.get('split', 'train')}" |
| groups.setdefault(key, {"count": 0, "size": 0}) |
| groups[key]["count"] += 1 |
| size = item.get("size") or 0 |
| groups[key]["size"] += int(size) if isinstance(size, (int, float)) else 0 |
| lines = ["## Files (Parquet)"] |
| for key, values in list(groups.items())[:max_rows]: |
| size_mb = values["size"] / (1024 * 1024) |
| lines.append(f"- {key}: {values['count']} file(s), {size_mb:.1f} MB") |
| if len(groups) > max_rows: |
| lines.append(f"- ... showing {max_rows} of {len(groups)} groups") |
| return "\n".join(lines) |
|
|
|
|
| def compatibility_notes(features: dict[str, Any]) -> str: |
| columns = set(features) |
| lines = ["## Training Compatibility"] |
| checks = { |
| "SFT": bool({"messages", "text"} & columns or {"prompt", "completion"} <= columns), |
| "DPO": {"prompt", "chosen", "rejected"} <= columns, |
| "GRPO": "prompt" in columns, |
| } |
| for method, ok in checks.items(): |
| lines.append(f"- {method}: {'looks compatible' if ok else 'columns not sufficient'}") |
| if "messages" in columns: |
| lines.append("- Chat data: inspect the sample message roles before choosing a trainer template.") |
| return "\n".join(lines) |
|
|
|
|
| def inspect_dataset(dataset: str, config: str | None, split: str | None, sample_rows: int, token: str | None) -> str: |
| warnings: list[str] = [] |
| with ThreadPoolExecutor(max_workers=3) as pool: |
| futures = { |
| pool.submit(fetch_json, "/is-valid", {"dataset": dataset}, token): "is-valid", |
| pool.submit(fetch_json, "/splits", {"dataset": dataset}, token): "splits", |
| pool.submit(fetch_json, "/parquet", {"dataset": dataset}, token): "parquet", |
| } |
| phase1: dict[str, Any] = {} |
| for future in as_completed(futures): |
| name = futures[future] |
| try: |
| phase1[name] = future.result() |
| except Exception as exc: |
| warnings.append(f"{name}: {exc}") |
|
|
| configs = extract_configs(phase1.get("splits", {})) |
| selected_config = config or (configs[0]["name"] if configs else "default") |
| selected_split = split or (configs[0]["splits"][0]["name"] if configs and configs[0]["splits"] else "train") |
|
|
| with ThreadPoolExecutor(max_workers=2) as pool: |
| futures = { |
| pool.submit(fetch_json, "/info", {"dataset": dataset, "config": selected_config}, token): "info", |
| pool.submit( |
| fetch_json, |
| "/first-rows", |
| {"dataset": dataset, "config": selected_config, "split": selected_split}, |
| token, |
| ): "first-rows", |
| } |
| phase2: dict[str, Any] = {} |
| for future in as_completed(futures): |
| name = futures[future] |
| try: |
| phase2[name] = future.result() |
| except Exception as exc: |
| warnings.append(f"{name}: {exc}") |
|
|
| features = phase2.get("info", {}).get("dataset_info", {}).get("features", {}) |
| sections = [f"# {dataset}"] |
| if "is-valid" in phase1: |
| sections.append(format_status(phase1["is-valid"])) |
| if configs: |
| sections.append(format_structure(configs)) |
| if "info" in phase2: |
| sections.append(format_schema(phase2["info"], selected_config)) |
| sections.append(compatibility_notes(features)) |
| if "first-rows" in phase2: |
| sections.append(format_samples(phase2["first-rows"], selected_config, selected_split, sample_rows)) |
| parquet = format_parquet(phase1.get("parquet", {})) |
| if parquet: |
| sections.append(parquet) |
| if warnings: |
| sections.append("## Warnings\n" + "\n".join(f"- {warning}" for warning in warnings)) |
| return "\n\n".join(sections) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("dataset", help="Dataset id, for example stanfordnlp/imdb") |
| parser.add_argument("--config", help="Config/subset name") |
| parser.add_argument("--split", help="Split for sample rows") |
| parser.add_argument("--sample-rows", type=int, default=3, help="Number of rows to show, max 10") |
| parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing an HF token") |
| args = parser.parse_args() |
| token = os.environ.get(args.token_env) |
| print(inspect_dataset(args.dataset, args.config, args.split, min(args.sample_rows, 10), token)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|