async def _load_servers_from_config(self): """Load server configurations from config.""" servers_config = self.config.get('servers', {}) for server_id, config_dict in servers_config.items(): try: # Create ServerConfig with defaults server_config = ServerConfig( id=server_id, name=config_dict.get('name', server_id), url=config_dict.get('url'), auth_token=config_dict.get('auth_token'), protocol_version=config_dict.get('protocol_version', '2024-11-05'), timeout=int(config_dict.get('timeout', 30)), max_retries=int(config_dict.get('max_retries', 3)), retry_backoff=float(config_dict.get('retry_backoff', 1.5)), max_connections=int(config_dict.get('max_connections', 10)), health_check_interval=int(config_dict.get('health_check_interval', 60)), circuit_breaker_threshold=int(config_dict.get('circuit_breaker_threshold', 5)), circuit_breaker_timeout=int(config_dict.get('circuit_breaker_timeout', 300)), connection_pool_ttl=int(config_dict.get('connection_pool_ttl', 3600)), tls_verify=config_dict.get('tls_verify', True), headers=config_dict.get('headers', {}), metadata=config_dict.get('metadata', {}) ) # Register server await self.register_server(server_config) except Exception as e: structlog.get_logger().error( "Failed to load server config", server_id=server_id, error=str(e) ) async def register_server(self, server_config: ServerConfig): """Register an MCP server.""" server_id = server_config.id self.servers[server_id] = server_config # Create connection pool connection_pool = ConnectionPool(server_config, self.metrics) await connection_pool.initialize() self.connection_pools[server_id] = connection_pool # Register tools with tool manager await self.tool_manager.register_server(server_config, connection_pool) structlog.get_logger().info( "Server registered", server_id=server_id, name=server_config.name, url=server_config.url ) async def unregister_server(self, server_id: str): """Unregister an MCP server.""" if server_id in self.servers: # Shutdown connection pool connection_pool = self.connection_pools.pop(server_id, None) if connection_pool: await connection_pool.shutdown() # Remove server config self.servers.pop(server_id, None) # Clear cached tools await self.cache.invalidate_pattern(f"tools:{server_id}") structlog.get_logger().info("Server unregistered", server_id=server_id) async def invoke_tool(self, server_id: str, tool_name: str, parameters: Dict[str, Any], session_id: str) -> Dict[str, Any]: """Invoke a tool with session management and rate limiting.""" # Create tool invocation invocation = ToolInvocation( session_id=session_id, server_id=server_id, tool_name=tool_name, parameters=parameters ) try: # Check session exists session = await self.session_manager.get_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") # Check rate limiting allowed, wait_time = await self.session_manager.check_rate_limit(session_id) if not allowed: raise ValueError(f"Rate limit exceeded. Wait {wait_time:.1f} seconds") # Update session activity await self.session_manager.update_session_activity(session_id) session.total_requests += 1 # Get connection pool connection_pool = self.connection_pools.get(server_id) if not connection_pool: raise ValueError(f"Server {server_id} not found") # Invoke tool via tool manager result = await self.tool_manager.invoke_tool( server_id, tool_name, parameters, session_id ) # Update invocation invocation.status = "success" invocation.completed_at = datetime.utcnow() invocation.result = result # Update session success rate if invocation.error is None: session.success_rate = (session.success_rate * (session.total_requests - 1) + 1) / session.total_requests else: session.success_rate = (session.success_rate * (session.total_requests - 1)) / session.total_requests return result except Exception as e: # Update invocation invocation.status = "error" invocation.completed_at = datetime.utcnow() invocation.error = str(e) structlog.get_logger().error( "Tool invocation failed", server_id=server_id, tool_name=tool_name, session_id=session_id, error=str(e) ) raise async def get_server_status(self, server_id: str) -> Dict[str, Any]: """Get server status and health information.""" server_config = self.servers.get(server_id) if not server_config: return {"error": f"Server {server_id} not found"} connection_pool = self.connection_pools.get(server_id) if not connection_pool: return {"error": f"No connection pool for server {server_id}"} # Get connection pool stats stats = { "server_id": server_id, "name": server_config.name, "url": server_config.url, "state": connection_pool.circuit_breaker_state.value, "failures": connection_pool.circuit_breaker_failures, "pool_size": len(connection_pool.pool), "available_connections": connection_pool.available.qsize(), "in_use_connections": len(connection_pool.in_use), "last_failure": connection_pool.circuit_breaker_last_failure.isoformat() if connection_pool.circuit_breaker_last_failure else None, "health_status": "unknown" # Would be populated by health check } # Get tool count tools = await self.tool_manager.get_server_tools(server_id) stats["tool_count"] = len(tools) return stats async def get_all_servers_status(self) -> Dict[str, Dict[str, Any]]: """Get status for all servers.""" return { server_id: await self.get_server_status(server_id) for server_id in self.servers.keys() } async def _health_check_loop(self): """Background health check loop.""" while True: try: await asyncio.sleep(30) # Check every 30 seconds # Check all servers for server_id in self.servers.keys(): try: status = await self.get_server_status(server_id) # Process health check result except Exception as e: structlog.get_logger().error( "Health check failed", server_id=server_id, error=str(e) ) except asyncio.CancelledError: break except Exception as e: structlog.get_logger().error("Health check loop error", error=str(e)) async def _cache_cleanup_loop(self): """Background cache cleanup loop.""" while True: try: await asyncio.sleep(300) # Clean every 5 minutes await self.cache.cleanup_expired() except asyncio.CancelledError: break except Exception as e: structlog.get_logger().error("Cache cleanup loop error", error=str(e)) async def get_metrics(self) -> Dict[str, Any]: """Get comprehensive metrics.""" # Prometheus metrics are scraped externally # Return application-level metrics return { "servers": len(self.servers), "active_connections": sum(len(pool.pool) for pool in self.connection_pools.values()), "cache_stats": self.cache.get_stats(), "session_stats": await self.session_manager.get_session_stats(), "tool_stats": self.tool_manager.get_stats(), "uptime": time.time() } async def shutdown(self): """Shutdown the orchestrator.""" structlog.get_logger().info("Shutting down MCP Orchestrator") # Set shutdown event self.shutdown_event.set() # Cancel background tasks for task in self.background_tasks: task.cancel() # Wait for tasks to complete await asyncio.gather(*self.background_tasks, return_exceptions=True) # Shutdown components await self.cache.shutdown() await self.session_manager.shutdown() # Shutdown connection pools for pool in self.connection_pools.values(): await pool.shutdown() # Shutdown executor self.executor.shutdown(wait=True) structlog.get_logger().info("MCP Orchestrator shutdown complete") @property def is_healthy(self) -> bool: """Check if orchestrator is healthy.""" # Check if shutdown is in progress if self.shutdown_event.is_set(): return False # Check if servers are responding for pool in self.connection_pools.values(): if pool.circuit_breaker_state == CircuitBreakerState.OPEN: return False return True # ============================================================================= # Logging and Monitoring Setup # ============================================================================= def setup_logging(log_level: str = "INFO"): """Setup structured logging.""" structlog.configure( processors=[ structlog.stdlib.filter_by_level, structlog.stdlib.add_logger_name, structlog.stdlib.add_log_level, structlog.stdlib.PositionalArgumentsFormatter(), structlog.processors.TimeStamper(fmt="iso"), structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, structlog.processors.UnicodeDecoder(), structlog.processors.JSONRenderer() ], context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), wrapper_class=structlog.stdlib.BoundLogger, cache_logger_on_first_use=True, ) import logging logging.basicConfig( level=getattr(logging, log_level.upper()), format="%(message)s", stream=sys.stdout, ) def setup_metrics(port: int = 9090): """Setup Prometheus metrics.""" try: # Start metrics server from prometheus_client import start_http_server start_http_server(port) structlog.get_logger().info("Metrics server started", port=port) except ImportError: structlog.get_logger().warning("Prometheus client not available, skipping metrics server") # ============================================================================= # Main Entry Point # ============================================================================= async def main(): """Main entry point for the MCP Orchestrator.""" # Setup logging and metrics setup_logging() setup_metrics() # Create and initialize orchestrator orchestrator = MCPOrchestrator() await orchestrator.initialize() try: # Keep running until shutdown await orchestrator.shutdown_event.wait() except KeyboardInterrupt: structlog.get_logger().info("Received shutdown signal") finally: await orchestrator.shutdown() if __name__ == "__main__": asyncio.run(main())