Spaces:
Paused
Paused
File size: 4,057 Bytes
a5784e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import asyncio
import json
import os
import time
from typing import Any, Dict, List, Optional, Set
def tool_get_current_time(params: Dict[str, Any]) -> Dict[str, Any]:
return {"current_time": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())}
def tool_echo(params: Dict[str, Any]) -> Dict[str, Any]:
return {"echo": params}
def tool_sum(params: Dict[str, Any]) -> Dict[str, Any]:
values = params.get("values")
if isinstance(values, list):
try:
total = sum(float(v) for v in values)
except Exception:
total = None
else:
total = None
return {"sum": total, "count": len(values) if isinstance(values, list) else 0}
FUNCTION_REGISTRY = {
"get_current_time": tool_get_current_time,
"echo": tool_echo,
"sum": tool_sum,
}
# Runtime-allowed tool names from incoming requests (OpenAI tools array)
_ALLOWED_RUNTIME_TOOLS: Set[str] = set()
_runtime_mcp_endpoint: Optional[str] = None
def register_runtime_tools(
tools: Optional[List[Dict[str, Any]]], mcp_endpoint: Optional[str] = None
) -> None:
"""Register tool names declared in the request as allowed.
The server may delegate unknown tools to MCP if configured.
"""
# Reset per-request registry to avoid leakage across requests
global _runtime_mcp_endpoint
_ALLOWED_RUNTIME_TOOLS.clear()
_runtime_mcp_endpoint = None
if not tools:
return
try:
for t in tools:
name = None
fn = t.get("function") if "function" in t else t
if isinstance(fn, dict):
name = fn.get("name") or t.get("name")
else:
name = t.get("name")
if name:
_ALLOWED_RUNTIME_TOOLS.add(str(name))
# Detect per-tool endpoint extension
ext_ep = (
t.get("x-mcp-endpoint")
or t.get("x_mcp_endpoint")
or (
isinstance(t.get("function"), dict)
and t["function"].get("x-mcp-endpoint")
)
or None
)
if ext_ep and not mcp_endpoint:
mcp_endpoint = ext_ep
# Capture per-request MCP endpoint if provided (explicit or via tool extension)
if mcp_endpoint:
_runtime_mcp_endpoint = mcp_endpoint
except Exception:
# be forgiving on malformed tools
pass
async def execute_tool_call(name: str, arguments_json: str) -> str:
"""Execute registered tools and return stringified result. Unknown tools return descriptive errors.
Fully asynchronous: built-in functions execute directly; MCP path uses async httpx client.
"""
try:
params = json.loads(arguments_json or "{}")
except Exception:
params = {}
func = FUNCTION_REGISTRY.get(name)
if not func:
# If tool is not built-in but declared, try MCP adapter if configured (env or per-request)
if name in _ALLOWED_RUNTIME_TOOLS:
try:
from api_utils.mcp_adapter import (
execute_mcp_tool,
execute_mcp_tool_with_endpoint,
)
if _runtime_mcp_endpoint:
return await execute_mcp_tool_with_endpoint(
_runtime_mcp_endpoint, name, params
)
if os.environ.get("MCP_HTTP_ENDPOINT"):
return await execute_mcp_tool(name, params)
except asyncio.CancelledError:
raise
except Exception as e:
return json.dumps(
{"error": f"MCP execution failed: {e}"}, ensure_ascii=False
)
return json.dumps(
{"error": f"Unknown tool: {name}", "arguments": params}, ensure_ascii=False
)
try:
result = func(params)
return json.dumps(result, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": f"Execution failed: {e}"}, ensure_ascii=False)
|