File size: 4,265 Bytes
7a4edb9
 
 
 
 
 
ca2ba49
e00943e
7a4edb9
85ff578
7a4edb9
85ff578
 
7a4edb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85ff578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4edb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca2ba49
 
 
 
 
 
 
 
 
7a4edb9
 
 
 
 
 
 
 
 
ca2ba49
 
7a4edb9
 
 
ca2ba49
 
7a4edb9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Base components for the tools package: LLM dispatcher and central registry.
"""

import logging
import json
from langchain_core.tools import Tool, StructuredTool
from langchain_google_genai import ChatGoogleGenerativeAI
try:
    from ..config import LOG_LEVEL, GEMINI_API_KEY, OPENROUTER_API_KEY, DEFAULT_MODEL
except (ImportError, ValueError):
    from config import LOG_LEVEL, GEMINI_API_KEY, OPENROUTER_API_KEY, DEFAULT_MODEL
from langchain_openai import ChatOpenAI

logger = logging.getLogger(__name__)

# Global registry for tools
# Format: { "name": { "fn": callable, "description": str, "parameters": dict } }
TOOLS = {}

def register_tool(name: str, description: str, parameters: list):
    """Decorator to register a tool.
    parameters should be a list of dicts: [{"name": "...", "type": "...", "description": "...", "required": True/False}]
    """
    def decorator(func):
        TOOLS[name] = {
            "fn": func,
            "description": description,
            "parameters": parameters,
        }
        return func
    return decorator

def get_llm():
    """Initialize and return LLM (Prioritizing Gemini, then OpenRouter)."""
    if GEMINI_API_KEY:
        logger.info(f"Initializing Google Gemini LLM ({DEFAULT_MODEL})...")
        return ChatGoogleGenerativeAI(
            model=DEFAULT_MODEL,
            google_api_key=GEMINI_API_KEY,
            temperature=0.1,
            max_retries=5,
            timeout=60
        )
    
    if OPENROUTER_API_KEY:
        logger.info(f"Initializing OpenRouter LLM: {DEFAULT_MODEL}")
        return ChatOpenAI(
            model=DEFAULT_MODEL,
            openai_api_key=OPENROUTER_API_KEY,
            openai_api_base="https://openrouter.ai/api/v1",
            temperature=0.1,
            max_tokens=2048,
            max_retries=3,
            request_timeout=60
        )
        
    raise ValueError("Neither GEMINI_API_KEY nor OPENROUTER_API_KEY is configured. Check your .env file.")

def get_tool_schemas() -> list[dict]:
    """Return tool schemas in Anthropic API format."""
    schemas = []
    for name, tool in TOOLS.items():
        properties = {}
        required = []
        for p in tool["parameters"]:
            properties[p["name"]] = {"type": p["type"], "description": p.get("description", p["name"])}
            if p.get("required", True):
                required.append(p["name"])
                
        schemas.append({
            "name": name,
            "description": tool["description"],
            "input_schema": {
                "type": "object",
                "properties": properties,
                "required": required,
            },
        })
    return schemas

def execute_tool(name: str, args: dict) -> str:
    """Execute a tool by name and return result as string."""
    tool = TOOLS.get(name)
    if not tool:
        return f"Tool '{name}' does not exist"
    
    try:
        # Check if args is a string (sometimes LLMs pass JSON string instead of dict)
        if isinstance(args, str):
            try:
                args = json.loads(args)
            except:
                # If it's a pure string, try to map it to the first parameter of the tool
                # Only if the tool has at least one parameter
                if tool["parameters"]:
                    args = {tool["parameters"][0]["name"]: args}
                else:
                    args = {}

        if not isinstance(args, dict):
            return f"Error: Tool arguments must be a mapping (dict), got {type(args).__name__}"

        result = tool["fn"](**args)
        if isinstance(result, (dict, list)):
            return json.dumps(result, ensure_ascii=False, indent=2)
        return str(result)
    except Exception as e:
        logger.error(f"Error executing tool {name}: {e}")
        return f"Error: {e}"

def get_langchain_tools() -> list[StructuredTool]:
    """Return tools in LangChain-native StructuredTool objects for better multi-arg support."""
    lc_tools = []
    for name, tool in TOOLS.items():
        lc_tools.append(
            StructuredTool.from_function(
                func=tool["fn"],
                name=name,
                description=tool["description"]
            )
        )
    return lc_tools