""" Foundation classes for the tool system. BaseTool Abstract base class that every tool extends. Provides input validation against a JSON-Schema ``parameters`` dict, schema serialization for the LLM prompt, and a standard ``ToolResult`` return type. ToolRegistry Central catalogue of available tools. The orchestrator calls ``registry.execute(name, params)`` and the registry handles lookup, validation, timeout enforcement, and structured result wrapping. """ from __future__ import annotations import asyncio import logging import time from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any logger = logging.getLogger(__name__) # ====================================================================== # # ToolResult # ====================================================================== # @dataclass class ToolResult: """Uniform envelope for tool execution outcomes.""" success: bool data: dict[str, Any] | None = None error: str | None = None execution_time_ms: float = 0.0 cached: bool = False # ====================================================================== # # BaseTool # ====================================================================== # class BaseTool(ABC): """Abstract base class for agent tools.""" name: str = "" description: str = "" parameters: dict = {} # JSON Schema (type "object") @abstractmethod async def execute(self, **kwargs: Any) -> ToolResult: """Run the tool with validated parameters and return a ToolResult.""" ... # ------------------------------------------------------------------ # # Validation # ------------------------------------------------------------------ # def validate_input(self, params: dict) -> dict: """Validate *params* against ``self.parameters`` JSON Schema. This is a lightweight validator that checks required fields and basic types without pulling in a full JSON-Schema library. Returns a cleaned copy of *params* with defaults filled in. """ schema_props = self.parameters.get("properties", {}) required = set(self.parameters.get("required", [])) cleaned: dict[str, Any] = {} # Check required fields. for key in required: if key not in params: raise ValueError(f"Missing required parameter: {key}") # Validate and copy. for key, value in params.items(): prop_schema = schema_props.get(key) if prop_schema is None: # Extra keys are tolerated but logged. logger.debug("Ignoring unknown parameter '%s' for tool '%s'", key, self.name) continue expected_type = prop_schema.get("type") if expected_type and not _check_type(value, expected_type): raise TypeError( f"Parameter '{key}' expected type '{expected_type}', " f"got {type(value).__name__}" ) cleaned[key] = value # Fill defaults. for key, prop_schema in schema_props.items(): if key not in cleaned and "default" in prop_schema: cleaned[key] = prop_schema["default"] return cleaned # ------------------------------------------------------------------ # # Schema serialization # ------------------------------------------------------------------ # def to_schema(self) -> dict: """Format for inclusion in the LLM tool-calling prompt.""" return { "type": "function", "function": { "name": self.name, "description": self.description, "parameters": self.parameters, }, } # ====================================================================== # # ToolRegistry # ====================================================================== # class ToolRegistry: """Central catalogue of available tools with execution support.""" def __init__(self) -> None: self._tools: dict[str, BaseTool] = {} def register(self, tool: BaseTool) -> None: if tool.name in self._tools: logger.warning("Overwriting existing tool '%s'", tool.name) self._tools[tool.name] = tool logger.debug("Registered tool: %s", tool.name) def get(self, name: str) -> BaseTool | None: return self._tools.get(name) def list_names(self) -> list[str]: return list(self._tools.keys()) def get_all_schemas(self) -> list[dict]: return [t.to_schema() for t in self._tools.values()] async def execute( self, name: str, params: dict, timeout: float = 10.0, ) -> ToolResult: """Look up, validate, and execute a tool by name. Parameters ---------- name: Registered tool name. params: Raw parameter dict (will be validated). timeout: Maximum seconds to allow for execution. Returns ------- ToolResult """ tool = self._tools.get(name) if tool is None: return ToolResult( success=False, error=f"Unknown tool: '{name}'. Available: {self.list_names()}", ) # Validate inputs. try: validated = tool.validate_input(params) except (ValueError, TypeError) as exc: return ToolResult(success=False, error=f"Validation error: {exc}") # Execute with timeout. start = time.perf_counter() try: result = await asyncio.wait_for( tool.execute(**validated), timeout=timeout ) except asyncio.TimeoutError: elapsed = (time.perf_counter() - start) * 1000 return ToolResult( success=False, error=f"Tool '{name}' timed out after {timeout:.1f}s", execution_time_ms=elapsed, ) except Exception as exc: elapsed = (time.perf_counter() - start) * 1000 logger.exception("Tool '%s' raised an exception", name) return ToolResult( success=False, error=f"Tool execution error: {exc}", execution_time_ms=elapsed, ) result.execution_time_ms = (time.perf_counter() - start) * 1000 return result # ====================================================================== # # Helpers # ====================================================================== # def _check_type(value: Any, json_type: str) -> bool: """Loose JSON-Schema type check.""" type_map: dict[str, tuple[type, ...]] = { "string": (str,), "integer": (int,), "number": (int, float), "boolean": (bool,), "array": (list, tuple), "object": (dict,), } allowed = type_map.get(json_type) if allowed is None: return True # unknown type -- accept anything return isinstance(value, allowed)