File size: 4,647 Bytes
dbb04e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import sys
import types
import pytest
from mnemocore.core.config import HAIMConfig, MCPConfig, SecurityConfig
from mnemocore.mcp.adapters.api_adapter import MnemoCoreAPIError
from mnemocore.mcp import server as mcp_server
class FakeFastMCP:
def __init__(self, name: str):
self.name = name
self.tools = {}
self.run_calls = []
def tool(self):
def decorator(fn):
self.tools[fn.__name__] = fn
return fn
return decorator
def run(self, **kwargs):
self.run_calls.append(kwargs)
class FakeAdapter:
def store(self, payload):
return {"memory_id": "mem_1", "payload": payload}
def query(self, payload):
return {"results": [{"id": "mem_1", "score": 1.0}], "payload": payload}
def get_memory(self, memory_id: str):
return {"id": memory_id, "content": "hello"}
def delete_memory(self, memory_id: str):
return {"deleted": memory_id}
def stats(self):
return {"engine_version": "3.5.1"}
def health(self):
return {"status": "healthy"}
def _install_fake_mcp_modules(monkeypatch):
mcp_mod = types.ModuleType("mcp")
server_mod = types.ModuleType("mcp.server")
fastmcp_mod = types.ModuleType("mcp.server.fastmcp")
fastmcp_mod.FastMCP = FakeFastMCP
monkeypatch.setitem(sys.modules, "mcp", mcp_mod)
monkeypatch.setitem(sys.modules, "mcp.server", server_mod)
monkeypatch.setitem(sys.modules, "mcp.server.fastmcp", fastmcp_mod)
def test_build_server_registers_only_allowlisted_tools(monkeypatch):
_install_fake_mcp_modules(monkeypatch)
monkeypatch.setattr(mcp_server, "MnemoCoreAPIAdapter", lambda *args, **kwargs: FakeAdapter())
config = HAIMConfig(
security=SecurityConfig(api_key="test-key"),
mcp=MCPConfig(
enabled=True,
allow_tools=["memory_health", "memory_stats"],
api_key="test-key",
),
)
server = mcp_server.build_server(config)
assert sorted(server.tools.keys()) == ["memory_health", "memory_stats"]
health_result = server.tools["memory_health"]()
assert health_result["ok"] is True
assert health_result["data"]["status"] == "healthy"
def test_tool_error_handling(monkeypatch):
class ErrorAdapter(FakeAdapter):
def health(self):
raise MnemoCoreAPIError("boom", status_code=503)
_install_fake_mcp_modules(monkeypatch)
monkeypatch.setattr(mcp_server, "MnemoCoreAPIAdapter", lambda *args, **kwargs: ErrorAdapter())
config = HAIMConfig(
security=SecurityConfig(api_key="test-key"),
mcp=MCPConfig(enabled=True, allow_tools=["memory_health"], api_key="test-key"),
)
server = mcp_server.build_server(config)
result = server.tools["memory_health"]()
assert result["ok"] is False
assert "boom" in result["error"]
def test_main_runs_with_stdio_transport(monkeypatch):
fake_server = FakeFastMCP("x")
monkeypatch.setattr(
mcp_server,
"get_config",
lambda: HAIMConfig(
security=SecurityConfig(api_key="k"),
mcp=MCPConfig(enabled=True, transport="stdio", api_key="k"),
),
)
monkeypatch.setattr(mcp_server, "build_server", lambda cfg: fake_server)
mcp_server.main()
assert fake_server.run_calls == [{"transport": "stdio"}]
def test_main_runs_with_sse_transport(monkeypatch):
fake_server = FakeFastMCP("x")
monkeypatch.setattr(
mcp_server,
"get_config",
lambda: HAIMConfig(
security=SecurityConfig(api_key="k"),
mcp=MCPConfig(
enabled=True,
transport="sse",
host="127.0.0.1",
port=8222,
api_key="k",
),
),
)
monkeypatch.setattr(mcp_server, "build_server", lambda cfg: fake_server)
mcp_server.main()
assert fake_server.run_calls == [
{"transport": "sse", "host": "127.0.0.1", "port": 8222}
]
def test_main_rejects_unknown_transport(monkeypatch):
monkeypatch.setattr(
mcp_server,
"get_config",
lambda: HAIMConfig(
security=SecurityConfig(api_key="k"),
mcp=MCPConfig(enabled=True, transport="unknown", api_key="k"),
),
)
monkeypatch.setattr(mcp_server, "build_server", lambda cfg: FakeFastMCP("x"))
with pytest.raises((ValueError, Exception), match="Unsupported transport"):
mcp_server.main()
|