File size: 6,346 Bytes
dbb04e4
 
 
 
 
 
 
 
 
 
 
c3a3710
 
 
 
dbb04e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3a3710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb04e4
 
 
 
 
 
c3a3710
 
 
dbb04e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""

MnemoCore MCP Server

====================

MCP bridge exposing MnemoCore API tools for agent clients.

"""

from typing import Any, Callable, Dict
from loguru import logger

from mnemocore.core.config import get_config, HAIMConfig
from mnemocore.mcp.adapters.api_adapter import MnemoCoreAPIAdapter, MnemoCoreAPIError
from mnemocore.mcp.schemas import (
    StoreToolInput, QueryToolInput, MemoryIdInput,
    ObserveToolInput, ContextToolInput, EpisodeToolInput
)
from mnemocore.core.exceptions import (
    DependencyMissingError,
    UnsupportedTransportError,
)


def _result_ok(data: Dict[str, Any]) -> Dict[str, Any]:
    return {"ok": True, "data": data}


def _result_error(message: str) -> Dict[str, Any]:
    return {"ok": False, "error": message}


def build_server(config: HAIMConfig | None = None):
    cfg = config or get_config()

    try:
        from mcp.server.fastmcp import FastMCP
    except ImportError as exc:
        raise DependencyMissingError(
            dependency="mcp",
            message="Install package 'mcp' to run the MCP server."
        ) from exc

    adapter = MnemoCoreAPIAdapter(
        base_url=cfg.mcp.api_base_url,
        api_key=cfg.mcp.api_key or cfg.security.api_key,
        timeout_seconds=cfg.mcp.timeout_seconds,
    )

    server = FastMCP("MnemoCore MCP")
    allow_tools = set(cfg.mcp.allow_tools)

    def register_tool(name: str, fn: Callable[[], None]) -> None:
        if name in allow_tools:
            fn()
        else:
            logger.info("Skipping disabled MCP tool: %s", name)

    def with_error_handling(call: Callable[[], Dict[str, Any]]) -> Dict[str, Any]:
        try:
            return _result_ok(call())
        except MnemoCoreAPIError as exc:
            return _result_error(str(exc))
        except Exception as exc:
            return _result_error(f"Unexpected error: {exc}")

    def register_memory_store() -> None:
        @server.tool()
        def memory_store(

            content: str,

            metadata: Dict[str, Any] | None = None,

            agent_id: str | None = None,

            ttl: int | None = None,

        ) -> Dict[str, Any]:
            payload = StoreToolInput(
                content=content,
                metadata=metadata,
                agent_id=agent_id,
                ttl=ttl,
            ).model_dump(exclude_none=True)
            return with_error_handling(lambda: adapter.store(payload))

    def register_memory_query() -> None:
        @server.tool()
        def memory_query(

            query: str,

            top_k: int = 5,

            agent_id: str | None = None,

        ) -> Dict[str, Any]:
            payload = QueryToolInput(
                query=query,
                top_k=top_k,
                agent_id=agent_id,
            ).model_dump(exclude_none=True)
            return with_error_handling(lambda: adapter.query(payload))

    def register_memory_get() -> None:
        @server.tool()
        def memory_get(memory_id: str) -> Dict[str, Any]:
            data = MemoryIdInput(memory_id=memory_id)
            return with_error_handling(lambda: adapter.get_memory(data.memory_id))

    def register_memory_delete() -> None:
        @server.tool()
        def memory_delete(memory_id: str) -> Dict[str, Any]:
            data = MemoryIdInput(memory_id=memory_id)
            return with_error_handling(lambda: adapter.delete_memory(data.memory_id))

    def register_memory_stats() -> None:
        @server.tool()
        def memory_stats() -> Dict[str, Any]:
            return with_error_handling(adapter.stats)

    def register_memory_health() -> None:
        @server.tool()
        def memory_health() -> Dict[str, Any]:
            return with_error_handling(adapter.health)

    # --- Phase 5: Cognitive Client Tools ---

    def register_store_observation() -> None:
        @server.tool()
        def store_observation(

            agent_id: str,

            content: str,

            kind: str = "observation",

            importance: float = 0.5,

            tags: list[str] | None = None

        ) -> Dict[str, Any]:
            payload = ObserveToolInput(
                agent_id=agent_id, content=content, kind=kind, importance=importance, tags=tags
            ).model_dump(exclude_none=True)
            return with_error_handling(lambda: adapter.observe_context(payload))

    def register_recall_context() -> None:
        @server.tool()
        def recall_context(agent_id: str, limit: int = 16) -> Dict[str, Any]:
            data = ContextToolInput(agent_id=agent_id, limit=limit)
            return with_error_handling(lambda: adapter.get_working_context(data.agent_id, data.limit))

    def register_start_episode() -> None:
        @server.tool()
        def start_episode(agent_id: str, goal: str, context: str | None = None) -> Dict[str, Any]:
            payload = EpisodeToolInput(
                agent_id=agent_id, goal=goal, context=context
            ).model_dump(exclude_none=True)
            return with_error_handling(lambda: adapter.start_episode(payload))

    register_tool("memory_store", register_memory_store)
    register_tool("memory_query", register_memory_query)
    register_tool("memory_get", register_memory_get)
    register_tool("memory_delete", register_memory_delete)
    register_tool("memory_stats", register_memory_stats)
    register_tool("memory_health", register_memory_health)
    register_tool("store_observation", register_store_observation)
    register_tool("recall_context", register_recall_context)
    register_tool("start_episode", register_start_episode)

    return server


def main() -> None:
    cfg = get_config()
    if not cfg.mcp.enabled:
        logger.warning("MCP is disabled in config (haim.mcp.enabled=false)")

    server = build_server(cfg)

    if cfg.mcp.transport == "stdio":
        server.run(transport="stdio")
        return

    if cfg.mcp.transport == "sse":
        server.run(transport="sse", host=cfg.mcp.host, port=cfg.mcp.port)
        return

    raise UnsupportedTransportError(
        transport=cfg.mcp.transport,
        supported_transports=["stdio", "sse"]
    )


if __name__ == "__main__":
    main()