File size: 16,445 Bytes
399b80c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
"""
Client for managing MCP servers and sessions.

This module provides a high-level client that manages MCP servers, connectors,
and sessions from configuration.
"""
import asyncio
import warnings
from typing import Any, Optional

from openspace.grounding.core.types import SandboxOptions
from openspace.config.utils import get_config_value, save_json_file, load_json_file
from .config import create_connector_from_config
from .session import MCPSession
from .installer import MCPInstallerManager, MCPDependencyError

from openspace.utils.logging import Logger

logger = Logger.get_logger(__name__)


class MCPClient:
    """Client for managing MCP servers and sessions.

    This class provides a unified interface for working with MCP servers,
    handling configuration, connector creation, and session management.
    """

    def __init__(
        self,
        config: str | dict[str, Any] | None = None,
        sandbox: bool = False,
        sandbox_options: SandboxOptions | None = None,
        timeout: float = 30.0,
        sse_read_timeout: float = 300.0,
        max_retries: int = 3,
        retry_interval: float = 2.0,
        installer: Optional[MCPInstallerManager] = None,
        check_dependencies: bool = True,
        tool_call_max_retries: int = 3,
        tool_call_retry_delay: float = 1.0,
    ) -> None:
        """Initialize a new MCP client.

        Args:
            config: Either a dict containing configuration or a path to a JSON config file.
                   If None, an empty configuration is used.
            sandbox: Whether to use sandboxed execution mode for running MCP servers.
            sandbox_options: Optional sandbox configuration options.
            timeout: Timeout for operations in seconds (default: 30.0)
            sse_read_timeout: SSE read timeout in seconds (default: 300.0)
            max_retries: Maximum number of retry attempts for failed operations (default: 3)
            retry_interval: Wait time between retries in seconds (default: 2.0)
            installer: Optional installer manager for dependency installation
            check_dependencies: Whether to check and install dependencies (default: True)
            tool_call_max_retries: Maximum number of retries for tool calls (default: 3)
            tool_call_retry_delay: Initial delay between tool call retries in seconds (default: 1.0)
        """
        self.config: dict[str, Any] = {}
        self.sandbox = sandbox
        self.sandbox_options = sandbox_options
        self.timeout = timeout
        self.sse_read_timeout = sse_read_timeout
        self.max_retries = max_retries
        self.retry_interval = retry_interval
        self.installer = installer
        self.check_dependencies = check_dependencies
        self.tool_call_max_retries = tool_call_max_retries
        self.tool_call_retry_delay = tool_call_retry_delay
        self.sessions: dict[str, MCPSession] = {}
        self.active_sessions: list[str] = []

        # Load configuration if provided
        if config is not None:
            if isinstance(config, str):
                self.config = load_json_file(config)
            else:
                self.config = config
    
    def _get_mcp_servers(self) -> dict[str, Any]:
        """Internal helper to get mcpServers configuration.
        
        Tries both 'mcpServers' and 'servers' keys for compatibility.
        
        Returns:
            Dictionary of MCP server configurations, empty dict if none found.
        """
        servers = get_config_value(self.config, "mcpServers", None)
        if servers is None:
            servers = get_config_value(self.config, "servers", {})
        return servers or {}

    @classmethod
    def from_dict(
        cls,
        config: dict[str, Any],
        sandbox: bool = False,
        sandbox_options: SandboxOptions | None = None,
        timeout: float = 30.0,
        sse_read_timeout: float = 300.0,
        max_retries: int = 3,
        retry_interval: float = 2.0,
    ) -> "MCPClient":
        """Create a MCPClient from a dictionary.

        Args:
            config: The configuration dictionary.
            sandbox: Whether to use sandboxed execution mode for running MCP servers.
            sandbox_options: Optional sandbox configuration options.
            timeout: Timeout for operations in seconds (default: 30.0)
            sse_read_timeout: SSE read timeout in seconds (default: 300.0)
            max_retries: Maximum number of retry attempts (default: 3)
            retry_interval: Wait time between retries in seconds (default: 2.0)
        """
        return cls(config=config, sandbox=sandbox, sandbox_options=sandbox_options, 
                   timeout=timeout, sse_read_timeout=sse_read_timeout,
                   max_retries=max_retries, retry_interval=retry_interval)

    @classmethod
    def from_config_file(
        cls, filepath: str, sandbox: bool = False, sandbox_options: SandboxOptions | None = None,
        timeout: float = 30.0, sse_read_timeout: float = 300.0,
        max_retries: int = 3, retry_interval: float = 2.0,
    ) -> "MCPClient":
        """Create a MCPClient from a configuration file.

        Args:
            filepath: The path to the configuration file.
            sandbox: Whether to use sandboxed execution mode for running MCP servers.
            sandbox_options: Optional sandbox configuration options.
            timeout: Timeout for operations in seconds (default: 30.0)
            sse_read_timeout: SSE read timeout in seconds (default: 300.0)
            max_retries: Maximum number of retry attempts (default: 3)
            retry_interval: Wait time between retries in seconds (default: 2.0)
        """
        return cls(config=load_json_file(filepath), sandbox=sandbox, sandbox_options=sandbox_options,
                   timeout=timeout, sse_read_timeout=sse_read_timeout,
                   max_retries=max_retries, retry_interval=retry_interval)

    def add_server(
        self,
        name: str,
        server_config: dict[str, Any],
    ) -> None:
        """Add a server configuration.

        Args:
            name: The name to identify this server.
            server_config: The server configuration.
        """
        mcp_servers = self._get_mcp_servers()
        if "mcpServers" not in self.config:
            self.config["mcpServers"] = {}
        
        self.config["mcpServers"][name] = server_config
        logger.debug(f"Added MCP server configuration: {name}")

    def remove_server(self, name: str) -> None:
        """Remove a server configuration.

        Args:
            name: The name of the server to remove.
        """
        mcp_servers = self._get_mcp_servers()
        if name in mcp_servers:
            # Remove from config
            if "mcpServers" in self.config:
                self.config["mcpServers"].pop(name, None)
            elif "servers" in self.config:
                self.config["servers"].pop(name, None)

            # If we removed an active session, remove it from active_sessions
            if name in self.active_sessions:
                self.active_sessions.remove(name)
            
            logger.debug(f"Removed MCP server configuration: {name}")
        else:
            logger.warning(f"Server '{name}' not found in configuration")

    def get_server_names(self) -> list[str]:
        """Get the list of configured server names.

        Returns:
            List of server names.
        """
        return list(self._get_mcp_servers().keys())

    def save_config(self, filepath: str) -> None:
        """Save the current configuration to a file.

        Args:
            filepath: The path to save the configuration to.
        """
        save_json_file(self.config, filepath)

    async def create_session(self, server_name: str, auto_initialize: bool = True) -> MCPSession:
        """Create a session for the specified server with retry logic.

        Args:
            server_name: The name of the server to create a session for.
            auto_initialize: Whether to automatically initialize the session.

        Returns:
            The created MCPSession.

        Raises:
            ValueError: If the specified server doesn't exist.
            Exception: If session creation fails after all retries.
        """
        # Check if session already exists
        if server_name in self.sessions:
            logger.debug(f"Session for server '{server_name}' already exists, returning existing session")
            return self.sessions[server_name]
        
        # Get server config
        servers = self._get_mcp_servers()
        
        if not servers:
            warnings.warn("No MCP servers defined in config", UserWarning, stacklevel=2)
            return None

        if server_name not in servers:
            raise ValueError(f"Server '{server_name}' not found in config. Available: {list(servers.keys())}")

        server_config = servers[server_name]

        # Retry logic for session creation
        last_exc: Exception | None = None
        
        for attempt in range(1, self.max_retries + 1):
            try:
                # Create connector with options (now async)
                connector = await create_connector_from_config(
                    server_config,
                    server_name=server_name,
                    sandbox=self.sandbox, 
                    sandbox_options=self.sandbox_options,
                    timeout=self.timeout,
                    sse_read_timeout=self.sse_read_timeout,
                    installer=self.installer,
                    check_dependencies=self.check_dependencies,
                    tool_call_max_retries=self.tool_call_max_retries,
                    tool_call_retry_delay=self.tool_call_retry_delay,
                )

                # Create the session with proper initialization parameters
                session = MCPSession(
                    connector=connector,
                    session_id=f"mcp-{server_name}",
                    auto_connect=True,
                    auto_initialize=False,  # We'll handle initialization explicitly below
                )
                
                # Initialize if requested
                if auto_initialize:
                    await session.initialize()
                    logger.debug(f"Initialized session for server '{server_name}'")
                
                # Store session
                self.sessions[server_name] = session

                # Add to active sessions
                if server_name not in self.active_sessions:
                    self.active_sessions.append(server_name)
                
                logger.info(f"Created session for MCP server '{server_name}' (attempt {attempt}/{self.max_retries})")
                return session
                
            except MCPDependencyError as e:
                # Don't retry dependency errors - they won't succeed on retry
                # Error already shown to user by installer, just re-raise
                logger.debug(f"Dependency error for server '{server_name}': {type(e).__name__}")
                raise
            except Exception as e:
                last_exc = e
                if attempt == self.max_retries:
                    break
                
                # Use info level for first attempt (common after fresh install), warning for subsequent
                log_level = logger.info if attempt == 1 else logger.warning
                log_level(
                    f"Failed to create session for server '{server_name}' (attempt {attempt}/{self.max_retries}): {e}, "
                    f"retrying in {self.retry_interval} seconds..."
                )
                await asyncio.sleep(self.retry_interval)
        
        # All retries failed
        error_msg = f"Failed to create session for server '{server_name}' after {self.max_retries} retries"
        logger.error(error_msg)
        raise last_exc or RuntimeError(error_msg)

    async def create_all_sessions(
        self,
        auto_initialize: bool = True,
    ) -> dict[str, MCPSession]:
        """Create sessions for all configured servers.

        Args:
            auto_initialize: Whether to automatically initialize the sessions.

        Returns:
            Dictionary mapping server names to their MCPSession instances.

        Warns:
            UserWarning: If no servers are configured.
        """
        servers = self._get_mcp_servers()
        
        if not servers:
            warnings.warn("No MCP servers defined in config", UserWarning, stacklevel=2)
            return {}

        # Create sessions for all servers (create_session already handles initialization)
        logger.debug(f"Creating sessions for {len(servers)} servers")
        for name in servers:
            try:
                await self.create_session(name, auto_initialize)
            except Exception as e:
                logger.error(f"Failed to create session for server '{name}': {e}")

        logger.info(f"Created {len(self.sessions)} MCP sessions")
        return self.sessions

    def get_session(self, server_name: str) -> MCPSession:
        """Get an existing session.

        Args:
            server_name: The name of the server to get the session for.
                        If None, uses the first active session.

        Returns:
            The MCPSession for the specified server.

        Raises:
            ValueError: If no active sessions exist or the specified session doesn't exist.
        """
        if server_name not in self.sessions:
            raise ValueError(f"No session exists for server '{server_name}'")

        return self.sessions[server_name]

    def get_all_active_sessions(self) -> dict[str, MCPSession]:
        """Get all active sessions.

        Returns:
            Dictionary mapping server names to their MCPSession instances.
        """
        return {name: self.sessions[name] for name in self.active_sessions if name in self.sessions}

    async def close_session(self, server_name: str) -> None:
        """Close a session.

        Args:
            server_name: The name of the server to close the session for.

        Raises:
            ValueError: If no active sessions exist or the specified session doesn't exist.
        """
        # Check if the session exists
        if server_name not in self.sessions:
            logger.warning(f"No session exists for server '{server_name}', nothing to close")
            return

        # Get the session
        session = self.sessions[server_name]
        error_occurred = False

        try:
            # Disconnect from the session
            logger.debug(f"Closing session for server '{server_name}'")
            await session.disconnect()
            logger.info(f"Successfully closed session for server '{server_name}'")
        except Exception as e:
            error_occurred = True
            logger.error(f"Error closing session for server '{server_name}': {e}")
        finally:
            # Remove the session regardless of whether disconnect succeeded
            self.sessions.pop(server_name, None)

            # Remove from active_sessions
            if server_name in self.active_sessions:
                self.active_sessions.remove(server_name)
            
            if error_occurred:
                logger.warning(f"Session for '{server_name}' removed from tracking despite disconnect error")

    async def close_all_sessions(self) -> None:
        """Close all active sessions.

        This method ensures all sessions are closed even if some fail.
        """
        # Get a list of all session names first to avoid modification during iteration
        server_names = list(self.sessions.keys())
        errors = []

        for server_name in server_names:
            try:
                logger.debug(f"Closing session for server '{server_name}'")
                await self.close_session(server_name)
            except Exception as e:
                error_msg = f"Failed to close session for server '{server_name}': {e}"
                logger.error(error_msg)
                errors.append(error_msg)

        # Log summary if there were errors
        if errors:
            logger.error(f"Encountered {len(errors)} errors while closing sessions")
        else:
            logger.debug("All sessions closed successfully")