| from typing import Dict, Type, Any, List, Optional, Callable |
| from agentpress.tool import Tool, SchemaType, ToolSchema |
| from utils.logger import logger |
|
|
|
|
| class ToolRegistry: |
| """Registry for managing and accessing tools. |
| |
| Maintains a collection of tool instances and their schemas, allowing for |
| selective registration of tool functions and easy access to tool capabilities. |
| |
| Attributes: |
| tools (Dict[str, Dict[str, Any]]): OpenAPI-style tools and schemas |
| xml_tools (Dict[str, Dict[str, Any]]): XML-style tools and schemas |
| |
| Methods: |
| register_tool: Register a tool with optional function filtering |
| get_tool: Get a specific tool by name |
| get_xml_tool: Get a tool by XML tag name |
| get_openapi_schemas: Get OpenAPI schemas for function calling |
| get_xml_examples: Get examples of XML tool usage |
| """ |
| |
| def __init__(self): |
| """Initialize a new ToolRegistry instance.""" |
| self.tools = {} |
| self.xml_tools = {} |
| logger.debug("Initialized new ToolRegistry instance") |
| |
| def register_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs): |
| """Register a tool with optional function filtering. |
| |
| Args: |
| tool_class: The tool class to register |
| function_names: Optional list of specific functions to register |
| **kwargs: Additional arguments passed to tool initialization |
| |
| Notes: |
| - If function_names is None, all functions are registered |
| - Handles both OpenAPI and XML schema registration |
| """ |
| logger.debug(f"Registering tool class: {tool_class.__name__}") |
| tool_instance = tool_class(**kwargs) |
| schemas = tool_instance.get_schemas() |
| |
| logger.debug(f"Available schemas for {tool_class.__name__}: {list(schemas.keys())}") |
| |
| registered_openapi = 0 |
| registered_xml = 0 |
| |
| for func_name, schema_list in schemas.items(): |
| if function_names is None or func_name in function_names: |
| for schema in schema_list: |
| if schema.schema_type == SchemaType.OPENAPI: |
| self.tools[func_name] = { |
| "instance": tool_instance, |
| "schema": schema |
| } |
| registered_openapi += 1 |
| logger.debug(f"Registered OpenAPI function {func_name} from {tool_class.__name__}") |
| |
| if schema.schema_type == SchemaType.XML and schema.xml_schema: |
| self.xml_tools[schema.xml_schema.tag_name] = { |
| "instance": tool_instance, |
| "method": func_name, |
| "schema": schema |
| } |
| registered_xml += 1 |
| logger.debug(f"Registered XML tag {schema.xml_schema.tag_name} -> {func_name} from {tool_class.__name__}") |
| |
| logger.debug(f"Tool registration complete for {tool_class.__name__}: {registered_openapi} OpenAPI functions, {registered_xml} XML tags") |
|
|
| def get_available_functions(self) -> Dict[str, Callable]: |
| """Get all available tool functions. |
| |
| Returns: |
| Dict mapping function names to their implementations |
| """ |
| available_functions = {} |
| |
| |
| for tool_name, tool_info in self.tools.items(): |
| tool_instance = tool_info['instance'] |
| function_name = tool_name |
| function = getattr(tool_instance, function_name) |
| available_functions[function_name] = function |
| |
| |
| for tag_name, tool_info in self.xml_tools.items(): |
| tool_instance = tool_info['instance'] |
| method_name = tool_info['method'] |
| function = getattr(tool_instance, method_name) |
| available_functions[method_name] = function |
| |
| logger.debug(f"Retrieved {len(available_functions)} available functions") |
| return available_functions |
|
|
| def get_tool(self, tool_name: str) -> Dict[str, Any]: |
| """Get a specific tool by name. |
| |
| Args: |
| tool_name: Name of the tool function |
| |
| Returns: |
| Dict containing tool instance and schema, or empty dict if not found |
| """ |
| tool = self.tools.get(tool_name, {}) |
| if not tool: |
| logger.warning(f"Tool not found: {tool_name}") |
| return tool |
|
|
| def get_xml_tool(self, tag_name: str) -> Dict[str, Any]: |
| """Get tool info by XML tag name. |
| |
| Args: |
| tag_name: XML tag name for the tool |
| |
| Returns: |
| Dict containing tool instance, method name, and schema |
| """ |
| tool = self.xml_tools.get(tag_name, {}) |
| if not tool: |
| logger.warning(f"XML tool not found for tag: {tag_name}") |
| return tool |
|
|
| def get_openapi_schemas(self) -> List[Dict[str, Any]]: |
| """Get OpenAPI schemas for function calling. |
| |
| Returns: |
| List of OpenAPI-compatible schema definitions |
| """ |
| schemas = [ |
| tool_info['schema'].schema |
| for tool_info in self.tools.values() |
| if tool_info['schema'].schema_type == SchemaType.OPENAPI |
| ] |
| logger.debug(f"Retrieved {len(schemas)} OpenAPI schemas") |
| return schemas |
|
|
| def get_xml_examples(self) -> Dict[str, str]: |
| """Get all XML tag examples. |
| |
| Returns: |
| Dict mapping tag names to their example usage |
| """ |
| examples = {} |
| for tool_info in self.xml_tools.values(): |
| schema = tool_info['schema'] |
| if schema.xml_schema and schema.xml_schema.example: |
| examples[schema.xml_schema.tag_name] = schema.xml_schema.example |
| logger.debug(f"Retrieved {len(examples)} XML examples") |
| return examples |
|
|