aakashdg commited on
Commit
b1d2ecb
·
verified ·
1 Parent(s): ddc122d

fix (asyncio inside FastAPI bug)

Browse files
Files changed (1) hide show
  1. src/executor.py +71 -154
src/executor.py CHANGED
@@ -268,15 +268,15 @@
268
  # return results
269
 
270
  # return results
271
-
272
  """
273
  MCP Executor - Stage 2
274
  Executes parallel calls to MCP servers based on routing decisions
275
- FIXED: Simpler async handling to prevent deadlocks
 
 
276
  """
277
 
278
  from typing import Dict, Any
279
- from concurrent.futures import ThreadPoolExecutor, as_completed
280
  import asyncio
281
  import inspect
282
 
@@ -284,8 +284,7 @@ import inspect
284
  class MCPExecutor:
285
  """
286
  Executes MCP server calls based on routing decisions.
287
- Integrates with existing server implementations in src/servers/
288
- Handles both sync and async server methods safely.
289
  """
290
 
291
  def __init__(self, servers: Dict[str, Any]):
@@ -294,110 +293,38 @@ class MCPExecutor:
294
 
295
  Args:
296
  servers: Dict mapping server names to initialized server objects
297
- e.g., {"weather": WeatherServer(), "soil": SoilPropertiesServer(), ...}
298
  """
299
  self.servers = servers
300
 
301
- def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
302
  """
303
- Execute MCP server calls in parallel based on routing.
304
 
305
  Args:
306
- routing: Simple dict with server names as keys and True/False as values
307
  location: Dict with 'latitude' and 'longitude' keys
308
 
309
  Returns:
310
- Dict mapping server names to their results with metadata
311
  """
312
  results = {}
313
-
314
- # For async servers, we need to run them differently
315
- # Separate sync and async servers
316
- sync_tasks = []
317
- async_tasks = []
318
 
319
  for server_name, should_query in routing.items():
320
  if should_query and server_name in self.servers:
321
  server = self.servers[server_name]
322
- task = {
323
- "server_name": server_name,
324
- "server": server,
325
- "location": location
326
- }
327
-
328
- # Check if server method is async
329
- if hasattr(server, 'get_data'):
330
- method = getattr(server, 'get_data')
331
- if inspect.iscoroutinefunction(method):
332
- async_tasks.append(task)
333
- else:
334
- sync_tasks.append(task)
335
- else:
336
- sync_tasks.append(task)
337
-
338
- # Execute sync servers in parallel with ThreadPoolExecutor
339
- if sync_tasks:
340
- with ThreadPoolExecutor(max_workers=5) as executor:
341
- futures = {
342
- executor.submit(self._call_sync_server, task): task
343
- for task in sync_tasks
344
- }
345
-
346
- for future in as_completed(futures):
347
- task = futures[future]
348
- server_name = task["server_name"]
349
-
350
- try:
351
- result = future.result(timeout=30)
352
- results[server_name] = {
353
- "data": result,
354
- "status": "success"
355
- }
356
- print(f"✓ {server_name.upper()}: Retrieved successfully")
357
- except Exception as e:
358
- results[server_name] = {
359
- "data": None,
360
- "status": "error",
361
- "error": str(e)
362
- }
363
- print(f"✗ {server_name.upper()}: Error - {str(e)}")
364
 
365
- # Execute async servers together in single event loop
366
- if async_tasks:
367
- try:
368
- async_results = asyncio.run(self._execute_async_batch(async_tasks))
369
- results.update(async_results)
370
- except Exception as e:
371
- # If batch fails, mark all as failed
372
- for task in async_tasks:
373
- results[task["server_name"]] = {
374
- "data": None,
375
- "status": "error",
376
- "error": f"Async batch execution failed: {str(e)}"
377
- }
378
- print(f"✗ {task['server_name'].upper()}: Async batch error")
379
 
380
- return results
381
-
382
- async def _execute_async_batch(self, tasks: list) -> Dict[str, Any]:
383
- """
384
- Execute multiple async server calls concurrently in a single event loop.
385
- This is safer than creating multiple event loops.
386
- """
387
- results = {}
388
-
389
- # Create async tasks for all servers
390
- async_calls = []
391
- for task in tasks:
392
- async_calls.append(self._call_async_server(task))
393
-
394
- # Execute all async calls concurrently
395
- task_results = await asyncio.gather(*async_calls, return_exceptions=True)
396
 
397
  # Process results
398
- for task, result in zip(tasks, task_results):
399
- server_name = task["server_name"]
400
-
401
  if isinstance(result, Exception):
402
  results[server_name] = {
403
  "data": None,
@@ -406,76 +333,66 @@ class MCPExecutor:
406
  }
407
  print(f"✗ {server_name.upper()}: Error - {str(result)}")
408
  else:
409
- results[server_name] = {
410
- "data": result,
411
- "status": "success"
412
- }
413
- print(f"✓ {server_name.upper()}: Retrieved successfully")
414
-
415
- return results
416
-
417
- async def _call_async_server(self, task: Dict[str, Any]) -> Any:
418
- """Call individual async MCP server"""
419
- server = task["server"]
420
- location = task["location"]
421
-
422
- if hasattr(server, 'get_data'):
423
- return await server.get_data(location['latitude'], location['longitude'])
424
- else:
425
- raise AttributeError(f"Server {task['server_name']} has no get_data method")
426
-
427
- def _call_sync_server(self, task: Dict[str, Any]) -> Any:
428
- """Call individual sync MCP server"""
429
- server = task["server"]
430
- location = task["location"]
431
-
432
- if hasattr(server, 'get_data'):
433
- return server.get_data(location['latitude'], location['longitude'])
434
- elif hasattr(server, 'query'):
435
- return server.query(location)
436
- elif hasattr(server, 'fetch_data'):
437
- return server.fetch_data(location['latitude'], location['longitude'])
438
- else:
439
- raise AttributeError(f"Server {task['server_name']} has no compatible query method")
440
-
441
- def execute_sequential(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
442
- """
443
- Execute MCP server calls sequentially (fallback if parallel fails).
444
- """
445
- results = {}
446
-
447
- for server_name, should_query in routing.items():
448
- if should_query and server_name in self.servers:
449
- try:
450
- server = self.servers[server_name]
451
-
452
- # Check if async
453
- if hasattr(server, 'get_data') and inspect.iscoroutinefunction(server.get_data):
454
- # Run async method
455
- result = asyncio.run(server.get_data(location['latitude'], location['longitude']))
456
  else:
457
- # Run sync method
458
- task = {
459
- "server_name": server_name,
460
- "server": server,
461
- "location": location
462
  }
463
- result = self._call_sync_server(task)
464
-
465
  results[server_name] = {
466
  "data": result,
467
  "status": "success"
468
  }
469
- print(f"✓ {server_name.upper()}: Retrieved successfully")
470
-
471
- except Exception as e:
472
- results[server_name] = {
473
- "data": None,
474
- "status": "error",
475
- "error": str(e)
476
- }
477
- print(f"✗ {server_name.upper()}: Error - {str(e)}")
478
 
479
  return results
 
 
 
 
480
 
481
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  # return results
269
 
270
  # return results
 
271
  """
