|
|
""" |
|
|
Tool Manager for CodeAct Agent. |
|
|
Unified management system for all types of tools: local functions, decorated tools, and MCP tools. |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Optional, Callable, Any, Union |
|
|
from enum import Enum |
|
|
from dataclasses import dataclass |
|
|
import re |
|
|
|
|
|
|
|
|
from .tool_registry import ToolRegistry, create_module2api_from_functions |
|
|
from .mcp_manager import MCPManager |
|
|
|
|
|
|
|
|
class ToolSource(Enum): |
|
|
"""Enumeration of tool sources.""" |
|
|
LOCAL = "local" |
|
|
DECORATED = "decorated" |
|
|
MCP = "mcp" |
|
|
ALL = "all" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolInfo: |
|
|
"""Comprehensive tool information.""" |
|
|
name: str |
|
|
description: str |
|
|
source: ToolSource |
|
|
function: Optional[Callable] = None |
|
|
schema: Optional[Dict] = None |
|
|
server: Optional[str] = None |
|
|
module: Optional[str] = None |
|
|
required_parameters: List[Dict] = None |
|
|
optional_parameters: List[Dict] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.required_parameters is None: |
|
|
self.required_parameters = [] |
|
|
if self.optional_parameters is None: |
|
|
self.optional_parameters = [] |
|
|
|
|
|
|
|
|
class ToolManager: |
|
|
""" |
|
|
Unified tool management system for CodeAct Agent. |
|
|
|
|
|
Manages all types of tools: |
|
|
- Local functions (legacy function registry) |
|
|
- Decorated tools (@tool decorator) |
|
|
- MCP tools (Model Context Protocol) |
|
|
|
|
|
Provides a single, consistent interface for tool operations. |
|
|
""" |
|
|
|
|
|
def __init__(self, console_display=None): |
|
|
""" |
|
|
Initialize the ToolManager. |
|
|
|
|
|
Args: |
|
|
console_display: Optional console display for MCP status output |
|
|
""" |
|
|
|
|
|
self.tool_registry = ToolRegistry() |
|
|
self.mcp_manager = MCPManager(console_display) |
|
|
|
|
|
|
|
|
self._tool_catalog: Dict[str, ToolInfo] = {} |
|
|
|
|
|
|
|
|
self._legacy_functions: Dict[str, Callable] = {} |
|
|
|
|
|
|
|
|
self._discover_decorated_tools() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_tool(self, tool: Union[Callable, Dict], name: str = None, |
|
|
description: str = None, source: ToolSource = ToolSource.LOCAL) -> bool: |
|
|
""" |
|
|
Add a tool to the manager. |
|
|
|
|
|
Args: |
|
|
tool: Either a callable function or a tool schema dict |
|
|
name: Optional custom name (defaults to function.__name__) |
|
|
description: Optional description (defaults to function.__doc__) |
|
|
source: Source type of the tool |
|
|
|
|
|
Returns: |
|
|
True if successfully added |
|
|
""" |
|
|
try: |
|
|
if callable(tool): |
|
|
|
|
|
tool_name = name or tool.__name__ |
|
|
tool_desc = description or tool.__doc__ or f"Function {tool.__name__}" |
|
|
|
|
|
|
|
|
success = self.tool_registry.add_function_directly(tool_name, tool, tool_desc) |
|
|
|
|
|
if success: |
|
|
|
|
|
tool_info = ToolInfo( |
|
|
name=tool_name, |
|
|
description=tool_desc, |
|
|
source=source, |
|
|
function=tool, |
|
|
schema=self._create_schema_from_function(tool_name, tool, tool_desc) |
|
|
) |
|
|
self._tool_catalog[tool_name] = tool_info |
|
|
|
|
|
if source == ToolSource.LOCAL: |
|
|
self._legacy_functions[tool_name] = tool |
|
|
|
|
|
return True |
|
|
|
|
|
elif isinstance(tool, dict): |
|
|
|
|
|
tool_name = tool.get("name") or name |
|
|
tool_desc = tool.get("description") or description |
|
|
|
|
|
if not tool_name: |
|
|
print("Warning: Tool schema must have a name") |
|
|
return False |
|
|
|
|
|
|
|
|
success = self.tool_registry.register_tool(tool) |
|
|
|
|
|
if success: |
|
|
tool_info = ToolInfo( |
|
|
name=tool_name, |
|
|
description=tool_desc, |
|
|
source=source, |
|
|
schema=tool, |
|
|
required_parameters=tool.get("required_parameters", []), |
|
|
optional_parameters=tool.get("optional_parameters", []) |
|
|
) |
|
|
self._tool_catalog[tool_name] = tool_info |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error adding tool {name}: {e}") |
|
|
return False |
|
|
|
|
|
def remove_tool(self, name: str) -> bool: |
|
|
"""Remove a tool by name from all registries.""" |
|
|
try: |
|
|
success = False |
|
|
|
|
|
|
|
|
if self.tool_registry.remove_tool_by_name(name): |
|
|
success = True |
|
|
|
|
|
|
|
|
if name in self.mcp_manager.mcp_functions: |
|
|
if self.mcp_manager.remove_mcp_tool(name, self.tool_registry): |
|
|
success = True |
|
|
|
|
|
|
|
|
if name in self._legacy_functions: |
|
|
del self._legacy_functions[name] |
|
|
success = True |
|
|
|
|
|
|
|
|
if name in self._tool_catalog: |
|
|
del self._tool_catalog[name] |
|
|
success = True |
|
|
|
|
|
return success |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error removing tool {name}: {e}") |
|
|
return False |
|
|
|
|
|
def get_tool(self, name: str) -> Optional[ToolInfo]: |
|
|
"""Get comprehensive tool information by name.""" |
|
|
return self._tool_catalog.get(name) |
|
|
|
|
|
def get_tool_function(self, name: str) -> Optional[Callable]: |
|
|
"""Get the actual function object by name.""" |
|
|
tool_info = self.get_tool(name) |
|
|
if tool_info and tool_info.function: |
|
|
return tool_info.function |
|
|
|
|
|
|
|
|
return self.tool_registry.get_function_by_name(name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_tools(self, source: ToolSource = ToolSource.ALL, |
|
|
include_details: bool = False) -> List[Dict]: |
|
|
""" |
|
|
List tools with optional filtering by source. |
|
|
|
|
|
Args: |
|
|
source: Filter by tool source (LOCAL, DECORATED, MCP, ALL) |
|
|
include_details: Whether to include detailed information |
|
|
|
|
|
Returns: |
|
|
List of tool dictionaries |
|
|
""" |
|
|
tools = [] |
|
|
|
|
|
for tool_name, tool_info in self._tool_catalog.items(): |
|
|
if source == ToolSource.ALL or tool_info.source == source: |
|
|
if include_details: |
|
|
tools.append({ |
|
|
"name": tool_info.name, |
|
|
"description": tool_info.description, |
|
|
"source": tool_info.source.value, |
|
|
"server": tool_info.server, |
|
|
"module": tool_info.module, |
|
|
"has_function": tool_info.function is not None, |
|
|
"required_params": len(tool_info.required_parameters), |
|
|
"optional_params": len(tool_info.optional_parameters) |
|
|
}) |
|
|
else: |
|
|
tools.append({ |
|
|
"name": tool_info.name, |
|
|
"description": tool_info.description, |
|
|
"source": tool_info.source.value |
|
|
}) |
|
|
|
|
|
return sorted(tools, key=lambda x: x["name"]) |
|
|
|
|
|
def search_tools(self, query: str, search_descriptions: bool = True) -> List[Dict]: |
|
|
""" |
|
|
Search tools by name and optionally description. |
|
|
|
|
|
Args: |
|
|
query: Search query (supports regex) |
|
|
search_descriptions: Whether to also search in descriptions |
|
|
|
|
|
Returns: |
|
|
List of matching tools |
|
|
""" |
|
|
pattern = re.compile(query, re.IGNORECASE) |
|
|
matching_tools = [] |
|
|
|
|
|
for tool_name, tool_info in self._tool_catalog.items(): |
|
|
match = False |
|
|
|
|
|
|
|
|
if pattern.search(tool_name): |
|
|
match = True |
|
|
|
|
|
|
|
|
elif search_descriptions and pattern.search(tool_info.description or ""): |
|
|
match = True |
|
|
|
|
|
if match: |
|
|
matching_tools.append({ |
|
|
"name": tool_info.name, |
|
|
"description": tool_info.description, |
|
|
"source": tool_info.source.value, |
|
|
"server": tool_info.server |
|
|
}) |
|
|
|
|
|
return sorted(matching_tools, key=lambda x: x["name"]) |
|
|
|
|
|
def get_tools_by_source(self, source: ToolSource) -> Dict[str, ToolInfo]: |
|
|
"""Get all tools from a specific source.""" |
|
|
return { |
|
|
name: tool_info |
|
|
for name, tool_info in self._tool_catalog.items() |
|
|
if tool_info.source == source |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_mcp_server(self, config_path: str = "./mcp_config.yaml") -> None: |
|
|
"""Add MCP tools from configuration file.""" |
|
|
try: |
|
|
|
|
|
self.mcp_manager.add_mcp(config_path, self.tool_registry) |
|
|
|
|
|
|
|
|
mcp_tools = self.mcp_manager.list_mcp_tools() |
|
|
for tool_name, tool_data in mcp_tools.items(): |
|
|
tool_info = ToolInfo( |
|
|
name=tool_name, |
|
|
description=tool_data.get("description", "MCP tool"), |
|
|
source=ToolSource.MCP, |
|
|
function=tool_data.get("function"), |
|
|
server=tool_data.get("server"), |
|
|
module=tool_data.get("module"), |
|
|
required_parameters=tool_data.get("required_parameters", []), |
|
|
optional_parameters=tool_data.get("optional_parameters", []) |
|
|
) |
|
|
self._tool_catalog[tool_name] = tool_info |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error adding MCP server: {e}") |
|
|
|
|
|
def list_mcp_servers(self) -> Dict[str, List[str]]: |
|
|
"""List all MCP servers and their tools.""" |
|
|
mcp_tools = self.get_tools_by_source(ToolSource.MCP) |
|
|
servers = {} |
|
|
|
|
|
for tool_name, tool_info in mcp_tools.items(): |
|
|
server_name = tool_info.server or "unknown" |
|
|
if server_name not in servers: |
|
|
servers[server_name] = [] |
|
|
servers[server_name].append(tool_name) |
|
|
|
|
|
return servers |
|
|
|
|
|
def show_mcp_status(self) -> None: |
|
|
"""Display detailed MCP status.""" |
|
|
self.mcp_manager.show_mcp_status() |
|
|
|
|
|
def get_mcp_summary(self) -> Dict[str, Any]: |
|
|
"""Get MCP tools summary.""" |
|
|
return self.mcp_manager.get_mcp_summary() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_all_functions(self) -> Dict[str, Callable]: |
|
|
"""Get all available functions as a dictionary.""" |
|
|
functions = {} |
|
|
|
|
|
|
|
|
functions.update(self.tool_registry.get_all_functions()) |
|
|
|
|
|
|
|
|
functions.update(self._legacy_functions) |
|
|
|
|
|
|
|
|
mcp_tools = self.mcp_manager.list_mcp_tools() |
|
|
for tool_name, tool_data in mcp_tools.items(): |
|
|
if tool_data.get("function"): |
|
|
functions[tool_name] = tool_data["function"] |
|
|
|
|
|
return functions |
|
|
|
|
|
def get_tool_schemas(self, openai_format: bool = True) -> List[Dict]: |
|
|
""" |
|
|
Get tool schemas for all tools. |
|
|
|
|
|
Args: |
|
|
openai_format: Whether to format as OpenAI function schemas |
|
|
|
|
|
Returns: |
|
|
List of tool schemas |
|
|
""" |
|
|
schemas = [] |
|
|
|
|
|
for tool_name, tool_info in self._tool_catalog.items(): |
|
|
if openai_format: |
|
|
|
|
|
schema = { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": tool_info.name, |
|
|
"description": tool_info.description, |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": {}, |
|
|
"required": [] |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for param in tool_info.required_parameters: |
|
|
param_schema = { |
|
|
"type": param.get("type", "string"), |
|
|
"description": param.get("description", "") |
|
|
} |
|
|
|
|
|
if "enum" in param: |
|
|
param_schema["enum"] = param["enum"] |
|
|
schema["function"]["parameters"]["properties"][param["name"]] = param_schema |
|
|
schema["function"]["parameters"]["required"].append(param["name"]) |
|
|
|
|
|
|
|
|
for param in tool_info.optional_parameters: |
|
|
param_schema = { |
|
|
"type": param.get("type", "string"), |
|
|
"description": param.get("description", "") |
|
|
} |
|
|
|
|
|
if "enum" in param: |
|
|
param_schema["enum"] = param["enum"] |
|
|
if "default" in param: |
|
|
param_schema["default"] = param["default"] |
|
|
schema["function"]["parameters"]["properties"][param["name"]] = param_schema |
|
|
|
|
|
schemas.append(schema) |
|
|
else: |
|
|
|
|
|
if tool_info.schema: |
|
|
schemas.append(tool_info.schema) |
|
|
|
|
|
return schemas |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tool_statistics(self) -> Dict[str, Any]: |
|
|
"""Get comprehensive tool statistics.""" |
|
|
stats = { |
|
|
"total_tools": len(self._tool_catalog), |
|
|
"by_source": {source.value: 0 for source in ToolSource if source != ToolSource.ALL}, |
|
|
"with_functions": 0, |
|
|
"mcp_servers": len(self.list_mcp_servers()), |
|
|
"tool_registry_size": len(self.tool_registry.tools), |
|
|
"legacy_functions": len(self._legacy_functions) |
|
|
} |
|
|
|
|
|
for tool_info in self._tool_catalog.values(): |
|
|
stats["by_source"][tool_info.source.value] += 1 |
|
|
if tool_info.function: |
|
|
stats["with_functions"] += 1 |
|
|
|
|
|
return stats |
|
|
|
|
|
def validate_tools(self) -> Dict[str, List[str]]: |
|
|
"""Validate all tools and return any issues found.""" |
|
|
issues = { |
|
|
"missing_functions": [], |
|
|
"missing_descriptions": [], |
|
|
"duplicate_names": [], |
|
|
"invalid_schemas": [] |
|
|
} |
|
|
|
|
|
seen_names = set() |
|
|
for tool_name, tool_info in self._tool_catalog.items(): |
|
|
|
|
|
if tool_name in seen_names: |
|
|
issues["duplicate_names"].append(tool_name) |
|
|
seen_names.add(tool_name) |
|
|
|
|
|
|
|
|
if not tool_info.function and tool_info.source != ToolSource.MCP: |
|
|
issues["missing_functions"].append(tool_name) |
|
|
|
|
|
|
|
|
if not tool_info.description or tool_info.description.strip() == "": |
|
|
issues["missing_descriptions"].append(tool_name) |
|
|
|
|
|
return issues |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _discover_decorated_tools(self) -> None: |
|
|
"""Discover and register tools marked with @tool decorator.""" |
|
|
try: |
|
|
from .builtin_tools import get_all_tool_functions |
|
|
|
|
|
tool_functions = get_all_tool_functions() |
|
|
for func in tool_functions: |
|
|
name = getattr(func, '_tool_name', func.__name__) |
|
|
description = getattr(func, '_tool_description', func.__doc__ or f"Function {func.__name__}") |
|
|
|
|
|
tool_info = ToolInfo( |
|
|
name=name, |
|
|
description=description, |
|
|
source=ToolSource.DECORATED, |
|
|
function=func, |
|
|
schema=self._create_schema_from_function(name, func, description) |
|
|
) |
|
|
self._tool_catalog[name] = tool_info |
|
|
|
|
|
|
|
|
self.tool_registry.add_function_directly(name, func, description) |
|
|
|
|
|
except ImportError: |
|
|
print("Warning: Could not import builtin_tools module for decorated tool discovery") |
|
|
|
|
|
def _create_schema_from_function(self, name: str, function: Callable, description: str) -> Dict: |
|
|
"""Create a tool schema from a function object.""" |
|
|
return self.tool_registry._create_schema_from_function(name, function, description) |
|
|
|
|
|
def _refresh_catalog(self) -> None: |
|
|
"""Refresh the tool catalog from all sources.""" |
|
|
|
|
|
self._tool_catalog.clear() |
|
|
|
|
|
|
|
|
self._discover_decorated_tools() |
|
|
|
|
|
|
|
|
mcp_tools = self.mcp_manager.list_mcp_tools() |
|
|
for tool_name, tool_data in mcp_tools.items(): |
|
|
tool_info = ToolInfo( |
|
|
name=tool_name, |
|
|
description=tool_data.get("description", "MCP tool"), |
|
|
source=ToolSource.MCP, |
|
|
function=tool_data.get("function"), |
|
|
server=tool_data.get("server"), |
|
|
module=tool_data.get("module"), |
|
|
required_parameters=tool_data.get("required_parameters", []), |
|
|
optional_parameters=tool_data.get("optional_parameters", []) |
|
|
) |
|
|
self._tool_catalog[tool_name] = tool_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_legacy_functions(self, functions: Dict[str, Callable]) -> int: |
|
|
"""Add legacy functions for backward compatibility.""" |
|
|
added_count = 0 |
|
|
for name, func in functions.items(): |
|
|
if self.add_tool(func, name, source=ToolSource.LOCAL): |
|
|
added_count += 1 |
|
|
return added_count |