Spaces:
Sleeping
Sleeping
File size: 8,085 Bytes
cfd07be 9d3c7fc cfd07be 9d3c7fc cfd07be 9d3c7fc b749aa6 9d3c7fc cfd07be 9d3c7fc b749aa6 4f4a9eb 9d3c7fc 4f4a9eb 9d3c7fc cfd07be b749aa6 9d3c7fc b749aa6 9d3c7fc cfd07be 9d3c7fc cfd07be 9d3c7fc cfd07be 9d3c7fc cfd07be 3847bc0 cfd07be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
import asyncio
import logging
from typing import Any, Dict, List, Optional
from langchain_mcp_adapters.client import MultiServerMCPClient
class MCPClient:
"""
Abstraction for calling MCP tools from multiple servers.
Manages connections to NPI, Credential DB, and Alert MCP servers.
"""
def __init__(self):
self.npi_url = os.getenv("NPI_MCP_URL", "http://localhost:8001/sse")
self.cred_db_url = os.getenv("CRED_DB_MCP_URL", "http://localhost:8002/sse")
self.alert_url = os.getenv("ALERT_MCP_URL", "http://localhost:8003/sse")
self._client: Optional[MultiServerMCPClient] = None
self._tools: Dict[str, Any] = {} # Cache tools
self._mock_mode = False
self._connected = False
self._connect_lock = asyncio.Lock()
# Configure logger
self.logger = logging.getLogger("mcp_client")
async def connect(self):
"""Establishes connections to all MCP servers."""
async with self._connect_lock:
if self._connected:
return
# Check if running on HF Spaces and using default localhost URLs
is_hf = os.getenv("SPACE_ID") is not None
# Helper to check if URL is localhost
def is_localhost(url):
return "localhost" in url or "127.0.0.1" in url
# Normalize URLs for SSE
def normalize_sse_url(url):
if url.endswith("/"):
url = url[:-1]
if not url.endswith("/sse"):
url += "/sse"
return url
npi_url = normalize_sse_url(self.npi_url)
cred_db_url = normalize_sse_url(self.cred_db_url)
alert_url = normalize_sse_url(self.alert_url)
if is_hf and (is_localhost(npi_url) or is_localhost(cred_db_url) or is_localhost(alert_url)):
self.logger.info("Detected Hugging Face Spaces environment with localhost URLs.")
self.logger.info("Skipping actual MCP connections and defaulting to mock data.")
self._mock_mode = True
self._connected = True
return
self.logger.info("Initializing MultiServerMCPClient...")
servers = {
"npi": {
"transport": "sse",
"url": npi_url,
},
"cred_db": {
"transport": "sse",
"url": cred_db_url,
},
"alert": {
"transport": "sse",
"url": alert_url,
}
}
# Add auth headers if needed
if os.getenv("HF_TOKEN"):
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
for server in servers.values():
server["headers"] = headers
try:
self._client = MultiServerMCPClient(servers)
# Pre-fetch tools to verify connection and cache them
tools_list = await self._client.get_tools()
self._tools = {tool.name: tool for tool in tools_list}
self.logger.info(f"Successfully connected. Loaded {len(self._tools)} tools.")
self._connected = True
except Exception as e:
self.logger.error(f"Failed to initialize MCP client: {e}", exc_info=True)
# If initialization fails, we might want to fallback to mock mode or just fail
# For now, let's allow retry or fail gracefully
pass
async def close(self):
"""Closes all connections."""
# MultiServerMCPClient might not have an explicit close, but we can clear it
self._client = None
self._connected = False
self.logger.info("MCP connections closed.")
async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Calls a tool. server_name is mostly for compatibility/mocking, as tools are flattened."""
if not self._connected:
await self.connect()
if self._mock_mode:
return self._get_mock_response(server_name, tool_name, arguments)
# In MultiServerMCPClient, tools are flattened.
tool = self._tools.get(tool_name)
# Fuzzy match if exact match fails
if not tool:
# Try to find a tool that contains the tool_name
# We prioritize matches that end with the tool_name or tool_name_tool
for name, t in self._tools.items():
if name == tool_name:
tool = t
break
if name.endswith(f"_{tool_name}") or name.endswith(f"_{tool_name}_tool") or name == f"{tool_name}_tool":
tool = t
break
# Fallback: check if tool_name is in the name (less safe but helpful)
if tool_name in name:
tool = t
# Keep searching for a better match (suffix)
continue
if not tool:
# Try to refresh tools
if self._client:
try:
tools_list = await self._client.get_tools()
self._tools = {t.name: t for t in tools_list}
tool = self._tools.get(tool_name)
# Retry fuzzy match after refresh
if not tool:
for name, t in self._tools.items():
if name.endswith(f"_{tool_name}") or name.endswith(f"_{tool_name}_tool") or name == f"{tool_name}_tool":
tool = t
break
if tool_name in name:
tool = t
except Exception as e:
self.logger.error(f"Error refreshing tools: {e}")
if not tool:
self.logger.warning(f"Tool '{tool_name}' not found in loaded tools. Using mock if available.")
return self._get_mock_response(server_name, tool_name, arguments)
try:
self.logger.info(f"Calling tool '{tool_name}' with args: {arguments}")
# LangChain tools are callable or have .invoke
result = await tool.ainvoke(arguments)
self.logger.info(f"Tool '{tool_name}' returned successfully.")
return result
except Exception as e:
self.logger.error(f"Error calling tool '{tool_name}': {e}", exc_info=True)
raise
def _get_mock_response(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Returns mock data when MCP server is unavailable."""
if server_name == "npi":
if tool_name == "search_providers":
return {"providers": [{"npi": "1234567890", "name": "Dr. Jane Doe", "taxonomy": "Cardiology"}]}
if tool_name == "get_provider_by_npi":
return {"npi": arguments.get("npi"), "name": "Dr. Jane Doe", "licenses": []}
if server_name == "cred_db":
if tool_name == "list_expiring_credentials":
return {"expiring": [{"provider_id": 1, "name": "Dr. Jane Doe", "credential": "Medical License", "days_remaining": 25}]}
if tool_name == "get_provider_snapshot":
return {"name": "Dr. Jane Doe", "status": "Active", "credentials": []}
if server_name == "alert":
if tool_name == "log_alert":
return {"success": True, "alert_id": 101}
if tool_name == "get_open_alerts":
return {"alerts": []}
return {"error": "Mock data not found for this tool"}
def get_tools(self) -> List[Any]:
"""Returns the list of available tools."""
return list(self._tools.values())
# Global instance
mcp_client = MCPClient()
|