File size: 8,809 Bytes
afefaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""MCP Tool Registry for dynamic tool discovery and management."""

import asyncio
from typing import Any, Callable, Optional
from dataclasses import dataclass, field
from enum import Enum

from app.utils.logging import get_logger

logger = get_logger(__name__)


class ToolStatus(Enum):
    """Status of a registered tool."""

    UNKNOWN = "unknown"
    HEALTHY = "healthy"
    UNHEALTHY = "unhealthy"
    INITIALIZING = "initializing"
    SHUTDOWN = "shutdown"


@dataclass
class ToolDefinition:
    """Definition of a registered tool."""

    name: str
    description: str
    handler: Callable[..., Any]
    parameters: dict[str, Any] = field(default_factory=dict)
    status: ToolStatus = ToolStatus.UNKNOWN
    metadata: dict[str, Any] = field(default_factory=dict)


class MCPToolRegistry:
    """
    Registry for MCP tools with dynamic discovery and execution.

    Manages tool lifecycle including registration, health checks,
    and execution routing.
    """

    def __init__(self) -> None:
        self._tools: dict[str, ToolDefinition] = {}
        self._initialized: bool = False
        self._health_check_interval: float = 30.0
        self._health_check_task: Optional[asyncio.Task[None]] = None

    async def initialize(self) -> None:
        """Initialize the registry and start health monitoring."""
        if self._initialized:
            logger.warning("Registry already initialized")
            return

        logger.info("Initializing MCP Tool Registry")

        # Start health check background task
        self._health_check_task = asyncio.create_task(self._health_check_loop())
        self._initialized = True

        logger.info("MCP Tool Registry initialized")

    async def shutdown(self) -> None:
        """Shutdown the registry and cleanup resources."""
        logger.info("Shutting down MCP Tool Registry")

        # Cancel health check task
        if self._health_check_task:
            self._health_check_task.cancel()
            try:
                await self._health_check_task
            except asyncio.CancelledError:
                pass

        # Mark all tools as shutdown
        for tool in self._tools.values():
            tool.status = ToolStatus.SHUTDOWN

        self._initialized = False
        logger.info("MCP Tool Registry shutdown complete")

    def register(
        self,
        name: str,
        handler: Callable[..., Any],
        description: str = "",
        parameters: Optional[dict[str, Any]] = None,
        metadata: Optional[dict[str, Any]] = None,
    ) -> ToolDefinition:
        """
        Register a new tool with the registry.

        Args:
            name: Unique tool name
            handler: Callable that implements the tool
            description: Human-readable description
            parameters: JSON schema for tool parameters
            metadata: Additional tool metadata

        Returns:
            The registered ToolDefinition

        Raises:
            ValueError: If a tool with the same name already exists
        """
        if name in self._tools:
            raise ValueError(f"Tool '{name}' is already registered")

        tool = ToolDefinition(
            name=name,
            description=description,
            handler=handler,
            parameters=parameters or {},
            status=ToolStatus.INITIALIZING,
            metadata=metadata or {},
        )

        self._tools[name] = tool
        logger.info(f"Registered tool: {name}")

        return tool

    def unregister(self, name: str) -> bool:
        """
        Unregister a tool from the registry.

        Args:
            name: Tool name to unregister

        Returns:
            True if tool was removed, False if not found
        """
        if name in self._tools:
            del self._tools[name]
            logger.info(f"Unregistered tool: {name}")
            return True
        return False

    def get(self, name: str) -> Optional[ToolDefinition]:
        """
        Get a tool definition by name.

        Args:
            name: Tool name to retrieve

        Returns:
            ToolDefinition if found, None otherwise
        """
        return self._tools.get(name)

    def list_tools(
        self,
        include_unhealthy: bool = False,
    ) -> list[ToolDefinition]:
        """
        List all registered tools.

        Args:
            include_unhealthy: Include tools with unhealthy status

        Returns:
            List of tool definitions
        """
        tools = list(self._tools.values())

        if not include_unhealthy:
            tools = [
                t for t in tools
                if t.status not in (ToolStatus.UNHEALTHY, ToolStatus.SHUTDOWN)
            ]

        return tools

    async def execute(
        self,
        name: str,
        **kwargs: Any,
    ) -> Any:
        """
        Execute a tool by name with the given parameters.

        Args:
            name: Tool name to execute
            **kwargs: Tool parameters

        Returns:
            Tool execution result

        Raises:
            KeyError: If tool is not found
            RuntimeError: If tool is not healthy
        """
        tool = self.get(name)

        if tool is None:
            raise KeyError(f"Tool '{name}' not found")

        if tool.status == ToolStatus.UNHEALTHY:
            raise RuntimeError(f"Tool '{name}' is unhealthy")

        if tool.status == ToolStatus.SHUTDOWN:
            raise RuntimeError(f"Tool '{name}' has been shut down")

        logger.debug(f"Executing tool: {name} with params: {kwargs}")

        try:
            # Handle both sync and async handlers
            if asyncio.iscoroutinefunction(tool.handler):
                result = await tool.handler(**kwargs)
            else:
                result = tool.handler(**kwargs)

            return result

        except Exception as e:
            logger.error(f"Tool execution failed: {name} - {e}")
            raise

    async def health_check(self, name: str) -> ToolStatus:
        """
        Check the health of a specific tool.

        Args:
            name: Tool name to check

        Returns:
            Current tool status
        """
        tool = self.get(name)
        if tool is None:
            return ToolStatus.UNKNOWN

        try:
            # Try to call a health check method if available
            handler = tool.handler
            if hasattr(handler, "health_check"):
                health_fn = getattr(handler, "health_check")
                if asyncio.iscoroutinefunction(health_fn):
                    await health_fn()
                else:
                    health_fn()

            tool.status = ToolStatus.HEALTHY
        except Exception as e:
            logger.warning(f"Health check failed for {name}: {e}")
            tool.status = ToolStatus.UNHEALTHY

        return tool.status

    async def health_check_all(self) -> dict[str, ToolStatus]:
        """
        Check health of all registered tools.

        Returns:
            Dictionary mapping tool names to their status
        """
        results: dict[str, ToolStatus] = {}

        for name in self._tools:
            results[name] = await self.health_check(name)

        return results

    async def _health_check_loop(self) -> None:
        """Background task for periodic health checks."""
        while True:
            try:
                await asyncio.sleep(self._health_check_interval)
                await self.health_check_all()
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Health check loop error: {e}")

    def get_tool_schema(self, name: str) -> Optional[dict[str, Any]]:
        """
        Get the JSON schema for a tool's parameters.

        Args:
            name: Tool name

        Returns:
            Parameter schema dict or None if not found
        """
        tool = self.get(name)
        if tool is None:
            return None

        return {
            "name": tool.name,
            "description": tool.description,
            "parameters": tool.parameters,
        }

    def list_schemas(self) -> list[dict[str, Any]]:
        """
        Get schemas for all registered tools.

        Returns:
            List of tool schema dictionaries
        """
        schemas = []
        for name in self._tools:
            schema = self.get_tool_schema(name)
            if schema:
                schemas.append(schema)
        return schemas

    @property
    def is_initialized(self) -> bool:
        """Check if the registry has been initialized."""
        return self._initialized

    @property
    def tool_count(self) -> int:
        """Get the number of registered tools."""
        return len(self._tools)