""" MnemoCore MCP Server ==================== MCP bridge exposing MnemoCore API tools for agent clients. """ from typing import Any, Callable, Dict from loguru import logger from mnemocore.core.config import get_config, HAIMConfig from mnemocore.mcp.adapters.api_adapter import MnemoCoreAPIAdapter, MnemoCoreAPIError from mnemocore.mcp.schemas import ( StoreToolInput, QueryToolInput, MemoryIdInput, ObserveToolInput, ContextToolInput, EpisodeToolInput ) from mnemocore.core.exceptions import ( DependencyMissingError, UnsupportedTransportError, ) def _result_ok(data: Dict[str, Any]) -> Dict[str, Any]: return {"ok": True, "data": data} def _result_error(message: str) -> Dict[str, Any]: return {"ok": False, "error": message} def build_server(config: HAIMConfig | None = None): cfg = config or get_config() try: from mcp.server.fastmcp import FastMCP except ImportError as exc: raise DependencyMissingError( dependency="mcp", message="Install package 'mcp' to run the MCP server." ) from exc adapter = MnemoCoreAPIAdapter( base_url=cfg.mcp.api_base_url, api_key=cfg.mcp.api_key or cfg.security.api_key, timeout_seconds=cfg.mcp.timeout_seconds, ) server = FastMCP("MnemoCore MCP") allow_tools = set(cfg.mcp.allow_tools) def register_tool(name: str, fn: Callable[[], None]) -> None: if name in allow_tools: fn() else: logger.info("Skipping disabled MCP tool: %s", name) def with_error_handling(call: Callable[[], Dict[str, Any]]) -> Dict[str, Any]: try: return _result_ok(call()) except MnemoCoreAPIError as exc: return _result_error(str(exc)) except Exception as exc: return _result_error(f"Unexpected error: {exc}") def register_memory_store() -> None: @server.tool() def memory_store( content: str, metadata: Dict[str, Any] | None = None, agent_id: str | None = None, ttl: int | None = None, ) -> Dict[str, Any]: payload = StoreToolInput( content=content, metadata=metadata, agent_id=agent_id, ttl=ttl, ).model_dump(exclude_none=True) return with_error_handling(lambda: adapter.store(payload)) def register_memory_query() -> None: @server.tool() def memory_query( query: str, top_k: int = 5, agent_id: str | None = None, ) -> Dict[str, Any]: payload = QueryToolInput( query=query, top_k=top_k, agent_id=agent_id, ).model_dump(exclude_none=True) return with_error_handling(lambda: adapter.query(payload)) def register_memory_get() -> None: @server.tool() def memory_get(memory_id: str) -> Dict[str, Any]: data = MemoryIdInput(memory_id=memory_id) return with_error_handling(lambda: adapter.get_memory(data.memory_id)) def register_memory_delete() -> None: @server.tool() def memory_delete(memory_id: str) -> Dict[str, Any]: data = MemoryIdInput(memory_id=memory_id) return with_error_handling(lambda: adapter.delete_memory(data.memory_id)) def register_memory_stats() -> None: @server.tool() def memory_stats() -> Dict[str, Any]: return with_error_handling(adapter.stats) def register_memory_health() -> None: @server.tool() def memory_health() -> Dict[str, Any]: return with_error_handling(adapter.health) # --- Phase 5: Cognitive Client Tools --- def register_store_observation() -> None: @server.tool() def store_observation( agent_id: str, content: str, kind: str = "observation", importance: float = 0.5, tags: list[str] | None = None ) -> Dict[str, Any]: payload = ObserveToolInput( agent_id=agent_id, content=content, kind=kind, importance=importance, tags=tags ).model_dump(exclude_none=True) return with_error_handling(lambda: adapter.observe_context(payload)) def register_recall_context() -> None: @server.tool() def recall_context(agent_id: str, limit: int = 16) -> Dict[str, Any]: data = ContextToolInput(agent_id=agent_id, limit=limit) return with_error_handling(lambda: adapter.get_working_context(data.agent_id, data.limit)) def register_start_episode() -> None: @server.tool() def start_episode(agent_id: str, goal: str, context: str | None = None) -> Dict[str, Any]: payload = EpisodeToolInput( agent_id=agent_id, goal=goal, context=context ).model_dump(exclude_none=True) return with_error_handling(lambda: adapter.start_episode(payload)) register_tool("memory_store", register_memory_store) register_tool("memory_query", register_memory_query) register_tool("memory_get", register_memory_get) register_tool("memory_delete", register_memory_delete) register_tool("memory_stats", register_memory_stats) register_tool("memory_health", register_memory_health) register_tool("store_observation", register_store_observation) register_tool("recall_context", register_recall_context) register_tool("start_episode", register_start_episode) return server def main() -> None: cfg = get_config() if not cfg.mcp.enabled: logger.warning("MCP is disabled in config (haim.mcp.enabled=false)") server = build_server(cfg) if cfg.mcp.transport == "stdio": server.run(transport="stdio") return if cfg.mcp.transport == "sse": server.run(transport="sse", host=cfg.mcp.host, port=cfg.mcp.port) return raise UnsupportedTransportError( transport=cfg.mcp.transport, supported_transports=["stdio", "sse"] ) if __name__ == "__main__": main()