| """Tests for model_tools.py — function call dispatch, agent-loop interception, legacy toolsets.""" |
|
|
| import json |
| from unittest.mock import call, patch |
|
|
| import pytest |
|
|
| from model_tools import ( |
| handle_function_call, |
| get_all_tool_names, |
| get_toolset_for_tool, |
| _AGENT_LOOP_TOOLS, |
| _LEGACY_TOOLSET_MAP, |
| TOOL_TO_TOOLSET_MAP, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class TestHandleFunctionCall: |
| def test_agent_loop_tool_returns_error(self): |
| for tool_name in _AGENT_LOOP_TOOLS: |
| result = json.loads(handle_function_call(tool_name, {})) |
| assert "error" in result |
| assert "agent loop" in result["error"].lower() |
|
|
| def test_unknown_tool_returns_error(self): |
| result = json.loads(handle_function_call("totally_fake_tool_xyz", {})) |
| assert "error" in result |
| assert "totally_fake_tool_xyz" in result["error"] |
|
|
| def test_exception_returns_json_error(self): |
| |
| result = handle_function_call("web_search", None) |
| parsed = json.loads(result) |
| assert isinstance(parsed, dict) |
| assert "error" in parsed |
| assert len(parsed["error"]) > 0 |
| assert "error" in parsed["error"].lower() or "failed" in parsed["error"].lower() |
|
|
| def test_tool_hooks_receive_session_and_tool_call_ids(self): |
| with ( |
| patch("model_tools.registry.dispatch", return_value='{"ok":true}'), |
| patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook, |
| ): |
| result = handle_function_call( |
| "web_search", |
| {"q": "test"}, |
| task_id="task-1", |
| tool_call_id="call-1", |
| session_id="session-1", |
| ) |
|
|
| assert result == '{"ok":true}' |
| assert mock_invoke_hook.call_args_list == [ |
| call( |
| "pre_tool_call", |
| tool_name="web_search", |
| args={"q": "test"}, |
| task_id="task-1", |
| session_id="session-1", |
| tool_call_id="call-1", |
| ), |
| call( |
| "post_tool_call", |
| tool_name="web_search", |
| args={"q": "test"}, |
| result='{"ok":true}', |
| task_id="task-1", |
| session_id="session-1", |
| tool_call_id="call-1", |
| ), |
| call( |
| "transform_tool_result", |
| tool_name="web_search", |
| args={"q": "test"}, |
| result='{"ok":true}', |
| task_id="task-1", |
| session_id="session-1", |
| tool_call_id="call-1", |
| ), |
| ] |
|
|
|
|
| |
| |
| |
|
|
| class TestAgentLoopTools: |
| def test_expected_tools_in_set(self): |
| assert "todo" in _AGENT_LOOP_TOOLS |
| assert "memory" in _AGENT_LOOP_TOOLS |
| assert "session_search" in _AGENT_LOOP_TOOLS |
| assert "delegate_task" in _AGENT_LOOP_TOOLS |
|
|
| def test_no_regular_tools_in_set(self): |
| assert "web_search" not in _AGENT_LOOP_TOOLS |
| assert "terminal" not in _AGENT_LOOP_TOOLS |
|
|
|
|
| |
| |
| |
|
|
| class TestPreToolCallBlocking: |
| """Verify that pre_tool_call hooks can block tool execution.""" |
|
|
| def test_blocked_tool_returns_error_and_skips_dispatch(self, monkeypatch): |
| def fake_invoke_hook(hook_name, **kwargs): |
| if hook_name == "pre_tool_call": |
| return [{"action": "block", "message": "Blocked by policy"}] |
| return [] |
|
|
| dispatch_called = False |
| _orig_dispatch = None |
|
|
| def fake_dispatch(*args, **kwargs): |
| nonlocal dispatch_called |
| dispatch_called = True |
| raise AssertionError("dispatch should not run when blocked") |
|
|
| monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) |
| monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch) |
|
|
| result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) |
| assert result == {"error": "Blocked by policy"} |
| assert not dispatch_called |
|
|
| def test_blocked_tool_skips_read_loop_notification(self, monkeypatch): |
| notifications = [] |
|
|
| def fake_invoke_hook(hook_name, **kwargs): |
| if hook_name == "pre_tool_call": |
| return [{"action": "block", "message": "Blocked"}] |
| return [] |
|
|
| monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) |
| monkeypatch.setattr("model_tools.registry.dispatch", |
| lambda *a, **kw: (_ for _ in ()).throw(AssertionError("should not run"))) |
| monkeypatch.setattr("tools.file_tools.notify_other_tool_call", |
| lambda task_id: notifications.append(task_id)) |
|
|
| result = json.loads(handle_function_call("web_search", {"q": "test"}, task_id="t1")) |
| assert result == {"error": "Blocked"} |
| assert notifications == [] |
|
|
| def test_invalid_hook_returns_do_not_block(self, monkeypatch): |
| """Malformed hook returns should be ignored — tool executes normally.""" |
| def fake_invoke_hook(hook_name, **kwargs): |
| if hook_name == "pre_tool_call": |
| return [ |
| "block", |
| {"action": "block"}, |
| {"action": "deny", "message": "nope"}, |
| ] |
| return [] |
|
|
| monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) |
| monkeypatch.setattr("model_tools.registry.dispatch", |
| lambda *a, **kw: json.dumps({"ok": True})) |
|
|
| result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) |
| assert result == {"ok": True} |
|
|
| def test_skip_flag_prevents_double_block_check(self, monkeypatch): |
| """When skip_pre_tool_call_hook=True, blocking is not checked (caller did it).""" |
| hook_calls = [] |
|
|
| def fake_invoke_hook(hook_name, **kwargs): |
| hook_calls.append(hook_name) |
| return [] |
|
|
| monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) |
| monkeypatch.setattr("model_tools.registry.dispatch", |
| lambda *a, **kw: json.dumps({"ok": True})) |
|
|
| handle_function_call("web_search", {"q": "test"}, task_id="t1", |
| skip_pre_tool_call_hook=True) |
|
|
| |
| |
| assert "pre_tool_call" in hook_calls |
| assert "post_tool_call" in hook_calls |
|
|
|
|
| |
| |
| |
|
|
| class TestLegacyToolsetMap: |
| def test_expected_legacy_names(self): |
| expected = [ |
| "web_tools", "terminal_tools", "vision_tools", "moa_tools", |
| "image_tools", "skills_tools", "browser_tools", "cronjob_tools", |
| "rl_tools", "file_tools", "tts_tools", |
| ] |
| for name in expected: |
| assert name in _LEGACY_TOOLSET_MAP, f"Missing legacy toolset: {name}" |
|
|
| def test_values_are_lists_of_strings(self): |
| for name, tools in _LEGACY_TOOLSET_MAP.items(): |
| assert isinstance(tools, list), f"{name} is not a list" |
| for tool in tools: |
| assert isinstance(tool, str), f"{name} contains non-string: {tool}" |
|
|
|
|
| |
| |
| |
|
|
| class TestBackwardCompat: |
| def test_get_all_tool_names_returns_list(self): |
| names = get_all_tool_names() |
| assert isinstance(names, list) |
| assert len(names) > 0 |
| |
| assert "web_search" in names |
| assert "terminal" in names |
|
|
| def test_get_toolset_for_tool(self): |
| result = get_toolset_for_tool("web_search") |
| assert result is not None |
| assert isinstance(result, str) |
|
|
| def test_get_toolset_for_unknown_tool(self): |
| result = get_toolset_for_tool("totally_nonexistent_tool") |
| assert result is None |
|
|
| def test_tool_to_toolset_map(self): |
| assert isinstance(TOOL_TO_TOOLSET_MAP, dict) |
| assert len(TOOL_TO_TOOLSET_MAP) > 0 |
|
|