272
  MCP Executor - Stage 2
273
  Executes parallel calls to MCP servers based on routing decisions
274
+ FIXED:
275
+ 1. Proper async handling for FastAPI (no asyncio.run inside existing loop)
276
+ 2. Fixed double-wrapping of server results
277
  """
278
 
279
  from typing import Dict, Any
 
280
  import asyncio
281
  import inspect
282
 
 
284
  class MCPExecutor:
285
  """
286
  Executes MCP server calls based on routing decisions.
287
+ Properly handles async servers within FastAPI's event loop.
 
288
  """
289
 
290
  def __init__(self, servers: Dict[str, Any]):
 
293
 
294
  Args:
295
  servers: Dict mapping server names to initialized server objects
 
296
  """
297
  self.servers = servers
298
 
299
+ async def execute_parallel_async(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
300
  """
301
+ Execute MCP server calls in parallel (async version for FastAPI).
302
 
303
  Args:
304
+ routing: Dict with server names as keys and True/False as values
305
  location: Dict with 'latitude' and 'longitude' keys
306
 
307
  Returns:
308
+ Dict mapping server names to their results
309
  """
310
  results = {}
311
+ tasks = []
312
+ server_names = []
 
 
 
313
 
314
  for server_name, should_query in routing.items():
315
  if should_query and server_name in self.servers:
316
  server = self.servers[server_name]
317
+ tasks.append(self._call_server(server, server_name, location))
318
+ server_names.append(server_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ if not tasks:
321
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ # Execute all tasks concurrently
324
+ task_results = await asyncio.gather(*tasks, return_exceptions=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  # Process results
327
+ for server_name, result in zip(server_names, task_results):
 
 
328
  if isinstance(result, Exception):
329
  results[server_name] = {
330
  "data": None,
 
333
  }
334
  print(f"✗ {server_name.upper()}: Error - {str(result)}")
335
  else:
336
+ # FIX: Handle servers that return {"status": ..., "data": ...}
337
+ # Don't double-wrap!
338
+ if isinstance(result, dict) and "status" in result:
339
+ # Server already returned proper format
340
+ if result.get("status") == "success":
341
+ results[server_name] = {
342
+ "data": result.get("data"), # Extract actual data
343
+ "status": "success"
344
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  else:
346
+ results[server_name] = {
347
+ "data": None,
348
+ "status": "error",
349
+ "error": result.get("error", "Unknown error")
 
350
  }
351
+ else:
352
+ # Server returned raw data
353
  results[server_name] = {
354
  "data": result,
355
  "status": "success"
356
  }
357
+ print(f"✓ {server_name.upper()}: Retrieved successfully")
 
 
 
 
 
 
 
 
358
 
359
  return results
360
+
361
+ def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
362
+ """
363
+ Execute MCP server calls in parallel (sync wrapper).
364
 
365
+ Detects if we're already in an async context and handles appropriately.
366
+ """
367
+ try:
368
+ # Check if there's already a running event loop
369
+ loop = asyncio.get_running_loop()
370
+ # We're in an async context - need to use nest_asyncio or return a coroutine
371
+ # For FastAPI, the endpoint should be async and call execute_parallel_async directly
372
+ raise RuntimeError(
373
+ "execute_parallel called from async context. "
374
+ "Use 'await executor.execute_parallel_async()' instead."
375
+ )
376
+ except RuntimeError:
377
+ # No running loop - safe to use asyncio.run
378
+ return asyncio.run(self.execute_parallel_async(routing, location))
379
+
380
+ async def _call_server(self, server: Any, server_name: str, location: Dict[str, float]) -> Any:
381
+ """
382
+ Call individual MCP server, handling both sync and async methods.
383
+ """
384
+ lat = location['latitude']
385
+ lon = location['longitude']
386
+
387
+ if hasattr(server, 'get_data'):
388
+ method = getattr(server, 'get_data')
389
+
390
+ if inspect.iscoroutinefunction(method):
391
+ # Async method - await it
392
+ return await method(lat, lon)
393
+ else:
394
+ # Sync method - run in executor to not block
395
+ loop = asyncio.get_event_loop()
396
+ return await loop.run_in_executor(None, method, lat, lon)
397
+ else:
398
+ raise AttributeError(f"Server {server_name} has no get_data method")