| | """
|
| | Dataset Inspection Tool - Comprehensive dataset analysis in one call
|
| |
|
| | Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints
|
| | to provide everything needed for ML tasks in a single tool call.
|
| | """
|
| |
|
| | import asyncio
|
| | import os
|
| | from typing import Any, TypedDict
|
| |
|
| | import httpx
|
| |
|
| | from agent.tools.types import ToolResult
|
| |
|
| | BASE_URL = "https://datasets-server.huggingface.co"
|
| |
|
| |
|
| | MAX_SAMPLE_VALUE_LEN = 150
|
| |
|
| |
|
| | class SplitConfig(TypedDict):
|
| | """Typed representation of a dataset config and its splits."""
|
| |
|
| | name: str
|
| | splits: list[str]
|
| |
|
| |
|
| | def _get_headers() -> dict:
|
| | """Get auth headers for private/gated datasets"""
|
| | token = os.environ.get("HF_TOKEN")
|
| | if token:
|
| | return {"Authorization": f"Bearer {token}"}
|
| | return {}
|
| |
|
| |
|
| | async def inspect_dataset(
|
| | dataset: str,
|
| | config: str | None = None,
|
| | split: str | None = None,
|
| | sample_rows: int = 3,
|
| | ) -> ToolResult:
|
| | """
|
| | Get comprehensive dataset info in one call.
|
| | All API calls made in parallel for speed.
|
| | """
|
| | headers = _get_headers()
|
| | output_parts = []
|
| | errors = []
|
| |
|
| | async with httpx.AsyncClient(timeout=15, headers=headers) as client:
|
| |
|
| | is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset})
|
| | splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset})
|
| | parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset})
|
| |
|
| | results = await asyncio.gather(
|
| | is_valid_task,
|
| | splits_task,
|
| | parquet_task,
|
| | return_exceptions=True,
|
| | )
|
| |
|
| |
|
| | if not isinstance(results[0], Exception):
|
| | try:
|
| | output_parts.append(_format_status(results[0].json()))
|
| | except Exception as e:
|
| | errors.append(f"is-valid: {e}")
|
| |
|
| |
|
| | configs = []
|
| | if not isinstance(results[1], Exception):
|
| | try:
|
| | splits_data = results[1].json()
|
| | configs = _extract_configs(splits_data)
|
| | if not config:
|
| | config = configs[0]["name"] if configs else "default"
|
| | if not split:
|
| | split = configs[0]["splits"][0] if configs else "train"
|
| | output_parts.append(_format_structure(configs))
|
| | except Exception as e:
|
| | errors.append(f"splits: {e}")
|
| |
|
| | if not config:
|
| | config = "default"
|
| | if not split:
|
| | split = "train"
|
| |
|
| |
|
| | parquet_section = None
|
| | if not isinstance(results[2], Exception):
|
| | try:
|
| | parquet_section = _format_parquet_files(results[2].json())
|
| | except Exception:
|
| | pass
|
| |
|
| |
|
| | info_task = client.get(
|
| | f"{BASE_URL}/info", params={"dataset": dataset, "config": config}
|
| | )
|
| | rows_task = client.get(
|
| | f"{BASE_URL}/first-rows",
|
| | params={"dataset": dataset, "config": config, "split": split},
|
| | timeout=30,
|
| | )
|
| |
|
| | content_results = await asyncio.gather(
|
| | info_task,
|
| | rows_task,
|
| | return_exceptions=True,
|
| | )
|
| |
|
| |
|
| | if not isinstance(content_results[0], Exception):
|
| | try:
|
| | output_parts.append(_format_schema(content_results[0].json(), config))
|
| | except Exception as e:
|
| | errors.append(f"info: {e}")
|
| |
|
| |
|
| | if not isinstance(content_results[1], Exception):
|
| | try:
|
| | output_parts.append(
|
| | _format_samples(
|
| | content_results[1].json(), config, split, sample_rows
|
| | )
|
| | )
|
| | except Exception as e:
|
| | errors.append(f"rows: {e}")
|
| |
|
| |
|
| | if parquet_section:
|
| | output_parts.append(parquet_section)
|
| |
|
| |
|
| | formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts)
|
| | if errors:
|
| | formatted += f"\n\n**Warnings:** {'; '.join(errors)}"
|
| |
|
| | return {
|
| | "formatted": formatted,
|
| | "totalResults": 1,
|
| | "resultsShared": 1,
|
| | "isError": len(output_parts) == 0,
|
| | }
|
| |
|
| |
|
| | def _format_status(data: dict) -> str:
|
| | """Format /is-valid response as status line"""
|
| | available = [
|
| | k
|
| | for k in ["viewer", "preview", "search", "filter", "statistics"]
|
| | if data.get(k)
|
| | ]
|
| | if available:
|
| | return f"## Status\n✓ Valid ({', '.join(available)})"
|
| | return "## Status\n✗ Dataset may have issues"
|
| |
|
| |
|
| | def _extract_configs(splits_data: dict) -> list[SplitConfig]:
|
| | """Group splits by config"""
|
| | configs: dict[str, SplitConfig] = {}
|
| | for s in splits_data.get("splits", []):
|
| | cfg = s.get("config", "default")
|
| | if cfg not in configs:
|
| | configs[cfg] = {"name": cfg, "splits": []}
|
| | configs[cfg]["splits"].append(s.get("split"))
|
| | return list(configs.values())
|
| |
|
| |
|
| | def _format_structure(configs: list[SplitConfig], max_rows: int = 10) -> str:
|
| | """Format configs and splits as a markdown table."""
|
| | lines = [
|
| | "## Structure (configs & splits)",
|
| | "| Config | Split |",
|
| | "|--------|-------|",
|
| | ]
|
| |
|
| | total_splits = sum(len(cfg["splits"]) for cfg in configs)
|
| | added_rows = 0
|
| |
|
| | for cfg in configs:
|
| | for split_name in cfg["splits"]:
|
| | if added_rows >= max_rows:
|
| | break
|
| | lines.append(f"| {cfg['name']} | {split_name} |")
|
| | added_rows += 1
|
| | if added_rows >= max_rows:
|
| | break
|
| |
|
| | if total_splits > added_rows:
|
| | lines.append(
|
| | f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |"
|
| | )
|
| |
|
| | return "\n".join(lines)
|
| |
|
| |
|
| | def _format_schema(info: dict, config: str) -> str:
|
| | """Extract features and format as table"""
|
| | features = info.get("dataset_info", {}).get("features", {})
|
| | lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"]
|
| | for col_name, col_info in features.items():
|
| | col_type = _get_type_str(col_info)
|
| | lines.append(f"| {col_name} | {col_type} |")
|
| | return "\n".join(lines)
|
| |
|
| |
|
| | def _get_type_str(col_info: dict) -> str:
|
| | """Convert feature info to readable type string"""
|
| | dtype = col_info.get("dtype") or col_info.get("_type", "unknown")
|
| | if col_info.get("_type") == "ClassLabel":
|
| | names = col_info.get("names", [])
|
| | if names and len(names) <= 5:
|
| | return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})"
|
| | return f"ClassLabel ({len(names)} classes)"
|
| | return str(dtype)
|
| |
|
| |
|
| | def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str:
|
| | """Format sample rows, truncate long values"""
|
| | rows = rows_data.get("rows", [])[:limit]
|
| | lines = [f"## Sample Rows ({config}/{split})"]
|
| |
|
| | messages_col_data = None
|
| |
|
| | for i, row_wrapper in enumerate(rows, 1):
|
| | row = row_wrapper.get("row", {})
|
| | lines.append(f"**Row {i}:**")
|
| | for key, val in row.items():
|
| |
|
| | if key.lower() == "messages" and messages_col_data is None:
|
| | messages_col_data = val
|
| |
|
| | val_str = str(val)
|
| | if len(val_str) > MAX_SAMPLE_VALUE_LEN:
|
| | val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..."
|
| | lines.append(f"- {key}: {val_str}")
|
| |
|
| |
|
| | if messages_col_data is not None:
|
| | messages_format = _format_messages_structure(messages_col_data)
|
| | if messages_format:
|
| | lines.append("")
|
| | lines.append(messages_format)
|
| |
|
| | return "\n".join(lines)
|
| |
|
| |
|
| | def _format_messages_structure(messages_data: Any) -> str | None:
|
| | """
|
| | Analyze and format the structure of a messages column.
|
| | Common in chat/instruction datasets.
|
| | """
|
| | import json
|
| |
|
| |
|
| | if isinstance(messages_data, str):
|
| | try:
|
| | messages_data = json.loads(messages_data)
|
| | except json.JSONDecodeError:
|
| | return None
|
| |
|
| | if not isinstance(messages_data, list) or not messages_data:
|
| | return None
|
| |
|
| | lines = ["## Messages Column Format"]
|
| |
|
| |
|
| | roles_seen = set()
|
| | has_tool_calls = False
|
| | has_tool_results = False
|
| | message_keys = set()
|
| |
|
| | for msg in messages_data:
|
| | if not isinstance(msg, dict):
|
| | continue
|
| |
|
| | message_keys.update(msg.keys())
|
| |
|
| | role = msg.get("role", "")
|
| | if role:
|
| | roles_seen.add(role)
|
| |
|
| | if "tool_calls" in msg or "function_call" in msg:
|
| | has_tool_calls = True
|
| | if role in ("tool", "function") or msg.get("tool_call_id"):
|
| | has_tool_results = True
|
| |
|
| |
|
| | lines.append(
|
| | f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}"
|
| | )
|
| |
|
| |
|
| | common_keys = [
|
| | "role",
|
| | "content",
|
| | "tool_calls",
|
| | "tool_call_id",
|
| | "name",
|
| | "function_call",
|
| | ]
|
| | key_status = []
|
| | for key in common_keys:
|
| | if key in message_keys:
|
| | key_status.append(f"{key} ✓")
|
| | else:
|
| | key_status.append(f"{key} ✗")
|
| | lines.append(f"**Message keys:** {', '.join(key_status)}")
|
| |
|
| | if has_tool_calls:
|
| | lines.append("**Tool calls:** ✓ Present")
|
| | if has_tool_results:
|
| | lines.append("**Tool results:** ✓ Present")
|
| |
|
| |
|
| |
|
| | example = None
|
| | fallback = None
|
| | for msg in messages_data:
|
| | if not isinstance(msg, dict):
|
| | continue
|
| | role = msg.get("role", "")
|
| |
|
| | if msg.get("tool_calls") or msg.get("function_call"):
|
| | example = msg
|
| | break
|
| | if role == "assistant" and example is None:
|
| | example = msg
|
| | elif role != "system" and fallback is None:
|
| | fallback = msg
|
| | if example is None:
|
| | example = fallback
|
| |
|
| | if example:
|
| | lines.append("")
|
| | lines.append("**Example message structure:**")
|
| |
|
| | example_clean = {}
|
| | for key, val in example.items():
|
| | if key == "content" and isinstance(val, str) and len(val) > 100:
|
| | example_clean[key] = val[:100] + "..."
|
| | else:
|
| | example_clean[key] = val
|
| | lines.append("```json")
|
| | lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False))
|
| | lines.append("```")
|
| |
|
| | return "\n".join(lines)
|
| |
|
| |
|
| | def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None:
|
| | """Format parquet file info, return None if no files."""
|
| | files = data.get("parquet_files", [])
|
| | if not files:
|
| | return None
|
| |
|
| |
|
| | groups: dict[str, dict] = {}
|
| | for f in files:
|
| | key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
|
| | if key not in groups:
|
| | groups[key] = {"count": 0, "size": 0}
|
| | size = f.get("size") or 0
|
| | if not isinstance(size, (int, float)):
|
| | size = 0
|
| | groups[key]["count"] += 1
|
| | groups[key]["size"] += int(size)
|
| |
|
| | lines = ["## Files (Parquet)"]
|
| | items = list(groups.items())
|
| | total_groups = len(items)
|
| |
|
| | shown = 0
|
| | for key, info in items[:max_rows]:
|
| | size_mb = info["size"] / (1024 * 1024)
|
| | lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
|
| | shown += 1
|
| |
|
| | if total_groups > shown:
|
| | lines.append(f"- ... (_showing {shown} of {total_groups} parquet groups_)")
|
| | return "\n".join(lines)
|
| |
|
| |
|
| |
|
| | HF_INSPECT_DATASET_TOOL_SPEC = {
|
| | "name": "hf_inspect_dataset",
|
| | "description": (
|
| | "Inspect a Hugging Face dataset comprehensively in one call.\n\n"
|
| | "## What you get\n"
|
| | "- Status check (validates dataset works without errors)\n"
|
| | "- All configs and splits (row counts/shares may be '?' when metadata is missing)\n"
|
| | "- Column names and types (schema)\n"
|
| | "- Sample rows to understand data format\n"
|
| | "- Parquet file structure and sizes\n\n"
|
| | "## CRITICAL\n"
|
| | "**Always inspect datasets before writing training code** to understand:\n"
|
| | "- Column names for your dataloader\n"
|
| | "- Data types and format\n"
|
| | "- Available splits (train/test/validation)\n\n"
|
| | "Supports private/gated datasets when HF_TOKEN is set.\n\n"
|
| | "## Examples\n"
|
| | '{"dataset": "stanfordnlp/imdb"}\n'
|
| | '{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n'
|
| | ),
|
| | "parameters": {
|
| | "type": "object",
|
| | "properties": {
|
| | "dataset": {
|
| | "type": "string",
|
| | "description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')",
|
| | },
|
| | "config": {
|
| | "type": "string",
|
| | "description": "Config/subset name. Auto-detected if not specified.",
|
| | },
|
| | "split": {
|
| | "type": "string",
|
| | "description": "Split for sample rows. Auto-detected if not specified.",
|
| | },
|
| | "sample_rows": {
|
| | "type": "integer",
|
| | "description": "Number of sample rows to show (default: 3, max: 10)",
|
| | "default": 3,
|
| | },
|
| | },
|
| | "required": ["dataset"],
|
| | },
|
| | }
|
| |
|
| |
|
| | async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
|
| | """Handler for agent tool router"""
|
| | try:
|
| | result = await inspect_dataset(
|
| | dataset=arguments["dataset"],
|
| | config=arguments.get("config"),
|
| | split=arguments.get("split"),
|
| | sample_rows=min(arguments.get("sample_rows", 3), 10),
|
| | )
|
| | return result["formatted"], not result.get("isError", False)
|
| | except Exception as e:
|
| | return f"Error inspecting dataset: {str(e)}", False
|
| |
|