mcp-alert-generator / src /executor.py
aakashdg's picture
cleanup
9bdf3bb verified
"""
MCP Executor - Stage 2
Executes parallel calls to MCP servers based on routing decisions
FIXED:
1. Proper async handling for FastAPI (no asyncio.run inside existing loop)
2. Fixed double-wrapping of server results
"""
from typing import Dict, Any
import asyncio
import inspect
class MCPExecutor:
"""
Executes MCP server calls based on routing decisions.
Properly handles async servers within FastAPI's event loop.
"""
def __init__(self, servers: Dict[str, Any]):
"""
Initialize executor with MCP server instances.
Args:
servers: Dict mapping server names to initialized server objects
"""
self.servers = servers
async def execute_parallel_async(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
"""
Execute MCP server calls in parallel (async version for FastAPI).
Args:
routing: Dict with server names as keys and True/False as values
location: Dict with 'latitude' and 'longitude' keys
Returns:
Dict mapping server names to their results
"""
results = {}
tasks = []
server_names = []
for server_name, should_query in routing.items():
if should_query and server_name in self.servers:
server = self.servers[server_name]
tasks.append(self._call_server(server, server_name, location))
server_names.append(server_name)
if not tasks:
return results
# Execute all tasks concurrently
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
for server_name, result in zip(server_names, task_results):
if isinstance(result, Exception):
results[server_name] = {
"data": None,
"status": "error",
"error": str(result)
}
print(f"βœ— {server_name.upper()}: Error - {str(result)}")
else:
# FIX: Handle servers that return {"status": ..., "data": ...}
# Don't double-wrap!
if isinstance(result, dict) and "status" in result:
# Server already returned proper format
if result.get("status") == "success":
results[server_name] = {
"data": result.get("data"), # Extract actual data
"status": "success"
}
else:
results[server_name] = {
"data": None,
"status": "error",
"error": result.get("error", "Unknown error")
}
else:
# Server returned raw data
results[server_name] = {
"data": result,
"status": "success"
}
print(f"βœ“ {server_name.upper()}: Retrieved successfully")
return results
def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
"""
Execute MCP server calls in parallel (sync wrapper).
Detects if we're already in an async context and handles appropriately.
"""
try:
# Check if there's already a running event loop
loop = asyncio.get_running_loop()
# We're in an async context - need to use nest_asyncio or return a coroutine
# For FastAPI, the endpoint should be async and call execute_parallel_async directly
raise RuntimeError(
"execute_parallel called from async context. "
"Use 'await executor.execute_parallel_async()' instead."
)
except RuntimeError:
# No running loop - safe to use asyncio.run
return asyncio.run(self.execute_parallel_async(routing, location))
async def _call_server(self, server: Any, server_name: str, location: Dict[str, float]) -> Any:
"""
Call individual MCP server, handling both sync and async methods.
"""
lat = location['latitude']
lon = location['longitude']
if hasattr(server, 'get_data'):
method = getattr(server, 'get_data')
if inspect.iscoroutinefunction(method):
# Async method - await it
return await method(lat, lon)
else:
# Sync method - run in executor to not block
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, method, lat, lon)
else:
raise AttributeError(f"Server {server_name} has no get_data method")