Spaces:
Running
Running
File size: 4,996 Bytes
6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb 6afc01a b1d2ecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
MCP Executor - Stage 2
Executes parallel calls to MCP servers based on routing decisions
FIXED:
1. Proper async handling for FastAPI (no asyncio.run inside existing loop)
2. Fixed double-wrapping of server results
"""
from typing import Dict, Any
import asyncio
import inspect
class MCPExecutor:
"""
Executes MCP server calls based on routing decisions.
Properly handles async servers within FastAPI's event loop.
"""
def __init__(self, servers: Dict[str, Any]):
"""
Initialize executor with MCP server instances.
Args:
servers: Dict mapping server names to initialized server objects
"""
self.servers = servers
async def execute_parallel_async(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
"""
Execute MCP server calls in parallel (async version for FastAPI).
Args:
routing: Dict with server names as keys and True/False as values
location: Dict with 'latitude' and 'longitude' keys
Returns:
Dict mapping server names to their results
"""
results = {}
tasks = []
server_names = []
for server_name, should_query in routing.items():
if should_query and server_name in self.servers:
server = self.servers[server_name]
tasks.append(self._call_server(server, server_name, location))
server_names.append(server_name)
if not tasks:
return results
# Execute all tasks concurrently
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
for server_name, result in zip(server_names, task_results):
if isinstance(result, Exception):
results[server_name] = {
"data": None,
"status": "error",
"error": str(result)
}
print(f"✗ {server_name.upper()}: Error - {str(result)}")
else:
# FIX: Handle servers that return {"status": ..., "data": ...}
# Don't double-wrap!
if isinstance(result, dict) and "status" in result:
# Server already returned proper format
if result.get("status") == "success":
results[server_name] = {
"data": result.get("data"), # Extract actual data
"status": "success"
}
else:
results[server_name] = {
"data": None,
"status": "error",
"error": result.get("error", "Unknown error")
}
else:
# Server returned raw data
results[server_name] = {
"data": result,
"status": "success"
}
print(f"✓ {server_name.upper()}: Retrieved successfully")
return results
def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
"""
Execute MCP server calls in parallel (sync wrapper).
Detects if we're already in an async context and handles appropriately.
"""
try:
# Check if there's already a running event loop
loop = asyncio.get_running_loop()
# We're in an async context - need to use nest_asyncio or return a coroutine
# For FastAPI, the endpoint should be async and call execute_parallel_async directly
raise RuntimeError(
"execute_parallel called from async context. "
"Use 'await executor.execute_parallel_async()' instead."
)
except RuntimeError:
# No running loop - safe to use asyncio.run
return asyncio.run(self.execute_parallel_async(routing, location))
async def _call_server(self, server: Any, server_name: str, location: Dict[str, float]) -> Any:
"""
Call individual MCP server, handling both sync and async methods.
"""
lat = location['latitude']
lon = location['longitude']
if hasattr(server, 'get_data'):
method = getattr(server, 'get_data')
if inspect.iscoroutinefunction(method):
# Async method - await it
return await method(lat, lon)
else:
# Sync method - run in executor to not block
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, method, lat, lon)
else:
raise AttributeError(f"Server {server_name} has no get_data method") |