MnemoCore / src /mnemocore /mcp /server.py
Granis87's picture
Upload folder using huggingface_hub
c3a3710 verified
"""
MnemoCore MCP Server
====================
MCP bridge exposing MnemoCore API tools for agent clients.
"""
from typing import Any, Callable, Dict
from loguru import logger
from mnemocore.core.config import get_config, HAIMConfig
from mnemocore.mcp.adapters.api_adapter import MnemoCoreAPIAdapter, MnemoCoreAPIError
from mnemocore.mcp.schemas import (
StoreToolInput, QueryToolInput, MemoryIdInput,
ObserveToolInput, ContextToolInput, EpisodeToolInput
)
from mnemocore.core.exceptions import (
DependencyMissingError,
UnsupportedTransportError,
)
def _result_ok(data: Dict[str, Any]) -> Dict[str, Any]:
return {"ok": True, "data": data}
def _result_error(message: str) -> Dict[str, Any]:
return {"ok": False, "error": message}
def build_server(config: HAIMConfig | None = None):
cfg = config or get_config()
try:
from mcp.server.fastmcp import FastMCP
except ImportError as exc:
raise DependencyMissingError(
dependency="mcp",
message="Install package 'mcp' to run the MCP server."
) from exc
adapter = MnemoCoreAPIAdapter(
base_url=cfg.mcp.api_base_url,
api_key=cfg.mcp.api_key or cfg.security.api_key,
timeout_seconds=cfg.mcp.timeout_seconds,
)
server = FastMCP("MnemoCore MCP")
allow_tools = set(cfg.mcp.allow_tools)
def register_tool(name: str, fn: Callable[[], None]) -> None:
if name in allow_tools:
fn()
else:
logger.info("Skipping disabled MCP tool: %s", name)
def with_error_handling(call: Callable[[], Dict[str, Any]]) -> Dict[str, Any]:
try:
return _result_ok(call())
except MnemoCoreAPIError as exc:
return _result_error(str(exc))
except Exception as exc:
return _result_error(f"Unexpected error: {exc}")
def register_memory_store() -> None:
@server.tool()
def memory_store(
content: str,
metadata: Dict[str, Any] | None = None,
agent_id: str | None = None,
ttl: int | None = None,
) -> Dict[str, Any]:
payload = StoreToolInput(
content=content,
metadata=metadata,
agent_id=agent_id,
ttl=ttl,
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.store(payload))
def register_memory_query() -> None:
@server.tool()
def memory_query(
query: str,
top_k: int = 5,
agent_id: str | None = None,
) -> Dict[str, Any]:
payload = QueryToolInput(
query=query,
top_k=top_k,
agent_id=agent_id,
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.query(payload))
def register_memory_get() -> None:
@server.tool()
def memory_get(memory_id: str) -> Dict[str, Any]:
data = MemoryIdInput(memory_id=memory_id)
return with_error_handling(lambda: adapter.get_memory(data.memory_id))
def register_memory_delete() -> None:
@server.tool()
def memory_delete(memory_id: str) -> Dict[str, Any]:
data = MemoryIdInput(memory_id=memory_id)
return with_error_handling(lambda: adapter.delete_memory(data.memory_id))
def register_memory_stats() -> None:
@server.tool()
def memory_stats() -> Dict[str, Any]:
return with_error_handling(adapter.stats)
def register_memory_health() -> None:
@server.tool()
def memory_health() -> Dict[str, Any]:
return with_error_handling(adapter.health)
# --- Phase 5: Cognitive Client Tools ---
def register_store_observation() -> None:
@server.tool()
def store_observation(
agent_id: str,
content: str,
kind: str = "observation",
importance: float = 0.5,
tags: list[str] | None = None
) -> Dict[str, Any]:
payload = ObserveToolInput(
agent_id=agent_id, content=content, kind=kind, importance=importance, tags=tags
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.observe_context(payload))
def register_recall_context() -> None:
@server.tool()
def recall_context(agent_id: str, limit: int = 16) -> Dict[str, Any]:
data = ContextToolInput(agent_id=agent_id, limit=limit)
return with_error_handling(lambda: adapter.get_working_context(data.agent_id, data.limit))
def register_start_episode() -> None:
@server.tool()
def start_episode(agent_id: str, goal: str, context: str | None = None) -> Dict[str, Any]:
payload = EpisodeToolInput(
agent_id=agent_id, goal=goal, context=context
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.start_episode(payload))
register_tool("memory_store", register_memory_store)
register_tool("memory_query", register_memory_query)
register_tool("memory_get", register_memory_get)
register_tool("memory_delete", register_memory_delete)
register_tool("memory_stats", register_memory_stats)
register_tool("memory_health", register_memory_health)
register_tool("store_observation", register_store_observation)
register_tool("recall_context", register_recall_context)
register_tool("start_episode", register_start_episode)
return server
def main() -> None:
cfg = get_config()
if not cfg.mcp.enabled:
logger.warning("MCP is disabled in config (haim.mcp.enabled=false)")
server = build_server(cfg)
if cfg.mcp.transport == "stdio":
server.run(transport="stdio")
return
if cfg.mcp.transport == "sse":
server.run(transport="sse", host=cfg.mcp.host, port=cfg.mcp.port)
return
raise UnsupportedTransportError(
transport=cfg.mcp.transport,
supported_transports=["stdio", "sse"]
)
if __name__ == "__main__":
main()