File size: 7,903 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import json
import os
from typing import Any, Dict, List, cast
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

import api_utils.tools_registry
from api_utils.tools_registry import (
    execute_tool_call,
    register_runtime_tools,
    tool_echo,
    tool_get_current_time,
    tool_sum,
)


@pytest.fixture(autouse=True)
def cleanup_registry():
    """Reset the registry state before and after each test."""
    api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.clear()
    api_utils.tools_registry._runtime_mcp_endpoint = None
    yield
    api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.clear()
    api_utils.tools_registry._runtime_mcp_endpoint = None


def test_tool_get_current_time():
    """Test get_current_time tool returns formatted timestamp."""
    result = tool_get_current_time({})
    assert "current_time" in result
    # Basic format check
    assert result["current_time"]


def test_tool_echo():
    """Test echo tool returns input parameters."""
    params = {"key": "value"}
    result = tool_echo(params)
    assert result["echo"] == params


def test_tool_sum():
    """Test sum tool handles valid, invalid, and missing values."""
    # Valid sum
    result = tool_sum({"values": [1, 2, 3]})
    assert result["sum"] == 6.0
    assert result["count"] == 3

    # Invalid values (non-numeric)
    result = tool_sum({"values": ["a", "b"]})
    assert result["sum"] is None
    assert result["count"] == 2

    # Not a list
    result = tool_sum({"values": "not a list"})
    assert result["sum"] is None
    assert result["count"] == 0

    # Missing key
    result = tool_sum({})
    assert result["sum"] is None
    assert result["count"] == 0


def test_register_runtime_tools_basic():
    """Test registering runtime tools with function and name fields."""
    tools = [{"function": {"name": "tool1"}}, {"name": "tool2"}]
    register_runtime_tools(tools)
    assert "tool1" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS
    assert "tool2" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS


def test_register_runtime_tools_empty():
    """Test registering empty or None tool lists."""
    register_runtime_tools([])
    assert len(api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS) == 0

    register_runtime_tools(None)
    assert len(api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS) == 0


def test_register_runtime_tools_malformed():
    """Test registering malformed tool definitions doesn't crash."""
    # Should not crash
    register_runtime_tools(cast(List[Dict[str, Any]], ["not a dict"]))
    # Should handle partially malformed
    register_runtime_tools([{"no_name": "foo"}])


def test_register_runtime_tools_mcp_endpoint():
    """Test MCP endpoint registration via argument and tool extensions."""
    # Via argument - needs at least one tool to process
    register_runtime_tools([{"name": "dummy"}], mcp_endpoint="http://mcp")
    assert api_utils.tools_registry._runtime_mcp_endpoint == "http://mcp"

    # Reset
    register_runtime_tools([])
    assert api_utils.tools_registry._runtime_mcp_endpoint is None

    # Via tool extension
    tools = [{"function": {"name": "mcp_tool", "x-mcp-endpoint": "http://tool-mcp"}}]
    register_runtime_tools(tools)
    assert "mcp_tool" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS
    assert api_utils.tools_registry._runtime_mcp_endpoint == "http://tool-mcp"

    # Top level x-mcp-endpoint
    tools = [{"name": "mcp_tool_2", "x_mcp_endpoint": "http://tool-mcp-2"}]
    register_runtime_tools(tools)
    assert api_utils.tools_registry._runtime_mcp_endpoint == "http://tool-mcp-2"


def test_register_runtime_tools_exceptions():
    """Test exception handling during tool registration."""
    # Test line 55: function is not a dict
    tools = [{"function": "not_a_dict", "name": "tool_weird"}]
    register_runtime_tools(tools)
    assert "tool_weird" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS

    # Test line 72-74: Exception handling (e.g. tools is not iterable but truthy)
    register_runtime_tools(
        cast(List[Dict[str, Any]], 123)
    )  # raises TypeError, caught by except
    # Should safely pass without error

    # Exception during iteration
    class BadTools:
        def __iter__(self):
            raise ValueError("Bad")

    register_runtime_tools(cast(List[Dict[str, Any]], BadTools()))


