rajkumarrawal's picture
Initial commit
2ec0d39
async def _load_servers_from_config(self):
"""Load server configurations from config."""
servers_config = self.config.get('servers', {})
for server_id, config_dict in servers_config.items():
try:
# Create ServerConfig with defaults
server_config = ServerConfig(
id=server_id,
name=config_dict.get('name', server_id),
url=config_dict.get('url'),
auth_token=config_dict.get('auth_token'),
protocol_version=config_dict.get('protocol_version', '2024-11-05'),
timeout=int(config_dict.get('timeout', 30)),
max_retries=int(config_dict.get('max_retries', 3)),
retry_backoff=float(config_dict.get('retry_backoff', 1.5)),
max_connections=int(config_dict.get('max_connections', 10)),
health_check_interval=int(config_dict.get('health_check_interval', 60)),
circuit_breaker_threshold=int(config_dict.get('circuit_breaker_threshold', 5)),
circuit_breaker_timeout=int(config_dict.get('circuit_breaker_timeout', 300)),
connection_pool_ttl=int(config_dict.get('connection_pool_ttl', 3600)),
tls_verify=config_dict.get('tls_verify', True),
headers=config_dict.get('headers', {}),
metadata=config_dict.get('metadata', {})
)
# Register server
await self.register_server(server_config)
except Exception as e:
structlog.get_logger().error(
"Failed to load server config",
server_id=server_id,
error=str(e)
)
async def register_server(self, server_config: ServerConfig):
"""Register an MCP server."""
server_id = server_config.id
self.servers[server_id] = server_config
# Create connection pool
connection_pool = ConnectionPool(server_config, self.metrics)
await connection_pool.initialize()
self.connection_pools[server_id] = connection_pool
# Register tools with tool manager
await self.tool_manager.register_server(server_config, connection_pool)
structlog.get_logger().info(
"Server registered",
server_id=server_id,
name=server_config.name,
url=server_config.url
)
async def unregister_server(self, server_id: str):
"""Unregister an MCP server."""
if server_id in self.servers:
# Shutdown connection pool
connection_pool = self.connection_pools.pop(server_id, None)
if connection_pool:
await connection_pool.shutdown()
# Remove server config
self.servers.pop(server_id, None)
# Clear cached tools
await self.cache.invalidate_pattern(f"tools:{server_id}")
structlog.get_logger().info("Server unregistered", server_id=server_id)
async def invoke_tool(self, server_id: str, tool_name: str,
parameters: Dict[str, Any], session_id: str) -> Dict[str, Any]:
"""Invoke a tool with session management and rate limiting."""
# Create tool invocation
invocation = ToolInvocation(
session_id=session_id,
server_id=server_id,
tool_name=tool_name,
parameters=parameters
)
try:
# Check session exists
session = await self.session_manager.get_session(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
# Check rate limiting
allowed, wait_time = await self.session_manager.check_rate_limit(session_id)
if not allowed:
raise ValueError(f"Rate limit exceeded. Wait {wait_time:.1f} seconds")
# Update session activity
await self.session_manager.update_session_activity(session_id)
session.total_requests += 1
# Get connection pool
connection_pool = self.connection_pools.get(server_id)
if not connection_pool:
raise ValueError(f"Server {server_id} not found")
# Invoke tool via tool manager
result = await self.tool_manager.invoke_tool(
server_id, tool_name, parameters, session_id
)
# Update invocation
invocation.status = "success"
invocation.completed_at = datetime.utcnow()
invocation.result = result
# Update session success rate
if invocation.error is None:
session.success_rate = (session.success_rate * (session.total_requests - 1) + 1) / session.total_requests
else:
session.success_rate = (session.success_rate * (session.total_requests - 1)) / session.total_requests
return result
except Exception as e:
# Update invocation
invocation.status = "error"
invocation.completed_at = datetime.utcnow()
invocation.error = str(e)
structlog.get_logger().error(
"Tool invocation failed",
server_id=server_id,
tool_name=tool_name,
session_id=session_id,
error=str(e)
)
raise
async def get_server_status(self, server_id: str) -> Dict[str, Any]:
"""Get server status and health information."""
server_config = self.servers.get(server_id)
if not server_config:
return {"error": f"Server {server_id} not found"}
connection_pool = self.connection_pools.get(server_id)
if not connection_pool:
return {"error": f"No connection pool for server {server_id}"}
# Get connection pool stats
stats = {
"server_id": server_id,
"name": server_config.name,
"url": server_config.url,
"state": connection_pool.circuit_breaker_state.value,
"failures": connection_pool.circuit_breaker_failures,
"pool_size": len(connection_pool.pool),
"available_connections": connection_pool.available.qsize(),
"in_use_connections": len(connection_pool.in_use),
"last_failure": connection_pool.circuit_breaker_last_failure.isoformat() if connection_pool.circuit_breaker_last_failure else None,
"health_status": "unknown" # Would be populated by health check
}
# Get tool count
tools = await self.tool_manager.get_server_tools(server_id)
stats["tool_count"] = len(tools)
return stats
async def get_all_servers_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status for all servers."""
return {
server_id: await self.get_server_status(server_id)
for server_id in self.servers.keys()
}
async def _health_check_loop(self):
"""Background health check loop."""
while True:
try:
await asyncio.sleep(30) # Check every 30 seconds
# Check all servers
for server_id in self.servers.keys():
try:
status = await self.get_server_status(server_id)
# Process health check result
except Exception as e:
structlog.get_logger().error(
"Health check failed",
server_id=server_id,
error=str(e)
)
except asyncio.CancelledError:
break
except Exception as e:
structlog.get_logger().error("Health check loop error", error=str(e))
async def _cache_cleanup_loop(self):
"""Background cache cleanup loop."""
while True:
try:
await asyncio.sleep(300) # Clean every 5 minutes
await self.cache.cleanup_expired()
except asyncio.CancelledError:
break
except Exception as e:
structlog.get_logger().error("Cache cleanup loop error", error=str(e))
async def get_metrics(self) -> Dict[str, Any]:
"""Get comprehensive metrics."""
# Prometheus metrics are scraped externally
# Return application-level metrics
return {
"servers": len(self.servers),
"active_connections": sum(len(pool.pool) for pool in self.connection_pools.values()),
"cache_stats": self.cache.get_stats(),
"session_stats": await self.session_manager.get_session_stats(),
"tool_stats": self.tool_manager.get_stats(),
"uptime": time.time()
}
async def shutdown(self):
"""Shutdown the orchestrator."""
structlog.get_logger().info("Shutting down MCP Orchestrator")
# Set shutdown event
self.shutdown_event.set()
# Cancel background tasks
for task in self.background_tasks:
task.cancel()
# Wait for tasks to complete
await asyncio.gather(*self.background_tasks, return_exceptions=True)
# Shutdown components
await self.cache.shutdown()
await self.session_manager.shutdown()
# Shutdown connection pools
for pool in self.connection_pools.values():
await pool.shutdown()
# Shutdown executor
self.executor.shutdown(wait=True)
structlog.get_logger().info("MCP Orchestrator shutdown complete")
@property
def is_healthy(self) -> bool:
"""Check if orchestrator is healthy."""
# Check if shutdown is in progress
if self.shutdown_event.is_set():
return False
# Check if servers are responding
for pool in self.connection_pools.values():
if pool.circuit_breaker_state == CircuitBreakerState.OPEN:
return False
return True
# =============================================================================
# Logging and Monitoring Setup
# =============================================================================
def setup_logging(log_level: str = "INFO"):
"""Setup structured logging."""
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
import logging
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format="%(message)s",
stream=sys.stdout,
)
def setup_metrics(port: int = 9090):
"""Setup Prometheus metrics."""
try:
# Start metrics server
from prometheus_client import start_http_server
start_http_server(port)
structlog.get_logger().info("Metrics server started", port=port)
except ImportError:
structlog.get_logger().warning("Prometheus client not available, skipping metrics server")
# =============================================================================
# Main Entry Point
# =============================================================================
async def main():
"""Main entry point for the MCP Orchestrator."""
# Setup logging and metrics
setup_logging()
setup_metrics()
# Create and initialize orchestrator
orchestrator = MCPOrchestrator()
await orchestrator.initialize()
try:
# Keep running until shutdown
await orchestrator.shutdown_event.wait()
except KeyboardInterrupt:
structlog.get_logger().info("Received shutdown signal")
finally:
await orchestrator.shutdown()
if __name__ == "__main__":
asyncio.run(main())