WolfDavid's picture
Upload folder using huggingface_hub
75418e4 verified
"""
Foundation classes for the tool system.
BaseTool
Abstract base class that every tool extends. Provides input validation
against a JSON-Schema ``parameters`` dict, schema serialization for the
LLM prompt, and a standard ``ToolResult`` return type.
ToolRegistry
Central catalogue of available tools. The orchestrator calls
``registry.execute(name, params)`` and the registry handles lookup,
validation, timeout enforcement, and structured result wrapping.
"""
from __future__ import annotations
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
# ====================================================================== #
# ToolResult
# ====================================================================== #
@dataclass
class ToolResult:
"""Uniform envelope for tool execution outcomes."""
success: bool
data: dict[str, Any] | None = None
error: str | None = None
execution_time_ms: float = 0.0
cached: bool = False
# ====================================================================== #
# BaseTool
# ====================================================================== #
class BaseTool(ABC):
"""Abstract base class for agent tools."""
name: str = ""
description: str = ""
parameters: dict = {} # JSON Schema (type "object")
@abstractmethod
async def execute(self, **kwargs: Any) -> ToolResult:
"""Run the tool with validated parameters and return a ToolResult."""
...
# ------------------------------------------------------------------ #
# Validation
# ------------------------------------------------------------------ #
def validate_input(self, params: dict) -> dict:
"""Validate *params* against ``self.parameters`` JSON Schema.
This is a lightweight validator that checks required fields and basic
types without pulling in a full JSON-Schema library. Returns a
cleaned copy of *params* with defaults filled in.
"""
schema_props = self.parameters.get("properties", {})
required = set(self.parameters.get("required", []))
cleaned: dict[str, Any] = {}
# Check required fields.
for key in required:
if key not in params:
raise ValueError(f"Missing required parameter: {key}")
# Validate and copy.
for key, value in params.items():
prop_schema = schema_props.get(key)
if prop_schema is None:
# Extra keys are tolerated but logged.
logger.debug("Ignoring unknown parameter '%s' for tool '%s'", key, self.name)
continue
expected_type = prop_schema.get("type")
if expected_type and not _check_type(value, expected_type):
raise TypeError(
f"Parameter '{key}' expected type '{expected_type}', "
f"got {type(value).__name__}"
)
cleaned[key] = value
# Fill defaults.
for key, prop_schema in schema_props.items():
if key not in cleaned and "default" in prop_schema:
cleaned[key] = prop_schema["default"]
return cleaned
# ------------------------------------------------------------------ #
# Schema serialization
# ------------------------------------------------------------------ #
def to_schema(self) -> dict:
"""Format for inclusion in the LLM tool-calling prompt."""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
}
# ====================================================================== #
# ToolRegistry
# ====================================================================== #
class ToolRegistry:
"""Central catalogue of available tools with execution support."""
def __init__(self) -> None:
self._tools: dict[str, BaseTool] = {}
def register(self, tool: BaseTool) -> None:
if tool.name in self._tools:
logger.warning("Overwriting existing tool '%s'", tool.name)
self._tools[tool.name] = tool
logger.debug("Registered tool: %s", tool.name)
def get(self, name: str) -> BaseTool | None:
return self._tools.get(name)
def list_names(self) -> list[str]:
return list(self._tools.keys())
def get_all_schemas(self) -> list[dict]:
return [t.to_schema() for t in self._tools.values()]
async def execute(
self,
name: str,
params: dict,
timeout: float = 10.0,
) -> ToolResult:
"""Look up, validate, and execute a tool by name.
Parameters
----------
name:
Registered tool name.
params:
Raw parameter dict (will be validated).
timeout:
Maximum seconds to allow for execution.
Returns
-------
ToolResult
"""
tool = self._tools.get(name)
if tool is None:
return ToolResult(
success=False,
error=f"Unknown tool: '{name}'. Available: {self.list_names()}",
)
# Validate inputs.
try:
validated = tool.validate_input(params)
except (ValueError, TypeError) as exc:
return ToolResult(success=False, error=f"Validation error: {exc}")
# Execute with timeout.
start = time.perf_counter()
try:
result = await asyncio.wait_for(
tool.execute(**validated), timeout=timeout
)
except asyncio.TimeoutError:
elapsed = (time.perf_counter() - start) * 1000
return ToolResult(
success=False,
error=f"Tool '{name}' timed out after {timeout:.1f}s",
execution_time_ms=elapsed,
)
except Exception as exc:
elapsed = (time.perf_counter() - start) * 1000
logger.exception("Tool '%s' raised an exception", name)
return ToolResult(
success=False,
error=f"Tool execution error: {exc}",
execution_time_ms=elapsed,
)
result.execution_time_ms = (time.perf_counter() - start) * 1000
return result
# ====================================================================== #
# Helpers
# ====================================================================== #
def _check_type(value: Any, json_type: str) -> bool:
"""Loose JSON-Schema type check."""
type_map: dict[str, tuple[type, ...]] = {
"string": (str,),
"integer": (int,),
"number": (int, float),
"boolean": (bool,),
"array": (list, tuple),
"object": (dict,),
}
allowed = type_map.get(json_type)
if allowed is None:
return True # unknown type -- accept anything
return isinstance(value, allowed)