Humanlearning's picture
feat: Implement interactive query agent with MCP client and Gradio UI, and add verification script.
3847bc0
import os
import asyncio
import logging
from typing import Any, Dict, List, Optional
from langchain_mcp_adapters.client import MultiServerMCPClient
class MCPClient:
"""
Abstraction for calling MCP tools from multiple servers.
Manages connections to NPI, Credential DB, and Alert MCP servers.
"""
def __init__(self):
self.npi_url = os.getenv("NPI_MCP_URL", "http://localhost:8001/sse")
self.cred_db_url = os.getenv("CRED_DB_MCP_URL", "http://localhost:8002/sse")
self.alert_url = os.getenv("ALERT_MCP_URL", "http://localhost:8003/sse")
self._client: Optional[MultiServerMCPClient] = None
self._tools: Dict[str, Any] = {} # Cache tools
self._mock_mode = False
self._connected = False
self._connect_lock = asyncio.Lock()
# Configure logger
self.logger = logging.getLogger("mcp_client")
async def connect(self):
"""Establishes connections to all MCP servers."""
async with self._connect_lock:
if self._connected:
return
# Check if running on HF Spaces and using default localhost URLs
is_hf = os.getenv("SPACE_ID") is not None
# Helper to check if URL is localhost
def is_localhost(url):
return "localhost" in url or "127.0.0.1" in url
# Normalize URLs for SSE
def normalize_sse_url(url):
if url.endswith("/"):
url = url[:-1]
if not url.endswith("/sse"):
url += "/sse"
return url
npi_url = normalize_sse_url(self.npi_url)
cred_db_url = normalize_sse_url(self.cred_db_url)
alert_url = normalize_sse_url(self.alert_url)
if is_hf and (is_localhost(npi_url) or is_localhost(cred_db_url) or is_localhost(alert_url)):
self.logger.info("Detected Hugging Face Spaces environment with localhost URLs.")
self.logger.info("Skipping actual MCP connections and defaulting to mock data.")
self._mock_mode = True
self._connected = True
return
self.logger.info("Initializing MultiServerMCPClient...")
servers = {
"npi": {
"transport": "sse",
"url": npi_url,
},
"cred_db": {
"transport": "sse",
"url": cred_db_url,
},
"alert": {
"transport": "sse",
"url": alert_url,
}
}
# Add auth headers if needed
if os.getenv("HF_TOKEN"):
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
for server in servers.values():
server["headers"] = headers
try:
self._client = MultiServerMCPClient(servers)
# Pre-fetch tools to verify connection and cache them
tools_list = await self._client.get_tools()
self._tools = {tool.name: tool for tool in tools_list}
self.logger.info(f"Successfully connected. Loaded {len(self._tools)} tools.")
self._connected = True
except Exception as e:
self.logger.error(f"Failed to initialize MCP client: {e}", exc_info=True)
# If initialization fails, we might want to fallback to mock mode or just fail
# For now, let's allow retry or fail gracefully
pass
async def close(self):
"""Closes all connections."""
# MultiServerMCPClient might not have an explicit close, but we can clear it
self._client = None
self._connected = False
self.logger.info("MCP connections closed.")
async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Calls a tool. server_name is mostly for compatibility/mocking, as tools are flattened."""
if not self._connected:
await self.connect()
if self._mock_mode:
return self._get_mock_response(server_name, tool_name, arguments)
# In MultiServerMCPClient, tools are flattened.
tool = self._tools.get(tool_name)
# Fuzzy match if exact match fails
if not tool:
# Try to find a tool that contains the tool_name
# We prioritize matches that end with the tool_name or tool_name_tool
for name, t in self._tools.items():
if name == tool_name:
tool = t
break
if name.endswith(f"_{tool_name}") or name.endswith(f"_{tool_name}_tool") or name == f"{tool_name}_tool":
tool = t
break
# Fallback: check if tool_name is in the name (less safe but helpful)
if tool_name in name:
tool = t
# Keep searching for a better match (suffix)
continue
if not tool:
# Try to refresh tools
if self._client:
try:
tools_list = await self._client.get_tools()
self._tools = {t.name: t for t in tools_list}
tool = self._tools.get(tool_name)
# Retry fuzzy match after refresh
if not tool:
for name, t in self._tools.items():
if name.endswith(f"_{tool_name}") or name.endswith(f"_{tool_name}_tool") or name == f"{tool_name}_tool":
tool = t
break
if tool_name in name:
tool = t
except Exception as e:
self.logger.error(f"Error refreshing tools: {e}")
if not tool:
self.logger.warning(f"Tool '{tool_name}' not found in loaded tools. Using mock if available.")
return self._get_mock_response(server_name, tool_name, arguments)
try:
self.logger.info(f"Calling tool '{tool_name}' with args: {arguments}")
# LangChain tools are callable or have .invoke
result = await tool.ainvoke(arguments)
self.logger.info(f"Tool '{tool_name}' returned successfully.")
return result
except Exception as e:
self.logger.error(f"Error calling tool '{tool_name}': {e}", exc_info=True)
raise
def _get_mock_response(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Returns mock data when MCP server is unavailable."""
if server_name == "npi":
if tool_name == "search_providers":
return {"providers": [{"npi": "1234567890", "name": "Dr. Jane Doe", "taxonomy": "Cardiology"}]}
if tool_name == "get_provider_by_npi":
return {"npi": arguments.get("npi"), "name": "Dr. Jane Doe", "licenses": []}
if server_name == "cred_db":
if tool_name == "list_expiring_credentials":
return {"expiring": [{"provider_id": 1, "name": "Dr. Jane Doe", "credential": "Medical License", "days_remaining": 25}]}
if tool_name == "get_provider_snapshot":
return {"name": "Dr. Jane Doe", "status": "Active", "credentials": []}
if server_name == "alert":
if tool_name == "log_alert":
return {"success": True, "alert_id": 101}
if tool_name == "get_open_alerts":
return {"alerts": []}
return {"error": "Mock data not found for this tool"}
def get_tools(self) -> List[Any]:
"""Returns the list of available tools."""
return list(self._tools.values())
# Global instance
mcp_client = MCPClient()