|
|
|
|
|
async def _load_servers_from_config(self): |
|
|
"""Load server configurations from config.""" |
|
|
servers_config = self.config.get('servers', {}) |
|
|
|
|
|
for server_id, config_dict in servers_config.items(): |
|
|
try: |
|
|
|
|
|
server_config = ServerConfig( |
|
|
id=server_id, |
|
|
name=config_dict.get('name', server_id), |
|
|
url=config_dict.get('url'), |
|
|
auth_token=config_dict.get('auth_token'), |
|
|
protocol_version=config_dict.get('protocol_version', '2024-11-05'), |
|
|
timeout=int(config_dict.get('timeout', 30)), |
|
|
max_retries=int(config_dict.get('max_retries', 3)), |
|
|
retry_backoff=float(config_dict.get('retry_backoff', 1.5)), |
|
|
max_connections=int(config_dict.get('max_connections', 10)), |
|
|
health_check_interval=int(config_dict.get('health_check_interval', 60)), |
|
|
circuit_breaker_threshold=int(config_dict.get('circuit_breaker_threshold', 5)), |
|
|
circuit_breaker_timeout=int(config_dict.get('circuit_breaker_timeout', 300)), |
|
|
connection_pool_ttl=int(config_dict.get('connection_pool_ttl', 3600)), |
|
|
tls_verify=config_dict.get('tls_verify', True), |
|
|
headers=config_dict.get('headers', {}), |
|
|
metadata=config_dict.get('metadata', {}) |
|
|
) |
|
|
|
|
|
|
|
|
await self.register_server(server_config) |
|
|
|
|
|
except Exception as e: |
|
|
structlog.get_logger().error( |
|
|
"Failed to load server config", |
|
|
server_id=server_id, |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
async def register_server(self, server_config: ServerConfig): |
|
|
"""Register an MCP server.""" |
|
|
server_id = server_config.id |
|
|
self.servers[server_id] = server_config |
|
|
|
|
|
|
|
|
connection_pool = ConnectionPool(server_config, self.metrics) |
|
|
await connection_pool.initialize() |
|
|
self.connection_pools[server_id] = connection_pool |
|
|
|
|
|
|
|
|
await self.tool_manager.register_server(server_config, connection_pool) |
|
|
|
|
|
structlog.get_logger().info( |
|
|
"Server registered", |
|
|
server_id=server_id, |
|
|
name=server_config.name, |
|
|
url=server_config.url |
|
|
) |
|
|
|
|
|
async def unregister_server(self, server_id: str): |
|
|
"""Unregister an MCP server.""" |
|
|
if server_id in self.servers: |
|
|
|
|
|
connection_pool = self.connection_pools.pop(server_id, None) |
|
|
if connection_pool: |
|
|
await connection_pool.shutdown() |
|
|
|
|
|
|
|
|
self.servers.pop(server_id, None) |
|
|
|
|
|
|
|
|
await self.cache.invalidate_pattern(f"tools:{server_id}") |
|
|
|
|
|
structlog.get_logger().info("Server unregistered", server_id=server_id) |
|
|
|
|
|
async def invoke_tool(self, server_id: str, tool_name: str, |
|
|
parameters: Dict[str, Any], session_id: str) -> Dict[str, Any]: |
|
|
"""Invoke a tool with session management and rate limiting.""" |
|
|
|
|
|
invocation = ToolInvocation( |
|
|
session_id=session_id, |
|
|
server_id=server_id, |
|
|
tool_name=tool_name, |
|
|
parameters=parameters |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
session = await self.session_manager.get_session(session_id) |
|
|
if not session: |
|
|
raise ValueError(f"Session {session_id} not found") |
|
|
|
|
|
|
|
|
allowed, wait_time = await self.session_manager.check_rate_limit(session_id) |
|
|
if not allowed: |
|
|
raise ValueError(f"Rate limit exceeded. Wait {wait_time:.1f} seconds") |
|
|
|
|
|
|
|
|
await self.session_manager.update_session_activity(session_id) |
|
|
session.total_requests += 1 |
|
|
|
|
|
|
|
|
connection_pool = self.connection_pools.get(server_id) |
|
|
if not connection_pool: |
|
|
raise ValueError(f"Server {server_id} not found") |
|
|
|
|
|
|
|
|
result = await self.tool_manager.invoke_tool( |
|
|
server_id, tool_name, parameters, session_id |
|
|
) |
|
|
|
|
|
|
|
|
invocation.status = "success" |
|
|
invocation.completed_at = datetime.utcnow() |
|
|
invocation.result = result |
|
|
|
|
|
|
|
|
if invocation.error is None: |
|
|
session.success_rate = (session.success_rate * (session.total_requests - 1) + 1) / session.total_requests |
|
|
else: |
|
|
session.success_rate = (session.success_rate * (session.total_requests - 1)) / session.total_requests |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
invocation.status = "error" |
|
|
invocation.completed_at = datetime.utcnow() |
|
|
invocation.error = str(e) |
|
|
|
|
|
structlog.get_logger().error( |
|
|
"Tool invocation failed", |
|
|
server_id=server_id, |
|
|
tool_name=tool_name, |
|
|
session_id=session_id, |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
raise |
|
|
|
|
|
async def get_server_status(self, server_id: str) -> Dict[str, Any]: |
|
|
"""Get server status and health information.""" |
|
|
server_config = self.servers.get(server_id) |
|
|
if not server_config: |
|
|
return {"error": f"Server {server_id} not found"} |
|
|
|
|
|
connection_pool = self.connection_pools.get(server_id) |
|
|
if not connection_pool: |
|
|
return {"error": f"No connection pool for server {server_id}"} |
|
|
|
|
|
|
|
|
stats = { |
|
|
"server_id": server_id, |
|
|
"name": server_config.name, |
|
|
"url": server_config.url, |
|
|
"state": connection_pool.circuit_breaker_state.value, |
|
|
"failures": connection_pool.circuit_breaker_failures, |
|
|
"pool_size": len(connection_pool.pool), |
|
|
"available_connections": connection_pool.available.qsize(), |
|
|
"in_use_connections": len(connection_pool.in_use), |
|
|
"last_failure": connection_pool.circuit_breaker_last_failure.isoformat() if connection_pool.circuit_breaker_last_failure else None, |
|
|
"health_status": "unknown" |
|
|
} |
|
|
|
|
|
|
|
|
tools = await self.tool_manager.get_server_tools(server_id) |
|
|
stats["tool_count"] = len(tools) |
|
|
|
|
|
return stats |
|
|
|
|
|
async def get_all_servers_status(self) -> Dict[str, Dict[str, Any]]: |
|
|
"""Get status for all servers.""" |
|
|
return { |
|
|
server_id: await self.get_server_status(server_id) |
|
|
for server_id in self.servers.keys() |
|
|
} |
|
|
|
|
|
async def _health_check_loop(self): |
|
|
"""Background health check loop.""" |
|
|
while True: |
|
|
try: |
|
|
await asyncio.sleep(30) |
|
|
|
|
|
|
|
|
for server_id in self.servers.keys(): |
|
|
try: |
|
|
status = await self.get_server_status(server_id) |
|
|
|
|
|
except Exception as e: |
|
|
structlog.get_logger().error( |
|
|
"Health check failed", |
|
|
server_id=server_id, |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
break |
|
|
except Exception as e: |
|
|
structlog.get_logger().error("Health check loop error", error=str(e)) |
|
|
|
|
|
async def _cache_cleanup_loop(self): |
|
|
"""Background cache cleanup loop.""" |
|
|
while True: |
|
|
try: |
|
|
await asyncio.sleep(300) |
|
|
await self.cache.cleanup_expired() |
|
|
except asyncio.CancelledError: |
|
|
break |
|
|
except Exception as e: |
|
|
structlog.get_logger().error("Cache cleanup loop error", error=str(e)) |
|
|
|
|
|
async def get_metrics(self) -> Dict[str, Any]: |
|
|
"""Get comprehensive metrics.""" |
|
|
|
|
|
|
|
|
return { |
|
|
"servers": len(self.servers), |
|
|
"active_connections": sum(len(pool.pool) for pool in self.connection_pools.values()), |
|
|
"cache_stats": self.cache.get_stats(), |
|
|
"session_stats": await self.session_manager.get_session_stats(), |
|
|
"tool_stats": self.tool_manager.get_stats(), |
|
|
"uptime": time.time() |
|
|
} |
|
|
|
|
|
async def shutdown(self): |
|
|
"""Shutdown the orchestrator.""" |
|
|
structlog.get_logger().info("Shutting down MCP Orchestrator") |
|
|
|
|
|
|
|
|
self.shutdown_event.set() |
|
|
|
|
|
|
|
|
for task in self.background_tasks: |
|
|
task.cancel() |
|
|
|
|
|
|
|
|
await asyncio.gather(*self.background_tasks, return_exceptions=True) |
|
|
|
|
|
|
|
|
await self.cache.shutdown() |
|
|
await self.session_manager.shutdown() |
|
|
|
|
|
|
|
|
for pool in self.connection_pools.values(): |
|
|
await pool.shutdown() |
|
|
|
|
|
|
|
|
self.executor.shutdown(wait=True) |
|
|
|
|
|
structlog.get_logger().info("MCP Orchestrator shutdown complete") |
|
|
|
|
|
@property |
|
|
def is_healthy(self) -> bool: |
|
|
"""Check if orchestrator is healthy.""" |
|
|
|
|
|
if self.shutdown_event.is_set(): |
|
|
return False |
|
|
|
|
|
|
|
|
for pool in self.connection_pools.values(): |
|
|
if pool.circuit_breaker_state == CircuitBreakerState.OPEN: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_logging(log_level: str = "INFO"): |
|
|
"""Setup structured logging.""" |
|
|
structlog.configure( |
|
|
processors=[ |
|
|
structlog.stdlib.filter_by_level, |
|
|
structlog.stdlib.add_logger_name, |
|
|
structlog.stdlib.add_log_level, |
|
|
structlog.stdlib.PositionalArgumentsFormatter(), |
|
|
structlog.processors.TimeStamper(fmt="iso"), |
|
|
structlog.processors.StackInfoRenderer(), |
|
|
structlog.processors.format_exc_info, |
|
|
structlog.processors.UnicodeDecoder(), |
|
|
structlog.processors.JSONRenderer() |
|
|
], |
|
|
context_class=dict, |
|
|
logger_factory=structlog.stdlib.LoggerFactory(), |
|
|
wrapper_class=structlog.stdlib.BoundLogger, |
|
|
cache_logger_on_first_use=True, |
|
|
) |
|
|
|
|
|
import logging |
|
|
logging.basicConfig( |
|
|
level=getattr(logging, log_level.upper()), |
|
|
format="%(message)s", |
|
|
stream=sys.stdout, |
|
|
) |
|
|
|
|
|
|
|
|
def setup_metrics(port: int = 9090): |
|
|
"""Setup Prometheus metrics.""" |
|
|
try: |
|
|
|
|
|
from prometheus_client import start_http_server |
|
|
start_http_server(port) |
|
|
structlog.get_logger().info("Metrics server started", port=port) |
|
|
except ImportError: |
|
|
structlog.get_logger().warning("Prometheus client not available, skipping metrics server") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Main entry point for the MCP Orchestrator.""" |
|
|
|
|
|
setup_logging() |
|
|
setup_metrics() |
|
|
|
|
|
|
|
|
orchestrator = MCPOrchestrator() |
|
|
await orchestrator.initialize() |
|
|
|
|
|
try: |
|
|
|
|
|
await orchestrator.shutdown_event.wait() |
|
|
except KeyboardInterrupt: |
|
|
structlog.get_logger().info("Received shutdown signal") |
|
|
finally: |
|
|
await orchestrator.shutdown() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|