""" Dataset Inspection Tool - Comprehensive dataset analysis in one call Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints to provide everything needed for ML tasks in a single tool call. """ import asyncio import os from typing import Any, TypedDict import httpx from agent.tools.types import ToolResult BASE_URL = "https://datasets-server.huggingface.co" # Truncation limit for long sample values in the output MAX_SAMPLE_VALUE_LEN = 150 class SplitConfig(TypedDict): """Typed representation of a dataset config and its splits.""" name: str splits: list[str] def _get_headers() -> dict: """Get auth headers for private/gated datasets""" token = os.environ.get("HF_TOKEN") if token: return {"Authorization": f"Bearer {token}"} return {} async def inspect_dataset( dataset: str, config: str | None = None, split: str | None = None, sample_rows: int = 3, ) -> ToolResult: """ Get comprehensive dataset info in one call. All API calls made in parallel for speed. """ headers = _get_headers() output_parts = [] errors = [] async with httpx.AsyncClient(timeout=15, headers=headers) as client: # Phase 1: Parallel calls for structure info (no dependencies) is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset}) splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset}) parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset}) results = await asyncio.gather( is_valid_task, splits_task, parquet_task, return_exceptions=True, ) # Process is-valid if not isinstance(results[0], Exception): try: output_parts.append(_format_status(results[0].json())) except Exception as e: errors.append(f"is-valid: {e}") # Process splits and auto-detect config/split configs = [] if not isinstance(results[1], Exception): try: splits_data = results[1].json() configs = _extract_configs(splits_data) if not config: config = configs[0]["name"] if configs else "default" if not split: split = configs[0]["splits"][0] if configs else "train" output_parts.append(_format_structure(configs)) except Exception as e: errors.append(f"splits: {e}") if not config: config = "default" if not split: split = "train" # Process parquet (will be added at the end) parquet_section = None if not isinstance(results[2], Exception): try: parquet_section = _format_parquet_files(results[2].json()) except Exception: pass # Silently skip if no parquet # Phase 2: Parallel calls for content (depend on config/split) info_task = client.get( f"{BASE_URL}/info", params={"dataset": dataset, "config": config} ) rows_task = client.get( f"{BASE_URL}/first-rows", params={"dataset": dataset, "config": config, "split": split}, timeout=30, ) content_results = await asyncio.gather( info_task, rows_task, return_exceptions=True, ) # Process info (schema) if not isinstance(content_results[0], Exception): try: output_parts.append(_format_schema(content_results[0].json(), config)) except Exception as e: errors.append(f"info: {e}") # Process sample rows if not isinstance(content_results[1], Exception): try: output_parts.append( _format_samples( content_results[1].json(), config, split, sample_rows ) ) except Exception as e: errors.append(f"rows: {e}") # Add parquet section at the end if available if parquet_section: output_parts.append(parquet_section) # Combine output formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts) if errors: formatted += f"\n\n**Warnings:** {'; '.join(errors)}" return { "formatted": formatted, "totalResults": 1, "resultsShared": 1, "isError": len(output_parts) == 0, } def _format_status(data: dict) -> str: """Format /is-valid response as status line""" available = [ k for k in ["viewer", "preview", "search", "filter", "statistics"] if data.get(k) ] if available: return f"## Status\n✓ Valid ({', '.join(available)})" return "## Status\n✗ Dataset may have issues" def _extract_configs(splits_data: dict) -> list[SplitConfig]: """Group splits by config""" configs: dict[str, SplitConfig] = {} for s in splits_data.get("splits", []): cfg = s.get("config", "default") if cfg not in configs: configs[cfg] = {"name": cfg, "splits": []} configs[cfg]["splits"].append(s.get("split")) return list(configs.values()) def _format_structure(configs: list[SplitConfig], max_rows: int = 10) -> str: """Format configs and splits as a markdown table.""" lines = [ "## Structure (configs & splits)", "| Config | Split |", "|--------|-------|", ] total_splits = sum(len(cfg["splits"]) for cfg in configs) added_rows = 0 for cfg in configs: for split_name in cfg["splits"]: if added_rows >= max_rows: break lines.append(f"| {cfg['name']} | {split_name} |") added_rows += 1 if added_rows >= max_rows: break if total_splits > added_rows: lines.append( f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |" ) return "\n".join(lines) def _format_schema(info: dict, config: str) -> str: """Extract features and format as table""" features = info.get("dataset_info", {}).get("features", {}) lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"] for col_name, col_info in features.items(): col_type = _get_type_str(col_info) lines.append(f"| {col_name} | {col_type} |") return "\n".join(lines) def _get_type_str(col_info: dict) -> str: """Convert feature info to readable type string""" dtype = col_info.get("dtype") or col_info.get("_type", "unknown") if col_info.get("_type") == "ClassLabel": names = col_info.get("names", []) if names and len(names) <= 5: return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})" return f"ClassLabel ({len(names)} classes)" return str(dtype) def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str: """Format sample rows, truncate long values""" rows = rows_data.get("rows", [])[:limit] lines = [f"## Sample Rows ({config}/{split})"] messages_col_data = None for i, row_wrapper in enumerate(rows, 1): row = row_wrapper.get("row", {}) lines.append(f"**Row {i}:**") for key, val in row.items(): # Check for messages column and capture first one for format analysis if key.lower() == "messages" and messages_col_data is None: messages_col_data = val val_str = str(val) if len(val_str) > MAX_SAMPLE_VALUE_LEN: val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..." lines.append(f"- {key}: {val_str}") # If we found a messages column, add format analysis if messages_col_data is not None: messages_format = _format_messages_structure(messages_col_data) if messages_format: lines.append("") lines.append(messages_format) return "\n".join(lines) def _format_messages_structure(messages_data: Any) -> str | None: """ Analyze and format the structure of a messages column. Common in chat/instruction datasets. """ import json # Parse if string if isinstance(messages_data, str): try: messages_data = json.loads(messages_data) except json.JSONDecodeError: return None if not isinstance(messages_data, list) or not messages_data: return None lines = ["## Messages Column Format"] # Analyze message structure roles_seen = set() has_tool_calls = False has_tool_results = False message_keys = set() for msg in messages_data: if not isinstance(msg, dict): continue message_keys.update(msg.keys()) role = msg.get("role", "") if role: roles_seen.add(role) if "tool_calls" in msg or "function_call" in msg: has_tool_calls = True if role in ("tool", "function") or msg.get("tool_call_id"): has_tool_results = True # Format the analysis lines.append( f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}" ) # Show common message keys with presence indicators common_keys = [ "role", "content", "tool_calls", "tool_call_id", "name", "function_call", ] key_status = [] for key in common_keys: if key in message_keys: key_status.append(f"{key} ✓") else: key_status.append(f"{key} ✗") lines.append(f"**Message keys:** {', '.join(key_status)}") if has_tool_calls: lines.append("**Tool calls:** ✓ Present") if has_tool_results: lines.append("**Tool results:** ✓ Present") # Show example message structure # Priority: 1) message with tool_calls, 2) first assistant message, 3) first non-system message example = None fallback = None for msg in messages_data: if not isinstance(msg, dict): continue role = msg.get("role", "") # Check for actual tool_calls/function_call values (not None) if msg.get("tool_calls") or msg.get("function_call"): example = msg break if role == "assistant" and example is None: example = msg elif role != "system" and fallback is None: fallback = msg if example is None: example = fallback if example: lines.append("") lines.append("**Example message structure:**") # Build a copy with truncated content but keep all keys example_clean = {} for key, val in example.items(): if key == "content" and isinstance(val, str) and len(val) > 100: example_clean[key] = val[:100] + "..." else: example_clean[key] = val lines.append("```json") lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False)) lines.append("```") return "\n".join(lines) def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None: """Format parquet file info, return None if no files.""" files = data.get("parquet_files", []) if not files: return None # Group by config/split groups: dict[str, dict] = {} for f in files: key = f"{f.get('config', 'default')}/{f.get('split', 'train')}" if key not in groups: groups[key] = {"count": 0, "size": 0} size = f.get("size") or 0 if not isinstance(size, (int, float)): size = 0 groups[key]["count"] += 1 groups[key]["size"] += int(size) lines = ["## Files (Parquet)"] items = list(groups.items()) total_groups = len(items) shown = 0 for key, info in items[:max_rows]: size_mb = info["size"] / (1024 * 1024) lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)") shown += 1 if total_groups > shown: lines.append(f"- ... (_showing {shown} of {total_groups} parquet groups_)") return "\n".join(lines) # Tool specification HF_INSPECT_DATASET_TOOL_SPEC = { "name": "hf_inspect_dataset", "description": ( "Inspect a Hugging Face dataset comprehensively in one call.\n\n" "## What you get\n" "- Status check (validates dataset works without errors)\n" "- All configs and splits (row counts/shares may be '?' when metadata is missing)\n" "- Column names and types (schema)\n" "- Sample rows to understand data format\n" "- Parquet file structure and sizes\n\n" "## CRITICAL\n" "**Always inspect datasets before writing training code** to understand:\n" "- Column names for your dataloader\n" "- Data types and format\n" "- Available splits (train/test/validation)\n\n" "Supports private/gated datasets when HF_TOKEN is set.\n\n" "## Examples\n" '{"dataset": "stanfordnlp/imdb"}\n' '{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n' ), "parameters": { "type": "object", "properties": { "dataset": { "type": "string", "description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')", }, "config": { "type": "string", "description": "Config/subset name. Auto-detected if not specified.", }, "split": { "type": "string", "description": "Split for sample rows. Auto-detected if not specified.", }, "sample_rows": { "type": "integer", "description": "Number of sample rows to show (default: 3, max: 10)", "default": 3, }, }, "required": ["dataset"], }, } async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router""" try: result = await inspect_dataset( dataset=arguments["dataset"], config=arguments.get("config"), split=arguments.get("split"), sample_rows=min(arguments.get("sample_rows", 3), 10), ) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error inspecting dataset: {str(e)}", False