File size: 6,346 Bytes
dbb04e4 c3a3710 dbb04e4 c3a3710 dbb04e4 c3a3710 dbb04e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """
MnemoCore MCP Server
====================
MCP bridge exposing MnemoCore API tools for agent clients.
"""
from typing import Any, Callable, Dict
from loguru import logger
from mnemocore.core.config import get_config, HAIMConfig
from mnemocore.mcp.adapters.api_adapter import MnemoCoreAPIAdapter, MnemoCoreAPIError
from mnemocore.mcp.schemas import (
StoreToolInput, QueryToolInput, MemoryIdInput,
ObserveToolInput, ContextToolInput, EpisodeToolInput
)
from mnemocore.core.exceptions import (
DependencyMissingError,
UnsupportedTransportError,
)
def _result_ok(data: Dict[str, Any]) -> Dict[str, Any]:
return {"ok": True, "data": data}
def _result_error(message: str) -> Dict[str, Any]:
return {"ok": False, "error": message}
def build_server(config: HAIMConfig | None = None):
cfg = config or get_config()
try:
from mcp.server.fastmcp import FastMCP
except ImportError as exc:
raise DependencyMissingError(
dependency="mcp",
message="Install package 'mcp' to run the MCP server."
) from exc
adapter = MnemoCoreAPIAdapter(
base_url=cfg.mcp.api_base_url,
api_key=cfg.mcp.api_key or cfg.security.api_key,
timeout_seconds=cfg.mcp.timeout_seconds,
)
server = FastMCP("MnemoCore MCP")
allow_tools = set(cfg.mcp.allow_tools)
def register_tool(name: str, fn: Callable[[], None]) -> None:
if name in allow_tools:
fn()
else:
logger.info("Skipping disabled MCP tool: %s", name)
def with_error_handling(call: Callable[[], Dict[str, Any]]) -> Dict[str, Any]:
try:
return _result_ok(call())
except MnemoCoreAPIError as exc:
return _result_error(str(exc))
except Exception as exc:
return _result_error(f"Unexpected error: {exc}")
def register_memory_store() -> None:
@server.tool()
def memory_store(
content: str,
metadata: Dict[str, Any] | None = None,
agent_id: str | None = None,
ttl: int | None = None,
) -> Dict[str, Any]:
payload = StoreToolInput(
content=content,
metadata=metadata,
agent_id=agent_id,
ttl=ttl,
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.store(payload))
def register_memory_query() -> None:
@server.tool()
def memory_query(
query: str,
top_k: int = 5,
agent_id: str | None = None,
) -> Dict[str, Any]:
payload = QueryToolInput(
query=query,
top_k=top_k,
agent_id=agent_id,
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.query(payload))
def register_memory_get() -> None:
@server.tool()
def memory_get(memory_id: str) -> Dict[str, Any]:
data = MemoryIdInput(memory_id=memory_id)
return with_error_handling(lambda: adapter.get_memory(data.memory_id))
def register_memory_delete() -> None:
@server.tool()
def memory_delete(memory_id: str) -> Dict[str, Any]:
data = MemoryIdInput(memory_id=memory_id)
return with_error_handling(lambda: adapter.delete_memory(data.memory_id))
def register_memory_stats() -> None:
@server.tool()
def memory_stats() -> Dict[str, Any]:
return with_error_handling(adapter.stats)
def register_memory_health() -> None:
@server.tool()
def memory_health() -> Dict[str, Any]:
return with_error_handling(adapter.health)
# --- Phase 5: Cognitive Client Tools ---
def register_store_observation() -> None:
@server.tool()
def store_observation(
agent_id: str,
content: str,
kind: str = "observation",
importance: float = 0.5,
tags: list[str] | None = None
) -> Dict[str, Any]:
payload = ObserveToolInput(
agent_id=agent_id, content=content, kind=kind, importance=importance, tags=tags
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.observe_context(payload))
def register_recall_context() -> None:
@server.tool()
def recall_context(agent_id: str, limit: int = 16) -> Dict[str, Any]:
data = ContextToolInput(agent_id=agent_id, limit=limit)
return with_error_handling(lambda: adapter.get_working_context(data.agent_id, data.limit))
def register_start_episode() -> None:
@server.tool()
def start_episode(agent_id: str, goal: str, context: str | None = None) -> Dict[str, Any]:
payload = EpisodeToolInput(
agent_id=agent_id, goal=goal, context=context
).model_dump(exclude_none=True)
return with_error_handling(lambda: adapter.start_episode(payload))
register_tool("memory_store", register_memory_store)
register_tool("memory_query", register_memory_query)
register_tool("memory_get", register_memory_get)
register_tool("memory_delete", register_memory_delete)
register_tool("memory_stats", register_memory_stats)
register_tool("memory_health", register_memory_health)
register_tool("store_observation", register_store_observation)
register_tool("recall_context", register_recall_context)
register_tool("start_episode", register_start_episode)
return server
def main() -> None:
cfg = get_config()
if not cfg.mcp.enabled:
logger.warning("MCP is disabled in config (haim.mcp.enabled=false)")
server = build_server(cfg)
if cfg.mcp.transport == "stdio":
server.run(transport="stdio")
return
if cfg.mcp.transport == "sse":
server.run(transport="sse", host=cfg.mcp.host, port=cfg.mcp.port)
return
raise UnsupportedTransportError(
transport=cfg.mcp.transport,
supported_transports=["stdio", "sse"]
)
if __name__ == "__main__":
main()
|