File size: 4,057 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import asyncio
import json
import os
import time
from typing import Any, Dict, List, Optional, Set


def tool_get_current_time(params: Dict[str, Any]) -> Dict[str, Any]:
    return {"current_time": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())}


def tool_echo(params: Dict[str, Any]) -> Dict[str, Any]:
    return {"echo": params}


def tool_sum(params: Dict[str, Any]) -> Dict[str, Any]:
    values = params.get("values")
    if isinstance(values, list):
        try:
            total = sum(float(v) for v in values)
        except Exception:
            total = None
    else:
        total = None
    return {"sum": total, "count": len(values) if isinstance(values, list) else 0}


FUNCTION_REGISTRY = {
    "get_current_time": tool_get_current_time,
    "echo": tool_echo,
    "sum": tool_sum,
}

# Runtime-allowed tool names from incoming requests (OpenAI tools array)
_ALLOWED_RUNTIME_TOOLS: Set[str] = set()
_runtime_mcp_endpoint: Optional[str] = None


def register_runtime_tools(
    tools: Optional[List[Dict[str, Any]]], mcp_endpoint: Optional[str] = None
) -> None:
    """Register tool names declared in the request as allowed.
    The server may delegate unknown tools to MCP if configured.
    """
    # Reset per-request registry to avoid leakage across requests
    global _runtime_mcp_endpoint
    _ALLOWED_RUNTIME_TOOLS.clear()
    _runtime_mcp_endpoint = None
    if not tools:
        return
    try:
        for t in tools:
            name = None
            fn = t.get("function") if "function" in t else t
            if isinstance(fn, dict):
                name = fn.get("name") or t.get("name")
            else:
                name = t.get("name")
            if name:
                _ALLOWED_RUNTIME_TOOLS.add(str(name))
            # Detect per-tool endpoint extension
            ext_ep = (
                t.get("x-mcp-endpoint")
                or t.get("x_mcp_endpoint")
                or (
                    isinstance(t.get("function"), dict)
                    and t["function"].get("x-mcp-endpoint")
                )
                or None
            )
            if ext_ep and not mcp_endpoint:
                mcp_endpoint = ext_ep
        # Capture per-request MCP endpoint if provided (explicit or via tool extension)
        if mcp_endpoint:
            _runtime_mcp_endpoint = mcp_endpoint
    except Exception:
        # be forgiving on malformed tools
        pass


async def execute_tool_call(name: str, arguments_json: str) -> str:
    """Execute registered tools and return stringified result. Unknown tools return descriptive errors.
    Fully asynchronous: built-in functions execute directly; MCP path uses async httpx client.
    """
    try:
        params = json.loads(arguments_json or "{}")
    except Exception:
        params = {}

    func = FUNCTION_REGISTRY.get(name)
    if not func:
        # If tool is not built-in but declared, try MCP adapter if configured (env or per-request)
        if name in _ALLOWED_RUNTIME_TOOLS:
            try:
                from api_utils.mcp_adapter import (
                    execute_mcp_tool,
                    execute_mcp_tool_with_endpoint,
                )

                if _runtime_mcp_endpoint:
                    return await execute_mcp_tool_with_endpoint(
                        _runtime_mcp_endpoint, name, params
                    )
                if os.environ.get("MCP_HTTP_ENDPOINT"):
                    return await execute_mcp_tool(name, params)
            except asyncio.CancelledError:
                raise
            except Exception as e:
                return json.dumps(
                    {"error": f"MCP execution failed: {e}"}, ensure_ascii=False
                )
        return json.dumps(
            {"error": f"Unknown tool: {name}", "arguments": params}, ensure_ascii=False
        )

    try:
        result = func(params)
        return json.dumps(result, ensure_ascii=False)
    except Exception as e:
        return json.dumps({"error": f"Execution failed: {e}"}, ensure_ascii=False)