Spaces:
Sleeping
Sleeping
fix (asyncio inside FastAPI bug)
Browse files- src/executor.py +71 -154
src/executor.py
CHANGED
|
@@ -268,15 +268,15 @@
|
|
| 268 |
# return results
|
| 269 |
|
| 270 |
# return results
|
| 271 |
-
|
| 272 |
"""
|
| 273 |
MCP Executor - Stage 2
|
| 274 |
Executes parallel calls to MCP servers based on routing decisions
|
| 275 |
-
FIXED:
|
|
|
|
|
|
|
| 276 |
"""
|
| 277 |
|
| 278 |
from typing import Dict, Any
|
| 279 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 280 |
import asyncio
|
| 281 |
import inspect
|
| 282 |
|
|
@@ -284,8 +284,7 @@ import inspect
|
|
| 284 |
class MCPExecutor:
|
| 285 |
"""
|
| 286 |
Executes MCP server calls based on routing decisions.
|
| 287 |
-
|
| 288 |
-
Handles both sync and async server methods safely.
|
| 289 |
"""
|
| 290 |
|
| 291 |
def __init__(self, servers: Dict[str, Any]):
|
|
@@ -294,110 +293,38 @@ class MCPExecutor:
|
|
| 294 |
|
| 295 |
Args:
|
| 296 |
servers: Dict mapping server names to initialized server objects
|
| 297 |
-
e.g., {"weather": WeatherServer(), "soil": SoilPropertiesServer(), ...}
|
| 298 |
"""
|
| 299 |
self.servers = servers
|
| 300 |
|
| 301 |
-
def
|
| 302 |
"""
|
| 303 |
-
Execute MCP server calls in parallel
|
| 304 |
|
| 305 |
Args:
|
| 306 |
-
routing:
|
| 307 |
location: Dict with 'latitude' and 'longitude' keys
|
| 308 |
|
| 309 |
Returns:
|
| 310 |
-
Dict mapping server names to their results
|
| 311 |
"""
|
| 312 |
results = {}
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
# Separate sync and async servers
|
| 316 |
-
sync_tasks = []
|
| 317 |
-
async_tasks = []
|
| 318 |
|
| 319 |
for server_name, should_query in routing.items():
|
| 320 |
if should_query and server_name in self.servers:
|
| 321 |
server = self.servers[server_name]
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
"server": server,
|
| 325 |
-
"location": location
|
| 326 |
-
}
|
| 327 |
-
|
| 328 |
-
# Check if server method is async
|
| 329 |
-
if hasattr(server, 'get_data'):
|
| 330 |
-
method = getattr(server, 'get_data')
|
| 331 |
-
if inspect.iscoroutinefunction(method):
|
| 332 |
-
async_tasks.append(task)
|
| 333 |
-
else:
|
| 334 |
-
sync_tasks.append(task)
|
| 335 |
-
else:
|
| 336 |
-
sync_tasks.append(task)
|
| 337 |
-
|
| 338 |
-
# Execute sync servers in parallel with ThreadPoolExecutor
|
| 339 |
-
if sync_tasks:
|
| 340 |
-
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 341 |
-
futures = {
|
| 342 |
-
executor.submit(self._call_sync_server, task): task
|
| 343 |
-
for task in sync_tasks
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
for future in as_completed(futures):
|
| 347 |
-
task = futures[future]
|
| 348 |
-
server_name = task["server_name"]
|
| 349 |
-
|
| 350 |
-
try:
|
| 351 |
-
result = future.result(timeout=30)
|
| 352 |
-
results[server_name] = {
|
| 353 |
-
"data": result,
|
| 354 |
-
"status": "success"
|
| 355 |
-
}
|
| 356 |
-
print(f"✓ {server_name.upper()}: Retrieved successfully")
|
| 357 |
-
except Exception as e:
|
| 358 |
-
results[server_name] = {
|
| 359 |
-
"data": None,
|
| 360 |
-
"status": "error",
|
| 361 |
-
"error": str(e)
|
| 362 |
-
}
|
| 363 |
-
print(f"✗ {server_name.upper()}: Error - {str(e)}")
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
try:
|
| 368 |
-
async_results = asyncio.run(self._execute_async_batch(async_tasks))
|
| 369 |
-
results.update(async_results)
|
| 370 |
-
except Exception as e:
|
| 371 |
-
# If batch fails, mark all as failed
|
| 372 |
-
for task in async_tasks:
|
| 373 |
-
results[task["server_name"]] = {
|
| 374 |
-
"data": None,
|
| 375 |
-
"status": "error",
|
| 376 |
-
"error": f"Async batch execution failed: {str(e)}"
|
| 377 |
-
}
|
| 378 |
-
print(f"✗ {task['server_name'].upper()}: Async batch error")
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
async def _execute_async_batch(self, tasks: list) -> Dict[str, Any]:
|
| 383 |
-
"""
|
| 384 |
-
Execute multiple async server calls concurrently in a single event loop.
|
| 385 |
-
This is safer than creating multiple event loops.
|
| 386 |
-
"""
|
| 387 |
-
results = {}
|
| 388 |
-
|
| 389 |
-
# Create async tasks for all servers
|
| 390 |
-
async_calls = []
|
| 391 |
-
for task in tasks:
|
| 392 |
-
async_calls.append(self._call_async_server(task))
|
| 393 |
-
|
| 394 |
-
# Execute all async calls concurrently
|
| 395 |
-
task_results = await asyncio.gather(*async_calls, return_exceptions=True)
|
| 396 |
|
| 397 |
# Process results
|
| 398 |
-
for
|
| 399 |
-
server_name = task["server_name"]
|
| 400 |
-
|
| 401 |
if isinstance(result, Exception):
|
| 402 |
results[server_name] = {
|
| 403 |
"data": None,
|
|
@@ -406,76 +333,66 @@ class MCPExecutor:
|
|
| 406 |
}
|
| 407 |
print(f"✗ {server_name.upper()}: Error - {str(result)}")
|
| 408 |
else:
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
"""Call individual async MCP server"""
|
| 419 |
-
server = task["server"]
|
| 420 |
-
location = task["location"]
|
| 421 |
-
|
| 422 |
-
if hasattr(server, 'get_data'):
|
| 423 |
-
return await server.get_data(location['latitude'], location['longitude'])
|
| 424 |
-
else:
|
| 425 |
-
raise AttributeError(f"Server {task['server_name']} has no get_data method")
|
| 426 |
-
|
| 427 |
-
def _call_sync_server(self, task: Dict[str, Any]) -> Any:
|
| 428 |
-
"""Call individual sync MCP server"""
|
| 429 |
-
server = task["server"]
|
| 430 |
-
location = task["location"]
|
| 431 |
-
|
| 432 |
-
if hasattr(server, 'get_data'):
|
| 433 |
-
return server.get_data(location['latitude'], location['longitude'])
|
| 434 |
-
elif hasattr(server, 'query'):
|
| 435 |
-
return server.query(location)
|
| 436 |
-
elif hasattr(server, 'fetch_data'):
|
| 437 |
-
return server.fetch_data(location['latitude'], location['longitude'])
|
| 438 |
-
else:
|
| 439 |
-
raise AttributeError(f"Server {task['server_name']} has no compatible query method")
|
| 440 |
-
|
| 441 |
-
def execute_sequential(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
|
| 442 |
-
"""
|
| 443 |
-
Execute MCP server calls sequentially (fallback if parallel fails).
|
| 444 |
-
"""
|
| 445 |
-
results = {}
|
| 446 |
-
|
| 447 |
-
for server_name, should_query in routing.items():
|
| 448 |
-
if should_query and server_name in self.servers:
|
| 449 |
-
try:
|
| 450 |
-
server = self.servers[server_name]
|
| 451 |
-
|
| 452 |
-
# Check if async
|
| 453 |
-
if hasattr(server, 'get_data') and inspect.iscoroutinefunction(server.get_data):
|
| 454 |
-
# Run async method
|
| 455 |
-
result = asyncio.run(server.get_data(location['latitude'], location['longitude']))
|
| 456 |
else:
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
"
|
| 460 |
-
"
|
| 461 |
-
"location": location
|
| 462 |
}
|
| 463 |
-
|
| 464 |
-
|
| 465 |
results[server_name] = {
|
| 466 |
"data": result,
|
| 467 |
"status": "success"
|
| 468 |
}
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
except Exception as e:
|
| 472 |
-
results[server_name] = {
|
| 473 |
-
"data": None,
|
| 474 |
-
"status": "error",
|
| 475 |
-
"error": str(e)
|
| 476 |
-
}
|
| 477 |
-
print(f"✗ {server_name.upper()}: Error - {str(e)}")
|
| 478 |
|
| 479 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
# return results
|
| 269 |
|
| 270 |
# return results
|
|
|
|
| 271 |
"""
|
| 272 |
MCP Executor - Stage 2
|
| 273 |
Executes parallel calls to MCP servers based on routing decisions
|
| 274 |
+
FIXED:
|
| 275 |
+
1. Proper async handling for FastAPI (no asyncio.run inside existing loop)
|
| 276 |
+
2. Fixed double-wrapping of server results
|
| 277 |
"""
|
| 278 |
|
| 279 |
from typing import Dict, Any
|
|
|
|
| 280 |
import asyncio
|
| 281 |
import inspect
|
| 282 |
|
|
|
|
| 284 |
class MCPExecutor:
|
| 285 |
"""
|
| 286 |
Executes MCP server calls based on routing decisions.
|
| 287 |
+
Properly handles async servers within FastAPI's event loop.
|
|
|
|
| 288 |
"""
|
| 289 |
|
| 290 |
def __init__(self, servers: Dict[str, Any]):
|
|
|
|
| 293 |
|
| 294 |
Args:
|
| 295 |
servers: Dict mapping server names to initialized server objects
|
|
|
|
| 296 |
"""
|
| 297 |
self.servers = servers
|
| 298 |
|
| 299 |
+
async def execute_parallel_async(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
|
| 300 |
"""
|
| 301 |
+
Execute MCP server calls in parallel (async version for FastAPI).
|
| 302 |
|
| 303 |
Args:
|
| 304 |
+
routing: Dict with server names as keys and True/False as values
|
| 305 |
location: Dict with 'latitude' and 'longitude' keys
|
| 306 |
|
| 307 |
Returns:
|
| 308 |
+
Dict mapping server names to their results
|
| 309 |
"""
|
| 310 |
results = {}
|
| 311 |
+
tasks = []
|
| 312 |
+
server_names = []
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
for server_name, should_query in routing.items():
|
| 315 |
if should_query and server_name in self.servers:
|
| 316 |
server = self.servers[server_name]
|
| 317 |
+
tasks.append(self._call_server(server, server_name, location))
|
| 318 |
+
server_names.append(server_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
if not tasks:
|
| 321 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
+
# Execute all tasks concurrently
|
| 324 |
+
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
# Process results
|
| 327 |
+
for server_name, result in zip(server_names, task_results):
|
|
|
|
|
|
|
| 328 |
if isinstance(result, Exception):
|
| 329 |
results[server_name] = {
|
| 330 |
"data": None,
|
|
|
|
| 333 |
}
|
| 334 |
print(f"✗ {server_name.upper()}: Error - {str(result)}")
|
| 335 |
else:
|
| 336 |
+
# FIX: Handle servers that return {"status": ..., "data": ...}
|
| 337 |
+
# Don't double-wrap!
|
| 338 |
+
if isinstance(result, dict) and "status" in result:
|
| 339 |
+
# Server already returned proper format
|
| 340 |
+
if result.get("status") == "success":
|
| 341 |
+
results[server_name] = {
|
| 342 |
+
"data": result.get("data"), # Extract actual data
|
| 343 |
+
"status": "success"
|
| 344 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
else:
|
| 346 |
+
results[server_name] = {
|
| 347 |
+
"data": None,
|
| 348 |
+
"status": "error",
|
| 349 |
+
"error": result.get("error", "Unknown error")
|
|
|
|
| 350 |
}
|
| 351 |
+
else:
|
| 352 |
+
# Server returned raw data
|
| 353 |
results[server_name] = {
|
| 354 |
"data": result,
|
| 355 |
"status": "success"
|
| 356 |
}
|
| 357 |
+
print(f"✓ {server_name.upper()}: Retrieved successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
return results
|
| 360 |
+
|
| 361 |
+
def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
|
| 362 |
+
"""
|
| 363 |
+
Execute MCP server calls in parallel (sync wrapper).
|
| 364 |
|
| 365 |
+
Detects if we're already in an async context and handles appropriately.
|
| 366 |
+
"""
|
| 367 |
+
try:
|
| 368 |
+
# Check if there's already a running event loop
|
| 369 |
+
loop = asyncio.get_running_loop()
|
| 370 |
+
# We're in an async context - need to use nest_asyncio or return a coroutine
|
| 371 |
+
# For FastAPI, the endpoint should be async and call execute_parallel_async directly
|
| 372 |
+
raise RuntimeError(
|
| 373 |
+
"execute_parallel called from async context. "
|
| 374 |
+
"Use 'await executor.execute_parallel_async()' instead."
|
| 375 |
+
)
|
| 376 |
+
except RuntimeError:
|
| 377 |
+
# No running loop - safe to use asyncio.run
|
| 378 |
+
return asyncio.run(self.execute_parallel_async(routing, location))
|
| 379 |
+
|
| 380 |
+
async def _call_server(self, server: Any, server_name: str, location: Dict[str, float]) -> Any:
|
| 381 |
+
"""
|
| 382 |
+
Call individual MCP server, handling both sync and async methods.
|
| 383 |
+
"""
|
| 384 |
+
lat = location['latitude']
|
| 385 |
+
lon = location['longitude']
|
| 386 |
+
|
| 387 |
+
if hasattr(server, 'get_data'):
|
| 388 |
+
method = getattr(server, 'get_data')
|
| 389 |
+
|
| 390 |
+
if inspect.iscoroutinefunction(method):
|
| 391 |
+
# Async method - await it
|
| 392 |
+
return await method(lat, lon)
|
| 393 |
+
else:
|
| 394 |
+
# Sync method - run in executor to not block
|
| 395 |
+
loop = asyncio.get_event_loop()
|
| 396 |
+
return await loop.run_in_executor(None, method, lat, lon)
|
| 397 |
+
else:
|
| 398 |
+
raise AttributeError(f"Server {server_name} has no get_data method")
|