water / agent /tools /dataset_tools.py
onewayto's picture
Upload 102 files
de93e67 verified
"""
Dataset Inspection Tool - Comprehensive dataset analysis in one call
Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints
to provide everything needed for ML tasks in a single tool call.
"""
import asyncio
import os
from typing import Any, TypedDict
import httpx
from agent.tools.types import ToolResult
BASE_URL = "https://datasets-server.huggingface.co"
# Truncation limit for long sample values in the output
MAX_SAMPLE_VALUE_LEN = 150
class SplitConfig(TypedDict):
"""Typed representation of a dataset config and its splits."""
name: str
splits: list[str]
def _get_headers() -> dict:
"""Get auth headers for private/gated datasets"""
token = os.environ.get("HF_TOKEN")
if token:
return {"Authorization": f"Bearer {token}"}
return {}
async def inspect_dataset(
dataset: str,
config: str | None = None,
split: str | None = None,
sample_rows: int = 3,
) -> ToolResult:
"""
Get comprehensive dataset info in one call.
All API calls made in parallel for speed.
"""
headers = _get_headers()
output_parts = []
errors = []
async with httpx.AsyncClient(timeout=15, headers=headers) as client:
# Phase 1: Parallel calls for structure info (no dependencies)
is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset})
splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset})
parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset})
results = await asyncio.gather(
is_valid_task,
splits_task,
parquet_task,
return_exceptions=True,
)
# Process is-valid
if not isinstance(results[0], Exception):
try:
output_parts.append(_format_status(results[0].json()))
except Exception as e:
errors.append(f"is-valid: {e}")
# Process splits and auto-detect config/split
configs = []
if not isinstance(results[1], Exception):
try:
splits_data = results[1].json()
configs = _extract_configs(splits_data)
if not config:
config = configs[0]["name"] if configs else "default"
if not split:
split = configs[0]["splits"][0] if configs else "train"
output_parts.append(_format_structure(configs))
except Exception as e:
errors.append(f"splits: {e}")
if not config:
config = "default"
if not split:
split = "train"
# Process parquet (will be added at the end)
parquet_section = None
if not isinstance(results[2], Exception):
try:
parquet_section = _format_parquet_files(results[2].json())
except Exception:
pass # Silently skip if no parquet
# Phase 2: Parallel calls for content (depend on config/split)
info_task = client.get(
f"{BASE_URL}/info", params={"dataset": dataset, "config": config}
)
rows_task = client.get(
f"{BASE_URL}/first-rows",
params={"dataset": dataset, "config": config, "split": split},
timeout=30,
)
content_results = await asyncio.gather(
info_task,
rows_task,
return_exceptions=True,
)
# Process info (schema)
if not isinstance(content_results[0], Exception):
try:
output_parts.append(_format_schema(content_results[0].json(), config))
except Exception as e:
errors.append(f"info: {e}")
# Process sample rows
if not isinstance(content_results[1], Exception):
try:
output_parts.append(
_format_samples(
content_results[1].json(), config, split, sample_rows
)
)
except Exception as e:
errors.append(f"rows: {e}")
# Add parquet section at the end if available
if parquet_section:
output_parts.append(parquet_section)
# Combine output
formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts)
if errors:
formatted += f"\n\n**Warnings:** {'; '.join(errors)}"
return {
"formatted": formatted,
"totalResults": 1,
"resultsShared": 1,
"isError": len(output_parts) == 0,
}
def _format_status(data: dict) -> str:
"""Format /is-valid response as status line"""
available = [
k
for k in ["viewer", "preview", "search", "filter", "statistics"]
if data.get(k)
]
if available:
return f"## Status\n✓ Valid ({', '.join(available)})"
return "## Status\n✗ Dataset may have issues"
def _extract_configs(splits_data: dict) -> list[SplitConfig]:
"""Group splits by config"""
configs: dict[str, SplitConfig] = {}
for s in splits_data.get("splits", []):
cfg = s.get("config", "default")
if cfg not in configs:
configs[cfg] = {"name": cfg, "splits": []}
configs[cfg]["splits"].append(s.get("split"))
return list(configs.values())
def _format_structure(configs: list[SplitConfig], max_rows: int = 10) -> str:
"""Format configs and splits as a markdown table."""
lines = [
"## Structure (configs & splits)",
"| Config | Split |",
"|--------|-------|",
]
total_splits = sum(len(cfg["splits"]) for cfg in configs)
added_rows = 0
for cfg in configs:
for split_name in cfg["splits"]:
if added_rows >= max_rows:
break
lines.append(f"| {cfg['name']} | {split_name} |")
added_rows += 1
if added_rows >= max_rows:
break
if total_splits > added_rows:
lines.append(
f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |"
)
return "\n".join(lines)
def _format_schema(info: dict, config: str) -> str:
"""Extract features and format as table"""
features = info.get("dataset_info", {}).get("features", {})
lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"]
for col_name, col_info in features.items():
col_type = _get_type_str(col_info)
lines.append(f"| {col_name} | {col_type} |")
return "\n".join(lines)
def _get_type_str(col_info: dict) -> str:
"""Convert feature info to readable type string"""
dtype = col_info.get("dtype") or col_info.get("_type", "unknown")
if col_info.get("_type") == "ClassLabel":
names = col_info.get("names", [])
if names and len(names) <= 5:
return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})"
return f"ClassLabel ({len(names)} classes)"
return str(dtype)
def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str:
"""Format sample rows, truncate long values"""
rows = rows_data.get("rows", [])[:limit]
lines = [f"## Sample Rows ({config}/{split})"]
messages_col_data = None
for i, row_wrapper in enumerate(rows, 1):
row = row_wrapper.get("row", {})
lines.append(f"**Row {i}:**")
for key, val in row.items():
# Check for messages column and capture first one for format analysis
if key.lower() == "messages" and messages_col_data is None:
messages_col_data = val
val_str = str(val)
if len(val_str) > MAX_SAMPLE_VALUE_LEN:
val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..."
lines.append(f"- {key}: {val_str}")
# If we found a messages column, add format analysis
if messages_col_data is not None:
messages_format = _format_messages_structure(messages_col_data)
if messages_format:
lines.append("")
lines.append(messages_format)
return "\n".join(lines)
def _format_messages_structure(messages_data: Any) -> str | None:
"""
Analyze and format the structure of a messages column.
Common in chat/instruction datasets.
"""
import json
# Parse if string
if isinstance(messages_data, str):
try:
messages_data = json.loads(messages_data)
except json.JSONDecodeError:
return None
if not isinstance(messages_data, list) or not messages_data:
return None
lines = ["## Messages Column Format"]
# Analyze message structure
roles_seen = set()
has_tool_calls = False
has_tool_results = False
message_keys = set()
for msg in messages_data:
if not isinstance(msg, dict):
continue
message_keys.update(msg.keys())
role = msg.get("role", "")
if role:
roles_seen.add(role)
if "tool_calls" in msg or "function_call" in msg:
has_tool_calls = True
if role in ("tool", "function") or msg.get("tool_call_id"):
has_tool_results = True
# Format the analysis
lines.append(
f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}"
)
# Show common message keys with presence indicators
common_keys = [
"role",
"content",
"tool_calls",
"tool_call_id",
"name",
"function_call",
]
key_status = []
for key in common_keys:
if key in message_keys:
key_status.append(f"{key} ✓")
else:
key_status.append(f"{key} ✗")
lines.append(f"**Message keys:** {', '.join(key_status)}")
if has_tool_calls:
lines.append("**Tool calls:** ✓ Present")
if has_tool_results:
lines.append("**Tool results:** ✓ Present")
# Show example message structure
# Priority: 1) message with tool_calls, 2) first assistant message, 3) first non-system message
example = None
fallback = None
for msg in messages_data:
if not isinstance(msg, dict):
continue
role = msg.get("role", "")
# Check for actual tool_calls/function_call values (not None)
if msg.get("tool_calls") or msg.get("function_call"):
example = msg
break
if role == "assistant" and example is None:
example = msg
elif role != "system" and fallback is None:
fallback = msg
if example is None:
example = fallback
if example:
lines.append("")
lines.append("**Example message structure:**")
# Build a copy with truncated content but keep all keys
example_clean = {}
for key, val in example.items():
if key == "content" and isinstance(val, str) and len(val) > 100:
example_clean[key] = val[:100] + "..."
else:
example_clean[key] = val
lines.append("```json")
lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False))
lines.append("```")
return "\n".join(lines)
def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None:
"""Format parquet file info, return None if no files."""
files = data.get("parquet_files", [])
if not files:
return None
# Group by config/split
groups: dict[str, dict] = {}
for f in files:
key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
if key not in groups:
groups[key] = {"count": 0, "size": 0}
size = f.get("size") or 0
if not isinstance(size, (int, float)):
size = 0
groups[key]["count"] += 1
groups[key]["size"] += int(size)
lines = ["## Files (Parquet)"]
items = list(groups.items())
total_groups = len(items)
shown = 0
for key, info in items[:max_rows]:
size_mb = info["size"] / (1024 * 1024)
lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
shown += 1
if total_groups > shown:
lines.append(f"- ... (_showing {shown} of {total_groups} parquet groups_)")
return "\n".join(lines)
# Tool specification
HF_INSPECT_DATASET_TOOL_SPEC = {
"name": "hf_inspect_dataset",
"description": (
"Inspect a Hugging Face dataset comprehensively in one call.\n\n"
"## What you get\n"
"- Status check (validates dataset works without errors)\n"
"- All configs and splits (row counts/shares may be '?' when metadata is missing)\n"
"- Column names and types (schema)\n"
"- Sample rows to understand data format\n"
"- Parquet file structure and sizes\n\n"
"## CRITICAL\n"
"**Always inspect datasets before writing training code** to understand:\n"
"- Column names for your dataloader\n"
"- Data types and format\n"
"- Available splits (train/test/validation)\n\n"
"Supports private/gated datasets when HF_TOKEN is set.\n\n"
"## Examples\n"
'{"dataset": "stanfordnlp/imdb"}\n'
'{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n'
),
"parameters": {
"type": "object",
"properties": {
"dataset": {
"type": "string",
"description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')",
},
"config": {
"type": "string",
"description": "Config/subset name. Auto-detected if not specified.",
},
"split": {
"type": "string",
"description": "Split for sample rows. Auto-detected if not specified.",
},
"sample_rows": {
"type": "integer",
"description": "Number of sample rows to show (default: 3, max: 10)",
"default": 3,
},
},
"required": ["dataset"],
},
}
async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
"""Handler for agent tool router"""
try:
result = await inspect_dataset(
dataset=arguments["dataset"],
config=arguments.get("config"),
split=arguments.get("split"),
sample_rows=min(arguments.get("sample_rows", 3), 10),
)
return result["formatted"], not result.get("isError", False)
except Exception as e:
return f"Error inspecting dataset: {str(e)}", False