mcp-alert-generator / src /executor.py
aakashdg's picture
fix (asyncio inside FastAPI bug)
b1d2ecb verified
raw
history blame
15.3 kB
# """
# Stage 2: MCP Executor - Parallel API Execution
# """
# import asyncio
# import time
# from typing import List, Dict, Any
# from .servers.weather import WeatherServer
# from .servers.soil import SoilPropertiesServer
# from .servers.water import WaterServer
# from .servers.elevation import ElevationServer
# from .servers.pests import PestsServer
# # MCP Server Registry
# MCP_SERVER_REGISTRY = {
# "weather": {
# "name": "Weather Server (Open-Meteo)",
# "description": "Current weather and 7-day forecasts: temperature, precipitation, wind, humidity",
# "capabilities": ["current_weather", "weather_forecast", "rainfall_prediction", "temperature_trends"],
# "use_for": ["rain", "temperature", "weather", "forecast", "frost", "wind"]
# },
# "soil_properties": {
# "name": "Soil Properties Server (SoilGrids)",
# "description": "Soil composition: clay, sand, silt, pH, organic matter from global soil database",
# "capabilities": ["soil_texture", "soil_ph", "clay_content", "sand_content", "nutrients"],
# "use_for": ["soil", "pH", "texture", "clay", "sand", "composition", "fertility", "nutrients"]
# },
# "water": {
# "name": "Groundwater Server (GRACE)",
# "description": "Groundwater levels and drought indicators from NASA GRACE satellite data",
# "capabilities": ["groundwater_levels", "drought_status", "water_storage", "soil_moisture"],
# "use_for": ["groundwater", "drought", "water", "irrigation", "water stress", "moisture"]
# },
# "elevation": {
# "name": "Elevation Server (OpenElevation)",
# "description": "Field elevation and terrain data for irrigation planning",
# "capabilities": ["elevation", "terrain_analysis"],
# "use_for": ["elevation", "slope", "terrain", "drainage"]
# },
# "pests": {
# "name": "Pest Observation Server (iNaturalist)",
# "description": "Recent pest and insect observations from community reporting",
# "capabilities": ["pest_observations", "disease_reports", "pest_distribution"],
# "use_for": ["pests", "insects", "disease", "outbreak"]
# }
# }
# class MCPExecutor:
# """Stage 2: Execute API calls in parallel"""
# def __init__(self):
# self.servers = {
# "weather": WeatherServer(),
# "soil_properties": SoilPropertiesServer(),
# "water": WaterServer(),
# "elevation": ElevationServer(),
# "pests": PestsServer()
# }
# async def execute_parallel(self, server_names: List[str], lat: float, lon: float) -> Dict[str, Any]:
# """
# Call multiple servers simultaneously
# Returns:
# {
# "results": {
# "weather": {"status": "success", "data": {...}},
# ...
# },
# "execution_time_seconds": float
# }
# """
# start_time = time.time()
# tasks = []
# valid_servers = []
# for name in server_names:
# if name in self.servers:
# tasks.append(self.servers[name].get_data(lat, lon))
# valid_servers.append(name)
# else:
# print(f"⚠️ Unknown server: {name}")
# # Execute all in parallel
# results = await asyncio.gather(*tasks, return_exceptions=True)
# # Format results
# formatted_results = {}
# for i, server_name in enumerate(valid_servers):
# result = results[i]
# if isinstance(result, Exception):
# formatted_results[server_name] = {
# "status": "error",
# "error": str(result)
# }
# else:
# formatted_results[server_name] = result
# elapsed_time = time.time() - start_time
# return {
# "results": formatted_results,
# "execution_time_seconds": round(elapsed_time, 2)
# }
# """
# MCP Executor - Stage 2
# Executes parallel calls to MCP servers based on routing decisions
# """
# from typing import Dict, Any
# from concurrent.futures import ThreadPoolExecutor, as_completed
# import asyncio
# class MCPExecutor:
# """
# Executes MCP server calls based on routing decisions.
# Integrates with existing server implementations in src/servers/
# Handles both sync and async server methods.
# """
# def __init__(self, servers: Dict[str, Any]):
# """
# Initialize executor with MCP server instances.
# Args:
# servers: Dict mapping server names to initialized server objects
# e.g., {"weather": WeatherServer(), "soil": SoilPropertiesServer(), ...}
# """
# self.servers = servers
# def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
# """
# Execute MCP server calls in parallel based on routing.
# Args:
# routing: Simple dict with server names as keys and True/False as values
# location: Dict with 'latitude' and 'longitude' keys
# Returns:
# Dict mapping server names to their results with metadata
# """
# results = {}
# tasks = []
# # Prepare tasks for servers marked for querying
# for server_name, should_query in routing.items():
# if should_query and server_name in self.servers:
# tasks.append({
# "server_name": server_name,
# "server": self.servers[server_name],
# "location": location
# })
# # Execute in parallel using ThreadPoolExecutor
# with ThreadPoolExecutor(max_workers=5) as executor:
# futures = {
# executor.submit(self._call_server_sync, task): task
# for task in tasks
# }
# for future in as_completed(futures):
# task = futures[future]
# server_name = task["server_name"]
# try:
# result = future.result(timeout=30)
# results[server_name] = {
# "data": result,
# "status": "success"
# }
# print(f"✓ {server_name.upper()}: Retrieved successfully")
# except Exception as e:
# results[server_name] = {
# "data": None,
# "status": "error",
# "error": str(e)
# }
# print(f"✗ {server_name.upper()}: Error - {str(e)}")
# return results
# def _call_server_sync(self, task: Dict[str, Any]) -> Any:
# """
# Call individual MCP server, handling both sync and async methods.
# Args:
# task: Dict containing server, location, and metadata
# Returns:
# Server response data
# """
# server = task["server"]
# location = task["location"]
# # Try async method first (most of your servers use async)
# if hasattr(server, 'get_data'):
# method = getattr(server, 'get_data')
# # Check if it's async
# if asyncio.iscoroutinefunction(method):
# # Run async method in new event loop
# try:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# result = loop.run_until_complete(
# method(location['latitude'], location['longitude'])
# )
# loop.close()
# return result
# except Exception as e:
# raise Exception(f"Async execution failed: {str(e)}")
# else:
# # Sync method
# return method(location['latitude'], location['longitude'])
# # Fallback to other method names
# elif hasattr(server, 'query'):
# return server.query(location)
# elif hasattr(server, 'fetch_data'):
# return server.fetch_data(location['latitude'], location['longitude'])
# else:
# raise AttributeError(f"Server {task['server_name']} has no compatible query method")
# def execute_sequential(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
# """
# Execute MCP server calls sequentially (fallback if parallel fails).
# Args:
# routing: Simple dict with server names as keys and True/False as values
# location: Dict with 'latitude' and 'longitude' keys
# Returns:
# Dict mapping server names to their results
# """
# results = {}
# for server_name, should_query in routing.items():
# if should_query and server_name in self.servers:
# try:
# task = {
# "server_name": server_name,
# "server": self.servers[server_name],
# "location": location
# }
# result = self._call_server_sync(task)
# results[server_name] = {
# "data": result,
# "status": "success"
# }
# print(f"✓ {server_name.upper()}: Retrieved successfully")
# except Exception as e:
# results[server_name] = {
# "data": None,
# "status": "error",
# "error": str(e)
# }
# print(f"✗ {server_name.upper()}: Error - {str(e)}")
# return results
# return results
"""
MCP Executor - Stage 2
Executes parallel calls to MCP servers based on routing decisions
FIXED:
1. Proper async handling for FastAPI (no asyncio.run inside existing loop)
2. Fixed double-wrapping of server results
"""
from typing import Dict, Any
import asyncio
import inspect
class MCPExecutor:
"""
Executes MCP server calls based on routing decisions.
Properly handles async servers within FastAPI's event loop.
"""
def __init__(self, servers: Dict[str, Any]):
"""
Initialize executor with MCP server instances.
Args:
servers: Dict mapping server names to initialized server objects
"""
self.servers = servers
async def execute_parallel_async(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
"""
Execute MCP server calls in parallel (async version for FastAPI).
Args:
routing: Dict with server names as keys and True/False as values
location: Dict with 'latitude' and 'longitude' keys
Returns:
Dict mapping server names to their results
"""
results = {}
tasks = []
server_names = []
for server_name, should_query in routing.items():
if should_query and server_name in self.servers:
server = self.servers[server_name]
tasks.append(self._call_server(server, server_name, location))
server_names.append(server_name)
if not tasks:
return results
# Execute all tasks concurrently
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
for server_name, result in zip(server_names, task_results):
if isinstance(result, Exception):
results[server_name] = {
"data": None,
"status": "error",
"error": str(result)
}
print(f"✗ {server_name.upper()}: Error - {str(result)}")
else:
# FIX: Handle servers that return {"status": ..., "data": ...}
# Don't double-wrap!
if isinstance(result, dict) and "status" in result:
# Server already returned proper format
if result.get("status") == "success":
results[server_name] = {
"data": result.get("data"), # Extract actual data
"status": "success"
}
else:
results[server_name] = {
"data": None,
"status": "error",
"error": result.get("error", "Unknown error")
}
else:
# Server returned raw data
results[server_name] = {
"data": result,
"status": "success"
}
print(f"✓ {server_name.upper()}: Retrieved successfully")
return results
def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
"""
Execute MCP server calls in parallel (sync wrapper).
Detects if we're already in an async context and handles appropriately.
"""
try:
# Check if there's already a running event loop
loop = asyncio.get_running_loop()
# We're in an async context - need to use nest_asyncio or return a coroutine
# For FastAPI, the endpoint should be async and call execute_parallel_async directly
raise RuntimeError(
"execute_parallel called from async context. "
"Use 'await executor.execute_parallel_async()' instead."
)
except RuntimeError:
# No running loop - safe to use asyncio.run
return asyncio.run(self.execute_parallel_async(routing, location))
async def _call_server(self, server: Any, server_name: str, location: Dict[str, float]) -> Any:
"""
Call individual MCP server, handling both sync and async methods.
"""
lat = location['latitude']
lon = location['longitude']
if hasattr(server, 'get_data'):
method = getattr(server, 'get_data')
if inspect.iscoroutinefunction(method):
# Async method - await it
return await method(lat, lon)
else:
# Sync method - run in executor to not block
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, method, lat, lon)
else:
raise AttributeError(f"Server {server_name} has no get_data method")