File size: 8,854 Bytes
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
User-defined tool adapter (HTTP + MCP) for Agno Agent tools.
"""

from __future__ import annotations

import asyncio
import json
import os
import re
from datetime import timedelta
from typing import Any
from urllib.parse import quote, urlencode, urlparse

import httpx
from agno.tools import Toolkit
from agno.tools.function import Function

try:
    from agno.tools.mcp import MCPTools
    from agno.tools.mcp.params import SSEClientParams, StreamableHTTPClientParams
except Exception:  # pragma: no cover - optional dependency
    MCPTools = None
    StreamableHTTPClientParams = None
    SSEClientParams = None


_mcp_tools_cache: dict[str, Any] = {}
_mcp_tools_lock = asyncio.Lock()
_REMOTE_MCP_TIMEOUT_SECONDS = max(15, int(os.getenv("REMOTE_MCP_TIMEOUT_SECONDS", "45")))
_REMOTE_MCP_SSE_READ_TIMEOUT_SECONDS = max(
    _REMOTE_MCP_TIMEOUT_SECONDS,
    int(os.getenv("REMOTE_MCP_SSE_READ_TIMEOUT_SECONDS", "300")),
)


def _normalize_mcp_transport(transport: str | None) -> str:
    normalized = str(transport or "streamable-http").strip().lower()
    if normalized in {"streamable_http", "streamablehttp", "http"}:
        return "streamable-http"
    if normalized == "sse":
        return "sse"
    if normalized == "stdio":
        return "stdio"
    return "streamable-http"


def _replace_template(template: Any, args: dict[str, Any]) -> Any:
    if not isinstance(template, str):
        return template
    def _replace(match):
        key = match.group(1)
        if key in args:
            return quote(str(args[key]), safe="")
        return match.group(0)
    return _TEMPLATE_REGEX.sub(_replace, template)


def _replace_templates(params: dict[str, Any], args: dict[str, Any]) -> dict[str, Any]:
    result: dict[str, Any] = {}
    for key, value in (params or {}).items():
        if isinstance(value, str):
            result[key] = _replace_template(value, args)
        else:
            result[key] = value
    return result


def _build_url(base_url: str, params: dict[str, Any]) -> str:
    if not params:
        return base_url
    query = urlencode({k: v for k, v in params.items() if v not in (None, "", [])}, doseq=True)
    if not query:
        return base_url
    return f"{base_url}{'&' if '?' in base_url else '?'}{query}"


def _validate_domain(url: str, allowed_domains: list[str]) -> None:
    if not allowed_domains:
        raise ValueError("No allowed domains configured for this tool")
    hostname = urlparse(url).hostname or ""
    for domain in allowed_domains:
        if hostname == domain:
            return
        if domain.startswith("*."):
            base = domain[2:]
            if hostname == base or hostname.endswith(f".{base}"):
                return
    raise ValueError(f"Domain {hostname} is not in the allowed list: {', '.join(allowed_domains)}")


async def _execute_http_tool(tool: dict[str, Any], args: dict[str, Any]) -> dict[str, Any]:
    config = tool.get("config") or {}
    url = config.get("url")
    method = (config.get("method") or "GET").upper()
    params = config.get("params") or {}
    headers = config.get("headers") or {}
    security = config.get("security") or {}

    if not url:
        raise ValueError("HTTP tool missing url")

    allowed_domains = security.get("allowedDomains") or []
    max_response_size = int(security.get("maxResponseSize") or 1000000)
    timeout_ms = int(security.get("timeout") or 10000)

    final_params = _replace_templates(params, args or {})
    processed_url = _replace_template(url, args or {})
    final_url = _build_url(processed_url, final_params) if method == "GET" else processed_url

    _validate_domain(final_url, allowed_domains)

    request_headers = {"Content-Type": "application/json"}
    request_headers.update(headers or {})

    timeout = httpx.Timeout(timeout_ms / 1000.0)
    async with httpx.AsyncClient(timeout=timeout) as client:
        if method in ("GET", "HEAD"):
            response = await client.request(method, final_url, headers=request_headers)
        else:
            response = await client.request(
                method,
                final_url,
                headers=request_headers,
                json=final_params,
            )

    response.raise_for_status()
    text = response.text
    if len(text) > max_response_size:
        raise ValueError(
            f"Response size {len(text)} bytes exceeds limit of {max_response_size} bytes"
        )

    try:
        return json.loads(text)
    except Exception:
        return {"data": text}


async def _get_mcp_tools(server_url: str, transport: str, headers: dict[str, Any]) -> Any:
    if MCPTools is None:
        raise RuntimeError("`mcp` not installed. Please install using `pip install mcp`.")

    normalized_transport = _normalize_mcp_transport(transport)

    key = json.dumps(
        {
            "url": server_url,
            "transport": normalized_transport,
            "headers": headers,
        },
        sort_keys=True,
    )

    async with _mcp_tools_lock:
        if key in _mcp_tools_cache:
            return _mcp_tools_cache[key]

        server_params = None
        if normalized_transport == "sse":
            server_params = (
                SSEClientParams(
                    url=server_url,
                    headers=headers,
                    timeout=_REMOTE_MCP_TIMEOUT_SECONDS,
                    sse_read_timeout=_REMOTE_MCP_SSE_READ_TIMEOUT_SECONDS,
                )
                if SSEClientParams
                else None
            )
        else:
            server_params = (
                StreamableHTTPClientParams(
                    url=server_url,
                    headers=headers,
                    timeout=timedelta(seconds=_REMOTE_MCP_TIMEOUT_SECONDS),
                    sse_read_timeout=timedelta(seconds=_REMOTE_MCP_SSE_READ_TIMEOUT_SECONDS),
                )
                if StreamableHTTPClientParams
                else None
            )

        tools = MCPTools(
            url=server_url,
            transport=normalized_transport,
            server_params=server_params,
            timeout_seconds=_REMOTE_MCP_TIMEOUT_SECONDS,
        )
        await tools.connect()
        _mcp_tools_cache[key] = tools
        return tools


async def _execute_mcp_tool(tool: dict[str, Any], args: dict[str, Any]) -> Any:
    if MCPTools is None:
        raise RuntimeError("`mcp` not installed. Please install using `pip install mcp`.")

    config = tool.get("config") or {}
    server_url = config.get("serverUrl") or config.get("server_url") or config.get("url")
    if not server_url:
        raise ValueError("MCP tool missing serverUrl")
    transport = _normalize_mcp_transport(
        config.get("transport") or config.get("serverTransport") or "streamable-http"
    )
    headers = dict(config.get("headers") or {})
    bearer = config.get("bearerToken") or config.get("authToken")
    if bearer and "Authorization" not in headers:
        headers["Authorization"] = f"Bearer {bearer}"

    tool_name = config.get("toolName") or tool.get("name")
    if not tool_name:
        raise ValueError("MCP tool missing toolName")

    tools = await _get_mcp_tools(server_url, transport, headers)
    functions = tools.get_async_functions()
    fn = functions.get(tool_name) or tools.get_functions().get(tool_name)
    if not fn or not fn.entrypoint:
        raise RuntimeError(f"MCP tool '{tool_name}' not found in server")

    return await fn.entrypoint(**(args or {}))


def build_user_tools_toolkit(user_tools: list[dict[str, Any]] | None) -> Toolkit | None:
    if not user_tools:
        return None

    functions: list[Function] = []
    for tool in user_tools:
        name = tool.get("name")
        if not name:
            continue
        description = tool.get("description") or ""
        parameters = (
            tool.get("input_schema")
            or tool.get("inputSchema")
            or tool.get("parameters")
            or {
            "type": "object",
            "properties": {},
        }
        )

        if tool.get("type") == "mcp":
            async def _mcp_entrypoint(*, _tool=tool, **kwargs):
                return await _execute_mcp_tool(_tool, kwargs)

            entrypoint = _mcp_entrypoint
        else:
            async def _http_entrypoint(*, _tool=tool, **kwargs):
                return await _execute_http_tool(_tool, kwargs)

            entrypoint = _http_entrypoint

        functions.append(
            Function(
                name=name,
                description=description,
                parameters=parameters,
                entrypoint=entrypoint,
                skip_entrypoint_processing=True,
            )
        )

    if not functions:
        return None
    return Toolkit(name="user_tools", tools=functions)


_TEMPLATE_REGEX = re.compile(r"\{\{(\w+)\}\}")