| """ |
| DungeonMaster AI - MCP Connection Manager |
| |
| Manages MCP connection lifecycle with health checks, automatic reconnection, |
| circuit breaker pattern, and graceful degradation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import contextlib |
| import logging |
| import random |
| import time |
| from collections.abc import Sequence |
| from datetime import datetime |
| from typing import TYPE_CHECKING |
|
|
| from llama_index.core.tools import FunctionTool |
|
|
| from src.config.settings import get_settings |
|
|
| from .exceptions import ( |
| MCPCircuitBreakerOpenError, |
| MCPUnavailableError, |
| ) |
| from .fallbacks import FallbackHandler |
| from .models import ( |
| CircuitBreakerState, |
| ConnectionState, |
| MCPConnectionStatus, |
| ) |
| from .toolkit_client import TTRPGToolkitClient |
|
|
| if TYPE_CHECKING: |
| from typing import Any |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class CircuitBreaker: |
| """ |
| Circuit breaker pattern implementation. |
| |
| Prevents repeated calls to a failing service by tracking failures |
| and temporarily rejecting requests when failure threshold is reached. |
| |
| States: |
| - CLOSED: Normal operation, requests allowed |
| - OPEN: Too many failures, requests rejected for reset_timeout |
| - HALF_OPEN: Testing if service recovered with a single request |
| """ |
|
|
| def __init__( |
| self, |
| failure_threshold: int = 5, |
| reset_timeout: float = 30.0, |
| half_open_max_calls: int = 1, |
| ) -> None: |
| """ |
| Initialize circuit breaker. |
| |
| Args: |
| failure_threshold: Failures before opening circuit |
| reset_timeout: Seconds to wait before trying again |
| half_open_max_calls: Max calls allowed in half-open state |
| """ |
| self.failure_threshold = failure_threshold |
| self.reset_timeout = reset_timeout |
| self.half_open_max_calls = half_open_max_calls |
|
|
| self._state = CircuitBreakerState.CLOSED |
| self._failure_count = 0 |
| self._last_failure_time: float | None = None |
| self._half_open_calls = 0 |
|
|
| @property |
| def state(self) -> CircuitBreakerState: |
| """Get current circuit breaker state.""" |
| |
| if ( |
| self._state == CircuitBreakerState.OPEN |
| and self._last_failure_time is not None |
| ): |
| elapsed = time.time() - self._last_failure_time |
| if elapsed >= self.reset_timeout: |
| self._state = CircuitBreakerState.HALF_OPEN |
| self._half_open_calls = 0 |
| logger.info("Circuit breaker transitioned to HALF_OPEN") |
|
|
| return self._state |
|
|
| @property |
| def is_open(self) -> bool: |
| """Check if circuit is open (rejecting requests).""" |
| return self.state == CircuitBreakerState.OPEN |
|
|
| @property |
| def time_until_retry(self) -> float | None: |
| """Seconds until retry is allowed, or None if not in OPEN state.""" |
| if self._state != CircuitBreakerState.OPEN: |
| return None |
| if self._last_failure_time is None: |
| return None |
| elapsed = time.time() - self._last_failure_time |
| remaining = self.reset_timeout - elapsed |
| return max(0.0, remaining) |
|
|
| def record_success(self) -> None: |
| """Record a successful call.""" |
| if self._state == CircuitBreakerState.HALF_OPEN: |
| |
| self._state = CircuitBreakerState.CLOSED |
| logger.info("Circuit breaker closed after successful recovery") |
|
|
| self._failure_count = 0 |
| self._last_failure_time = None |
|
|
| def record_failure(self) -> None: |
| """Record a failed call.""" |
| self._failure_count += 1 |
| self._last_failure_time = time.time() |
|
|
| if self._state == CircuitBreakerState.HALF_OPEN: |
| |
| self._state = CircuitBreakerState.OPEN |
| logger.warning("Circuit breaker reopened after half-open failure") |
| elif self._failure_count >= self.failure_threshold: |
| self._state = CircuitBreakerState.OPEN |
| logger.warning( |
| f"Circuit breaker opened after {self._failure_count} failures" |
| ) |
|
|
| def allow_request(self) -> bool: |
| """ |
| Check if a request should be allowed. |
| |
| Returns: |
| True if request is allowed, False if should be rejected. |
| """ |
| state = self.state |
|
|
| if state == CircuitBreakerState.CLOSED: |
| return True |
|
|
| if state == CircuitBreakerState.HALF_OPEN: |
| if self._half_open_calls < self.half_open_max_calls: |
| self._half_open_calls += 1 |
| return True |
| return False |
|
|
| |
| return False |
|
|
| def reset(self) -> None: |
| """Reset circuit breaker to initial state.""" |
| self._state = CircuitBreakerState.CLOSED |
| self._failure_count = 0 |
| self._last_failure_time = None |
| self._half_open_calls = 0 |
|
|
|
|
| class ConnectionManager: |
| """ |
| Manages MCP connection lifecycle with health checks and reconnection. |
| |
| Features: |
| - Automatic reconnection with exponential backoff |
| - Circuit breaker to prevent hammering failed server |
| - Health check monitoring |
| - Graceful degradation via FallbackHandler |
| - Connection status tracking |
| |
| Example: |
| ```python |
| manager = ConnectionManager() |
| connected = await manager.connect() |
| |
| if manager.is_available: |
| tools = await manager.get_tools() |
| result = await manager.execute_tool("roll", {"notation": "1d20"}) |
| else: |
| # Fallback handling |
| result = await manager.execute_with_fallback("roll", {"notation": "1d20"}) |
| ``` |
| """ |
|
|
| def __init__( |
| self, |
| toolkit_client: TTRPGToolkitClient | None = None, |
| max_retries: int | None = None, |
| retry_delay: float | None = None, |
| fallback_handler: FallbackHandler | None = None, |
| ) -> None: |
| """ |
| Initialize connection manager. |
| |
| Args: |
| toolkit_client: Pre-configured client, or None to create new one |
| max_retries: Max reconnection attempts (default from settings) |
| retry_delay: Base delay between retries (default from settings) |
| fallback_handler: Handler for graceful degradation |
| """ |
| settings = get_settings() |
|
|
| self._client = toolkit_client or TTRPGToolkitClient() |
| self._max_retries: int = max_retries or settings.mcp.mcp_retry_attempts |
| self._retry_delay: float = retry_delay or settings.mcp.mcp_retry_delay |
| self._fallback_handler = fallback_handler or FallbackHandler() |
|
|
| |
| self._state = ConnectionState.DISCONNECTED |
| self._last_successful_call: datetime | None = None |
| self._consecutive_failures = 0 |
| self._last_error: str | None = None |
|
|
| |
| self._circuit_breaker = CircuitBreaker( |
| failure_threshold=5, |
| reset_timeout=30.0, |
| ) |
|
|
| |
| self._health_check_task: asyncio.Task[None] | None = None |
| self._health_check_interval = 60.0 |
|
|
| @property |
| def state(self) -> ConnectionState: |
| """Get current connection state.""" |
| return self._state |
|
|
| @property |
| def is_available(self) -> bool: |
| """Check if MCP is available for use.""" |
| return ( |
| self._state == ConnectionState.CONNECTED |
| and not self._circuit_breaker.is_open |
| ) |
|
|
| @property |
| def client(self) -> TTRPGToolkitClient: |
| """Get the underlying toolkit client.""" |
| return self._client |
|
|
| def get_status(self) -> MCPConnectionStatus: |
| """Get detailed connection status.""" |
| return MCPConnectionStatus( |
| state=self._state, |
| is_available=self.is_available, |
| url=self._client.url, |
| last_successful_call=self._last_successful_call, |
| consecutive_failures=self._consecutive_failures, |
| circuit_breaker_state=self._circuit_breaker.state, |
| tools_count=self._client.tools_count, |
| error_message=self._last_error, |
| ) |
|
|
| async def connect(self) -> bool: |
| """ |
| Connect to MCP server with retry logic. |
| |
| Returns: |
| True if connection successful, False otherwise. |
| """ |
| self._state = ConnectionState.CONNECTING |
| logger.info("Attempting to connect to MCP server...") |
|
|
| for attempt in range(self._max_retries): |
| try: |
| await self._client.connect() |
| self._state = ConnectionState.CONNECTED |
| self._consecutive_failures = 0 |
| self._last_successful_call = datetime.now() |
| self._last_error = None |
| self._circuit_breaker.reset() |
|
|
| logger.info("Successfully connected to MCP server") |
| return True |
|
|
| except Exception as e: |
| self._consecutive_failures += 1 |
| self._last_error = str(e) |
| logger.warning( |
| f"Connection attempt {attempt + 1}/{self._max_retries} failed: {e}" |
| ) |
|
|
| if attempt < self._max_retries - 1: |
| delay = self._calculate_backoff_delay(attempt) |
| logger.info(f"Retrying in {delay:.2f} seconds...") |
| await asyncio.sleep(delay) |
|
|
| self._state = ConnectionState.ERROR |
| logger.error(f"Failed to connect after {self._max_retries} attempts") |
| return False |
|
|
| async def disconnect(self) -> None: |
| """Disconnect from MCP server.""" |
| |
| if self._health_check_task and not self._health_check_task.done(): |
| self._health_check_task.cancel() |
| with contextlib.suppress(asyncio.CancelledError): |
| await self._health_check_task |
|
|
| await self._client.disconnect() |
| self._state = ConnectionState.DISCONNECTED |
| logger.info("Disconnected from MCP server") |
|
|
| async def health_check(self) -> bool: |
| """ |
| Perform health check by listing tools. |
| |
| Returns: |
| True if healthy, False otherwise. |
| """ |
| try: |
| await self._client.list_tool_names() |
| self._last_successful_call = datetime.now() |
| self._consecutive_failures = 0 |
| self._circuit_breaker.record_success() |
| return True |
|
|
| except Exception as e: |
| self._consecutive_failures += 1 |
| self._circuit_breaker.record_failure() |
| logger.warning(f"Health check failed: {e}") |
| return False |
|
|
| async def get_tools( |
| self, |
| categories: Sequence[str] | None = None, |
| ) -> Sequence[FunctionTool]: |
| """ |
| Get tools with automatic reconnection on failure. |
| |
| Args: |
| categories: Optional list of categories to filter. |
| |
| Returns: |
| Sequence of FunctionTool objects. |
| |
| Raises: |
| MCPUnavailableError: If MCP is unavailable after reconnection attempt. |
| """ |
| if not self.is_available: |
| await self._attempt_reconnect() |
|
|
| if not self.is_available: |
| raise MCPUnavailableError( |
| "MCP server is unavailable", |
| reason=self._last_error, |
| ) |
|
|
| try: |
| tools: Sequence[FunctionTool] |
| if categories: |
| tools = await self._client.get_tools_by_category(categories) |
| else: |
| tools = await self._client.get_all_tools() |
|
|
| self._last_successful_call = datetime.now() |
| self._circuit_breaker.record_success() |
| return tools |
|
|
| except Exception as e: |
| self._circuit_breaker.record_failure() |
| self._last_error = str(e) |
| logger.error(f"Failed to get tools: {e}") |
|
|
| |
| if await self._attempt_reconnect(): |
| |
| if categories: |
| return await self._client.get_tools_by_category(categories) |
| return await self._client.get_all_tools() |
|
|
| raise MCPUnavailableError( |
| "Unable to get tools after reconnection", |
| reason=str(e), |
| ) from e |
|
|
| async def execute_tool( |
| self, |
| tool_name: str, |
| arguments: dict[str, Any], |
| ) -> Any: |
| """ |
| Execute a tool with connection management. |
| |
| Args: |
| tool_name: Name of the tool to call. |
| arguments: Tool arguments. |
| |
| Returns: |
| Tool result. |
| |
| Raises: |
| MCPCircuitBreakerOpenError: If circuit breaker is open. |
| MCPUnavailableError: If MCP is unavailable. |
| """ |
| |
| if not self._circuit_breaker.allow_request(): |
| retry_after = self._circuit_breaker.time_until_retry |
| raise MCPCircuitBreakerOpenError(retry_after_seconds=retry_after) |
|
|
| if not self.is_available: |
| await self._attempt_reconnect() |
|
|
| if not self.is_available: |
| raise MCPUnavailableError( |
| "MCP server is unavailable", |
| reason=self._last_error, |
| ) |
|
|
| try: |
| result = await self._client.call_tool(tool_name, arguments) |
| self._last_successful_call = datetime.now() |
| self._consecutive_failures = 0 |
| self._circuit_breaker.record_success() |
| return result |
|
|
| except Exception as e: |
| self._consecutive_failures += 1 |
| self._circuit_breaker.record_failure() |
| self._last_error = str(e) |
| raise |
|
|
| async def execute_with_fallback( |
| self, |
| tool_name: str, |
| arguments: dict[str, Any], |
| ) -> Any: |
| """ |
| Execute tool with automatic fallback on failure. |
| |
| If MCP fails and a fallback handler can handle the tool, |
| uses the fallback. Otherwise, raises the original error. |
| |
| Args: |
| tool_name: Name of the tool to call. |
| arguments: Tool arguments. |
| |
| Returns: |
| Tool result (from MCP or fallback). |
| |
| Raises: |
| MCPUnavailableError: If MCP fails and no fallback available. |
| """ |
| try: |
| return await self.execute_tool(tool_name, arguments) |
|
|
| except (MCPUnavailableError, MCPCircuitBreakerOpenError) as e: |
| |
| if self._fallback_handler.can_handle(tool_name): |
| logger.info(f"Using fallback for tool '{tool_name}'") |
| return await self._fallback_handler.handle(tool_name, arguments) |
|
|
| |
| raise MCPUnavailableError( |
| f"MCP unavailable and no fallback for '{tool_name}'", |
| reason=str(e), |
| ) from e |
|
|
| async def _attempt_reconnect(self) -> bool: |
| """ |
| Attempt reconnection with exponential backoff. |
| |
| Returns: |
| True if reconnection successful, False otherwise. |
| """ |
| if self._state == ConnectionState.RECONNECTING: |
| |
| return False |
|
|
| self._state = ConnectionState.RECONNECTING |
| logger.info("Attempting to reconnect to MCP server...") |
|
|
| for attempt in range(self._max_retries): |
| try: |
| await self._client.disconnect() |
| await self._client.connect() |
|
|
| self._state = ConnectionState.CONNECTED |
| self._consecutive_failures = 0 |
| self._last_successful_call = datetime.now() |
| self._circuit_breaker.reset() |
|
|
| logger.info("Reconnection successful") |
| return True |
|
|
| except Exception as e: |
| logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}") |
| delay = self._calculate_backoff_delay(attempt) |
| await asyncio.sleep(delay) |
|
|
| self._state = ConnectionState.ERROR |
| logger.error("All reconnection attempts failed") |
| return False |
|
|
| def _calculate_backoff_delay(self, attempt: int) -> float: |
| """ |
| Calculate delay with exponential backoff and jitter. |
| |
| Args: |
| attempt: Current attempt number (0-indexed). |
| |
| Returns: |
| Delay in seconds. |
| """ |
| |
| delay: float = self._retry_delay * (2**attempt) |
|
|
| |
| delay = min(delay, 30.0) |
|
|
| |
| jitter: float = delay * 0.1 * random.random() |
| delay += jitter |
|
|
| return delay |
|
|
| async def start_health_monitoring(self) -> None: |
| """Start background health check monitoring.""" |
| if self._health_check_task and not self._health_check_task.done(): |
| logger.warning("Health monitoring already running") |
| return |
|
|
| self._health_check_task = asyncio.create_task(self._health_check_loop()) |
| logger.info("Started health check monitoring") |
|
|
| async def stop_health_monitoring(self) -> None: |
| """Stop background health check monitoring.""" |
| if self._health_check_task and not self._health_check_task.done(): |
| self._health_check_task.cancel() |
| with contextlib.suppress(asyncio.CancelledError): |
| await self._health_check_task |
| logger.info("Stopped health check monitoring") |
|
|
| async def _health_check_loop(self) -> None: |
| """Background task for periodic health checks.""" |
| while True: |
| try: |
| await asyncio.sleep(self._health_check_interval) |
|
|
| if self._state == ConnectionState.CONNECTED: |
| healthy = await self.health_check() |
| if not healthy: |
| logger.warning("Health check failed, attempting reconnection") |
| await self._attempt_reconnect() |
|
|
| except asyncio.CancelledError: |
| break |
| except Exception as e: |
| logger.error(f"Health check loop error: {e}") |
|
|
| def get_unavailable_message(self) -> str: |
| """Get user-friendly message when MCP is unavailable.""" |
| return self._fallback_handler.get_unavailable_message() |
|
|
| def __repr__(self) -> str: |
| """String representation.""" |
| return ( |
| f"ConnectionManager(state={self._state.value}, " |
| f"available={self.is_available}, " |
| f"failures={self._consecutive_failures})" |
| ) |
|
|