|
|
|
|
|
import inspect |
|
|
from typing import Dict, List, Optional, Any |
|
|
|
|
|
from ..core.module import BaseModule |
|
|
|
|
|
ALLOWED_TYPES = ["string", "number", "integer", "boolean", "object", "array"] |
|
|
|
|
|
|
|
|
class Tool(BaseModule): |
|
|
name: str |
|
|
description: str |
|
|
inputs: Dict[str, Dict[str, Any]] |
|
|
required: Optional[List[str]] = None |
|
|
|
|
|
""" |
|
|
inputs: {"input_name": {"type": "string", "description": "input description"}, ...} |
|
|
""" |
|
|
|
|
|
def __init_subclass__(cls): |
|
|
super().__init_subclass__() |
|
|
cls.validate_attributes() |
|
|
|
|
|
def get_tool_schema(self) -> Dict: |
|
|
return { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": self.name, |
|
|
"description": self.description, |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": self.inputs, |
|
|
"required": self.required |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def validate_attributes(cls): |
|
|
required_attributes = { |
|
|
"name": str, |
|
|
"description": str, |
|
|
"inputs": dict |
|
|
} |
|
|
|
|
|
json_to_python = { |
|
|
"string": str, |
|
|
"integer": int, |
|
|
"number": float, |
|
|
"boolean": bool, |
|
|
"object": dict, |
|
|
"array": list, |
|
|
} |
|
|
|
|
|
for attr, attr_type in required_attributes.items(): |
|
|
if not hasattr(cls, attr): |
|
|
raise ValueError(f"Attribute {attr} is required") |
|
|
if not isinstance(getattr(cls, attr), attr_type): |
|
|
raise ValueError(f"Attribute {attr} must be of type {attr_type}") |
|
|
|
|
|
for input_name, input_content in cls.inputs.items(): |
|
|
if not isinstance(input_content, dict): |
|
|
raise ValueError(f"Input '{input_name}' must be a dictionary") |
|
|
if "type" not in input_content or "description" not in input_content: |
|
|
raise ValueError(f"Input '{input_name}' must have 'type' and 'description'") |
|
|
if input_content["type"] not in ALLOWED_TYPES: |
|
|
raise ValueError(f"Input '{input_name}' must have a valid type, should be one of {ALLOWED_TYPES}") |
|
|
|
|
|
call_signature = inspect.signature(cls.__call__) |
|
|
if input_name not in call_signature.parameters: |
|
|
raise ValueError(f"Input '{input_name}' is not found in __call__") |
|
|
if call_signature.parameters[input_name].annotation != json_to_python[input_content["type"]]: |
|
|
raise ValueError(f"Input '{input_name}' has a type mismatch in __call__") |
|
|
|
|
|
if cls.required: |
|
|
for required_input in cls.required: |
|
|
if required_input not in cls.inputs: |
|
|
raise ValueError(f"Required input '{required_input}' is not found in inputs") |
|
|
|
|
|
def __call__(self, **kwargs): |
|
|
raise NotImplementedError("All tools must implement __call__") |
|
|
|
|
|
class Toolkit(BaseModule): |
|
|
name: str |
|
|
tools: List[Tool] |
|
|
|
|
|
def get_tool_names(self) -> List[str]: |
|
|
return [tool.name for tool in self.tools] |
|
|
|
|
|
def get_tool_descriptions(self) -> List[str]: |
|
|
return [tool.description for tool in self.tools] |
|
|
|
|
|
def get_tool_inputs(self) -> List[Dict]: |
|
|
return [tool.inputs for tool in self.tools] |
|
|
|
|
|
def add_tool(self, tool: Tool): |
|
|
self.tools.append(tool) |
|
|
|
|
|
def remove_tool(self, tool_name: str): |
|
|
self.tools = [tool for tool in self.tools if tool.name != tool_name] |
|
|
|
|
|
def get_tool(self, tool_name: str) -> Tool: |
|
|
for tool in self.tools: |
|
|
if tool.name == tool_name: |
|
|
return tool |
|
|
raise ValueError(f"Tool '{tool_name}' not found") |
|
|
|
|
|
def get_tools(self) -> List[Tool]: |
|
|
return self.tools |
|
|
|
|
|
def get_tool_schemas(self) -> List[Dict]: |
|
|
return [tool.get_tool_schema() for tool in self.tools] |
|
|
|