Spaces:
Paused
Paused
File size: 7,903 Bytes
a5784e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | import json
import os
from typing import Any, Dict, List, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import api_utils.tools_registry
from api_utils.tools_registry import (
execute_tool_call,
register_runtime_tools,
tool_echo,
tool_get_current_time,
tool_sum,
)
@pytest.fixture(autouse=True)
def cleanup_registry():
"""Reset the registry state before and after each test."""
api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.clear()
api_utils.tools_registry._runtime_mcp_endpoint = None
yield
api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.clear()
api_utils.tools_registry._runtime_mcp_endpoint = None
def test_tool_get_current_time():
"""Test get_current_time tool returns formatted timestamp."""
result = tool_get_current_time({})
assert "current_time" in result
# Basic format check
assert result["current_time"]
def test_tool_echo():
"""Test echo tool returns input parameters."""
params = {"key": "value"}
result = tool_echo(params)
assert result["echo"] == params
def test_tool_sum():
"""Test sum tool handles valid, invalid, and missing values."""
# Valid sum
result = tool_sum({"values": [1, 2, 3]})
assert result["sum"] == 6.0
assert result["count"] == 3
# Invalid values (non-numeric)
result = tool_sum({"values": ["a", "b"]})
assert result["sum"] is None
assert result["count"] == 2
# Not a list
result = tool_sum({"values": "not a list"})
assert result["sum"] is None
assert result["count"] == 0
# Missing key
result = tool_sum({})
assert result["sum"] is None
assert result["count"] == 0
def test_register_runtime_tools_basic():
"""Test registering runtime tools with function and name fields."""
tools = [{"function": {"name": "tool1"}}, {"name": "tool2"}]
register_runtime_tools(tools)
assert "tool1" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS
assert "tool2" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS
def test_register_runtime_tools_empty():
"""Test registering empty or None tool lists."""
register_runtime_tools([])
assert len(api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS) == 0
register_runtime_tools(None)
assert len(api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS) == 0
def test_register_runtime_tools_malformed():
"""Test registering malformed tool definitions doesn't crash."""
# Should not crash
register_runtime_tools(cast(List[Dict[str, Any]], ["not a dict"]))
# Should handle partially malformed
register_runtime_tools([{"no_name": "foo"}])
def test_register_runtime_tools_mcp_endpoint():
"""Test MCP endpoint registration via argument and tool extensions."""
# Via argument - needs at least one tool to process
register_runtime_tools([{"name": "dummy"}], mcp_endpoint="http://mcp")
assert api_utils.tools_registry._runtime_mcp_endpoint == "http://mcp"
# Reset
register_runtime_tools([])
assert api_utils.tools_registry._runtime_mcp_endpoint is None
# Via tool extension
tools = [{"function": {"name": "mcp_tool", "x-mcp-endpoint": "http://tool-mcp"}}]
register_runtime_tools(tools)
assert "mcp_tool" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS
assert api_utils.tools_registry._runtime_mcp_endpoint == "http://tool-mcp"
# Top level x-mcp-endpoint
tools = [{"name": "mcp_tool_2", "x_mcp_endpoint": "http://tool-mcp-2"}]
register_runtime_tools(tools)
assert api_utils.tools_registry._runtime_mcp_endpoint == "http://tool-mcp-2"
def test_register_runtime_tools_exceptions():
"""Test exception handling during tool registration."""
# Test line 55: function is not a dict
tools = [{"function": "not_a_dict", "name": "tool_weird"}]
register_runtime_tools(tools)
assert "tool_weird" in api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS
# Test line 72-74: Exception handling (e.g. tools is not iterable but truthy)
register_runtime_tools(
cast(List[Dict[str, Any]], 123)
) # raises TypeError, caught by except
# Should safely pass without error
# Exception during iteration
class BadTools:
def __iter__(self):
raise ValueError("Bad")
register_runtime_tools(cast(List[Dict[str, Any]], BadTools()))
@pytest.mark.asyncio
async def test_execute_tool_call_builtin():
# Echo
args = json.dumps({"msg": "hello"})
result = await execute_tool_call("echo", args)
data = json.loads(result)
assert data["echo"] == {"msg": "hello"}
# Sum
args = json.dumps({"values": [10, 20]})
result = await execute_tool_call("sum", args)
data = json.loads(result)
assert data["sum"] == 30.0
@pytest.mark.asyncio
async def test_execute_tool_call_invalid_json():
# Should fallback to empty dict
result = await execute_tool_call("echo", "{invalid")
data = json.loads(result)
assert data["echo"] == {}
@pytest.mark.asyncio
async def test_execute_tool_call_unknown():
result = await execute_tool_call("unknown_tool", "{}")
data = json.loads(result)
assert "error" in data
assert "Unknown tool" in data["error"]
@pytest.mark.asyncio
async def test_execute_tool_call_exception():
# Mock a builtin tool raising exception
with patch.dict(
api_utils.tools_registry.FUNCTION_REGISTRY,
{"fail": MagicMock(side_effect=Exception("Boom"))},
):
result = await execute_tool_call("fail", "{}")
data = json.loads(result)
assert "error" in data
assert "Execution failed" in data["error"]
@pytest.mark.asyncio
async def test_execute_tool_call_mcp_runtime():
# Setup runtime tool
api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.add("mcp_tool")
api_utils.tools_registry._runtime_mcp_endpoint = "http://runtime-mcp"
mock_mcp = AsyncMock(return_value=json.dumps({"result": "mcp_ok"}))
mcp_adapter_mock = MagicMock()
mcp_adapter_mock.execute_mcp_tool_with_endpoint = mock_mcp
with patch.dict("sys.modules", {"api_utils.mcp_adapter": mcp_adapter_mock}):
result = await execute_tool_call("mcp_tool", '{"a": 1}')
assert result == json.dumps({"result": "mcp_ok"})
mock_mcp.assert_awaited_with("http://runtime-mcp", "mcp_tool", {"a": 1})
@pytest.mark.asyncio
async def test_execute_tool_call_mcp_env():
# Setup runtime tool allowed, but no runtime endpoint, fallback to env
api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.add("mcp_env_tool")
api_utils.tools_registry._runtime_mcp_endpoint = None
with patch.dict(os.environ, {"MCP_HTTP_ENDPOINT": "http://env-mcp"}):
mock_mcp = AsyncMock(return_value=json.dumps({"result": "env_ok"}))
mcp_adapter_mock = MagicMock()
mcp_adapter_mock.execute_mcp_tool = mock_mcp
# We also need execute_mcp_tool_with_endpoint to be present to avoid import error
mcp_adapter_mock.execute_mcp_tool_with_endpoint = AsyncMock()
with patch.dict("sys.modules", {"api_utils.mcp_adapter": mcp_adapter_mock}):
result = await execute_tool_call("mcp_env_tool", '{"b": 2}')
assert result == json.dumps({"result": "env_ok"})
mock_mcp.assert_awaited_with("mcp_env_tool", {"b": 2})
@pytest.mark.asyncio
async def test_execute_tool_call_mcp_fail():
api_utils.tools_registry._ALLOWED_RUNTIME_TOOLS.add("fail_tool")
api_utils.tools_registry._runtime_mcp_endpoint = "http://fail"
mcp_adapter_mock = MagicMock()
mcp_adapter_mock.execute_mcp_tool_with_endpoint = AsyncMock(
side_effect=Exception("MCP Down")
)
with patch.dict("sys.modules", {"api_utils.mcp_adapter": mcp_adapter_mock}):
result = await execute_tool_call("fail_tool", "{}")
data = json.loads(result)
assert "error" in data
assert "MCP execution failed" in data["error"]
|