@pytest.mark.asyncio
async def test_execute_tool_call_builtin():
    # Echo
    args = json.dumps({"msg": "hello"})
    result = await execute_tool_call("echo", args)
    data = json.loads(result)
    assert data["echo"] == {"msg": "hello"}

    # Sum
    args = json.dumps({"values": [10, 20]})
    result = await execute_tool_call("sum", args)
    data = json.loads(result)
    assert data["sum"] == 30.0


@pytest.mark.asyncio
async def test_execute_tool_call_invalid_json():
    # Should fallback to empty dict
    result = await execute_tool_call("echo", "{invalid")
    data = json.loads(result)
    assert data["echo"] == {}


@pytest.mark.asyncio
async def test_execute_tool_call_unknown():
    result = await execute_tool_call("unknown_tool", "{}")
    data = json.loads(result)
    assert "error" in data
    assert "Unknown tool" in data["error"]


@pytest.mark.asyncio
async def test_execute_tool_call_exception():
    # Mock a builtin tool raising exception
    with patch.dict(
        api_utils.tools_registry.FUNCTION_REGISTRY,
        {"fail": MagicMock(side_effect=Exception("Boom"))},
    ):
        result = await execute_tool_call("fail", "{}")
        data = json.loads(result)
        assert "error" in data
        assert "Execution failed" in data["error"]


@pytest.mark.asyncio
async def test_execute_tool_call_mcp_runtime():
    # Setup runtime tool
    api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.add("mcp_tool")
    api_utils.tools_registry._runtime_mcp_endpoint = "http://runtime-mcp"

    mock_mcp = AsyncMock(return_value=json.dumps({"result": "mcp_ok"}))
    mcp_adapter_mock = MagicMock()
    mcp_adapter_mock.execute_mcp_tool_with_endpoint = mock_mcp

    with patch.dict("sys.modules", {"api_utils.mcp_adapter": mcp_adapter_mock}):
        result = await execute_tool_call("mcp_tool", '{"a": 1}')
        assert result == json.dumps({"result": "mcp_ok"})
        mock_mcp.assert_awaited_with("http://runtime-mcp", "mcp_tool", {"a": 1})


@pytest.mark.asyncio
async def test_execute_tool_call_mcp_env():
    # Setup runtime tool allowed, but no runtime endpoint, fallback to env
    api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.add("mcp_env_tool")
    api_utils.tools_registry._runtime_mcp_endpoint = None

    with patch.dict(os.environ, {"MCP_HTTP_ENDPOINT": "http://env-mcp"}):
        mock_mcp = AsyncMock(return_value=json.dumps({"result": "env_ok"}))
        mcp_adapter_mock = MagicMock()
        mcp_adapter_mock.execute_mcp_tool = mock_mcp
        # We also need execute_mcp_tool_with_endpoint to be present to avoid import error
        mcp_adapter_mock.execute_mcp_tool_with_endpoint = AsyncMock()

        with patch.dict("sys.modules", {"api_utils.mcp_adapter": mcp_adapter_mock}):
            result = await execute_tool_call("mcp_env_tool", '{"b": 2}')
            assert result == json.dumps({"result": "env_ok"})
            mock_mcp.assert_awaited_with("mcp_env_tool", {"b": 2})


@pytest.mark.asyncio
async def test_execute_tool_call_mcp_fail():
    api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.add("fail_tool")
    api_utils.tools_registry._runtime_mcp_endpoint = "http://fail"

    mcp_adapter_mock = MagicMock()
    mcp_adapter_mock.execute_mcp_tool_with_endpoint = AsyncMock(
        side_effect=Exception("MCP Down")
    )

    with patch.dict("sys.modules", {"api_utils.mcp_adapter": mcp_adapter_mock}):
        result = await execute_tool_call("fail_tool", "{}")
        data = json.loads(result)
        assert "error" in data
        assert "MCP execution failed" in data["error"]