File size: 4,996 Bytes
6afc01a
 
 
b1d2ecb
 
 
6afc01a
 
 
 
 
 
 
 
 
 
b1d2ecb
6afc01a
 
 
 
 
 
 
 
 
 
 
b1d2ecb
6afc01a
b1d2ecb
6afc01a
 
b1d2ecb
6afc01a
 
 
b1d2ecb
6afc01a
 
b1d2ecb
 
6afc01a
 
 
 
b1d2ecb
 
6afc01a
b1d2ecb
 
6afc01a
b1d2ecb
 
6afc01a
 
b1d2ecb
6afc01a
 
 
 
 
 
 
 
b1d2ecb
 
 
 
 
 
 
 
 
6afc01a
b1d2ecb
 
 
 
6afc01a
b1d2ecb
 
6afc01a
 
 
 
b1d2ecb
6afc01a
 
b1d2ecb
 
 
 
6afc01a
b1d2ecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
MCP Executor - Stage 2
Executes parallel calls to MCP servers based on routing decisions
FIXED: 
1. Proper async handling for FastAPI (no asyncio.run inside existing loop)
2. Fixed double-wrapping of server results
"""

from typing import Dict, Any
import asyncio
import inspect


class MCPExecutor:
    """
    Executes MCP server calls based on routing decisions.
    Properly handles async servers within FastAPI's event loop.
    """
    
    def __init__(self, servers: Dict[str, Any]):
        """
        Initialize executor with MCP server instances.
        
        Args:
            servers: Dict mapping server names to initialized server objects
        """
        self.servers = servers
    
    async def execute_parallel_async(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
        """
        Execute MCP server calls in parallel (async version for FastAPI).
        
        Args:
            routing: Dict with server names as keys and True/False as values
            location: Dict with 'latitude' and 'longitude' keys
            
        Returns:
            Dict mapping server names to their results
        """
        results = {}
        tasks = []
        server_names = []
        
        for server_name, should_query in routing.items():
            if should_query and server_name in self.servers:
                server = self.servers[server_name]
                tasks.append(self._call_server(server, server_name, location))
                server_names.append(server_name)
        
        if not tasks:
            return results
        
        # Execute all tasks concurrently
        task_results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Process results
        for server_name, result in zip(server_names, task_results):
            if isinstance(result, Exception):
                results[server_name] = {
                    "data": None,
                    "status": "error",
                    "error": str(result)
                }
                print(f"✗ {server_name.upper()}: Error - {str(result)}")
            else:
                # FIX: Handle servers that return {"status": ..., "data": ...}
                # Don't double-wrap!
                if isinstance(result, dict) and "status" in result:
                    # Server already returned proper format
                    if result.get("status") == "success":
                        results[server_name] = {
                            "data": result.get("data"),  # Extract actual data
                            "status": "success"
                        }
                    else:
                        results[server_name] = {
                            "data": None,
                            "status": "error",
                            "error": result.get("error", "Unknown error")
                        }
                else:
                    # Server returned raw data
                    results[server_name] = {
                        "data": result,
                        "status": "success"
                    }
                print(f"✓ {server_name.upper()}: Retrieved successfully")
        
        return results
    
    def execute_parallel(self, routing: Dict[str, bool], location: Dict[str, float]) -> Dict[str, Any]:
        """
        Execute MCP server calls in parallel (sync wrapper).
        
        Detects if we're already in an async context and handles appropriately.
        """
        try:
            # Check if there's already a running event loop
            loop = asyncio.get_running_loop()
            # We're in an async context - need to use nest_asyncio or return a coroutine
            # For FastAPI, the endpoint should be async and call execute_parallel_async directly
            raise RuntimeError(
                "execute_parallel called from async context. "
                "Use 'await executor.execute_parallel_async()' instead."
            )
        except RuntimeError:
            # No running loop - safe to use asyncio.run
            return asyncio.run(self.execute_parallel_async(routing, location))
    
    async def _call_server(self, server: Any, server_name: str, location: Dict[str, float]) -> Any:
        """
        Call individual MCP server, handling both sync and async methods.
        """
        lat = location['latitude']
        lon = location['longitude']
        
        if hasattr(server, 'get_data'):
            method = getattr(server, 'get_data')
            
            if inspect.iscoroutinefunction(method):
                # Async method - await it
                return await method(lat, lon)
            else:
                # Sync method - run in executor to not block
                loop = asyncio.get_event_loop()
                return await loop.run_in_executor(None, method, lat, lon)
        else:
            raise AttributeError(f"Server {server_name} has no get_data method")