Spaces:
Sleeping
Sleeping
| """ | |
| Foundation classes for the tool system. | |
| BaseTool | |
| Abstract base class that every tool extends. Provides input validation | |
| against a JSON-Schema ``parameters`` dict, schema serialization for the | |
| LLM prompt, and a standard ``ToolResult`` return type. | |
| ToolRegistry | |
| Central catalogue of available tools. The orchestrator calls | |
| ``registry.execute(name, params)`` and the registry handles lookup, | |
| validation, timeout enforcement, and structured result wrapping. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import time | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| # ====================================================================== # | |
| # ToolResult | |
| # ====================================================================== # | |
| class ToolResult: | |
| """Uniform envelope for tool execution outcomes.""" | |
| success: bool | |
| data: dict[str, Any] | None = None | |
| error: str | None = None | |
| execution_time_ms: float = 0.0 | |
| cached: bool = False | |
| # ====================================================================== # | |
| # BaseTool | |
| # ====================================================================== # | |
| class BaseTool(ABC): | |
| """Abstract base class for agent tools.""" | |
| name: str = "" | |
| description: str = "" | |
| parameters: dict = {} # JSON Schema (type "object") | |
| async def execute(self, **kwargs: Any) -> ToolResult: | |
| """Run the tool with validated parameters and return a ToolResult.""" | |
| ... | |
| # ------------------------------------------------------------------ # | |
| # Validation | |
| # ------------------------------------------------------------------ # | |
| def validate_input(self, params: dict) -> dict: | |
| """Validate *params* against ``self.parameters`` JSON Schema. | |
| This is a lightweight validator that checks required fields and basic | |
| types without pulling in a full JSON-Schema library. Returns a | |
| cleaned copy of *params* with defaults filled in. | |
| """ | |
| schema_props = self.parameters.get("properties", {}) | |
| required = set(self.parameters.get("required", [])) | |
| cleaned: dict[str, Any] = {} | |
| # Check required fields. | |
| for key in required: | |
| if key not in params: | |
| raise ValueError(f"Missing required parameter: {key}") | |
| # Validate and copy. | |
| for key, value in params.items(): | |
| prop_schema = schema_props.get(key) | |
| if prop_schema is None: | |
| # Extra keys are tolerated but logged. | |
| logger.debug("Ignoring unknown parameter '%s' for tool '%s'", key, self.name) | |
| continue | |
| expected_type = prop_schema.get("type") | |
| if expected_type and not _check_type(value, expected_type): | |
| raise TypeError( | |
| f"Parameter '{key}' expected type '{expected_type}', " | |
| f"got {type(value).__name__}" | |
| ) | |
| cleaned[key] = value | |
| # Fill defaults. | |
| for key, prop_schema in schema_props.items(): | |
| if key not in cleaned and "default" in prop_schema: | |
| cleaned[key] = prop_schema["default"] | |
| return cleaned | |
| # ------------------------------------------------------------------ # | |
| # Schema serialization | |
| # ------------------------------------------------------------------ # | |
| def to_schema(self) -> dict: | |
| """Format for inclusion in the LLM tool-calling prompt.""" | |
| return { | |
| "type": "function", | |
| "function": { | |
| "name": self.name, | |
| "description": self.description, | |
| "parameters": self.parameters, | |
| }, | |
| } | |
| # ====================================================================== # | |
| # ToolRegistry | |
| # ====================================================================== # | |
| class ToolRegistry: | |
| """Central catalogue of available tools with execution support.""" | |
| def __init__(self) -> None: | |
| self._tools: dict[str, BaseTool] = {} | |
| def register(self, tool: BaseTool) -> None: | |
| if tool.name in self._tools: | |
| logger.warning("Overwriting existing tool '%s'", tool.name) | |
| self._tools[tool.name] = tool | |
| logger.debug("Registered tool: %s", tool.name) | |
| def get(self, name: str) -> BaseTool | None: | |
| return self._tools.get(name) | |
| def list_names(self) -> list[str]: | |
| return list(self._tools.keys()) | |
| def get_all_schemas(self) -> list[dict]: | |
| return [t.to_schema() for t in self._tools.values()] | |
| async def execute( | |
| self, | |
| name: str, | |
| params: dict, | |
| timeout: float = 10.0, | |
| ) -> ToolResult: | |
| """Look up, validate, and execute a tool by name. | |
| Parameters | |
| ---------- | |
| name: | |
| Registered tool name. | |
| params: | |
| Raw parameter dict (will be validated). | |
| timeout: | |
| Maximum seconds to allow for execution. | |
| Returns | |
| ------- | |
| ToolResult | |
| """ | |
| tool = self._tools.get(name) | |
| if tool is None: | |
| return ToolResult( | |
| success=False, | |
| error=f"Unknown tool: '{name}'. Available: {self.list_names()}", | |
| ) | |
| # Validate inputs. | |
| try: | |
| validated = tool.validate_input(params) | |
| except (ValueError, TypeError) as exc: | |
| return ToolResult(success=False, error=f"Validation error: {exc}") | |
| # Execute with timeout. | |
| start = time.perf_counter() | |
| try: | |
| result = await asyncio.wait_for( | |
| tool.execute(**validated), timeout=timeout | |
| ) | |
| except asyncio.TimeoutError: | |
| elapsed = (time.perf_counter() - start) * 1000 | |
| return ToolResult( | |
| success=False, | |
| error=f"Tool '{name}' timed out after {timeout:.1f}s", | |
| execution_time_ms=elapsed, | |
| ) | |
| except Exception as exc: | |
| elapsed = (time.perf_counter() - start) * 1000 | |
| logger.exception("Tool '%s' raised an exception", name) | |
| return ToolResult( | |
| success=False, | |
| error=f"Tool execution error: {exc}", | |
| execution_time_ms=elapsed, | |
| ) | |
| result.execution_time_ms = (time.perf_counter() - start) * 1000 | |
| return result | |
| # ====================================================================== # | |
| # Helpers | |
| # ====================================================================== # | |
| def _check_type(value: Any, json_type: str) -> bool: | |
| """Loose JSON-Schema type check.""" | |
| type_map: dict[str, tuple[type, ...]] = { | |
| "string": (str,), | |
| "integer": (int,), | |
| "number": (int, float), | |
| "boolean": (bool,), | |
| "array": (list, tuple), | |
| "object": (dict,), | |
| } | |
| allowed = type_map.get(json_type) | |
| if allowed is None: | |
| return True # unknown type -- accept anything | |
| return isinstance(value, allowed) | |