import os import requests import json from typing import Dict, Any, List, Optional from langchain.tools import BaseTool from pydantic import Field import logging logger = logging.getLogger(__name__) class MCPClient: """Client for making authenticated REST API calls to the MCP server.""" def __init__(self, mcp_url: str, api_key: str): self.mcp_url = mcp_url self.headers = { "x-api-key": api_key, "Content-Type": "application/json" } def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: """Send a POST request to a given MCP endpoint.""" try: url = f"{self.mcp_url}/{endpoint}" response = requests.post(url, headers=self.headers, data=json.dumps(data)) response.raise_for_status() return response.json() except requests.exceptions.HTTPError as http_err: logger.error(f"HTTP error occurred: {http_err} - {response.text}") return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"} except requests.exceptions.RequestException as req_err: logger.error(f"Request error occurred: {req_err}") return {"status": "error", "message": f"Request failed: {req_err}"} except json.JSONDecodeError: logger.error("Failed to decode JSON response.") return {"status": "error", "message": "Invalid JSON response from server."} class SchemaSearchTool(BaseTool): """LangChain tool for searching database schemas.""" name: str = "schema_search" description: str = """ Search for relevant database schemas based on a natural language query. Use this when you need to find which tables/columns are relevant to a user's question. Input should be a descriptive query like 'patient information' or 'drug trials'. """ mcp_client: MCPClient def _run(self, query: str) -> str: """Execute schema search.""" response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query}) if response.get("status") == "success": schemas = response.get("schemas", []) if schemas: schema_text = "Found relevant schemas:\\n" for schema in schemas: schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n" return schema_text else: return "No relevant schemas found." else: return f"Error searching schemas: {response.get('message', 'Unknown error')}" async def _arun(self, query: str) -> str: """Async version - just calls sync version.""" return self._run(query) class JoinPathFinderTool(BaseTool): """LangChain tool for finding join paths between tables.""" name: str = "find_join_path" description: str = """ Find how to join two tables together using foreign key relationships. Use this when you need to query across multiple tables. Input should be two table names separated by a comma, like 'patients,studies'. """ mcp_client: MCPClient def _run(self, table_names: str) -> str: """Find join path.""" try: tables = [t.strip() for t in table_names.split(',')] if len(tables) != 2: return "Please provide exactly two table names separated by a comma." response = self.mcp_client.post( "graph/find_join_path", {"table1": tables[0], "table2": tables[1]} ) if response.get("status") == "success": path = response.get("path", "No path found") return f"Join path: {path}" else: return f"Error finding join path: {response.get('message', 'Unknown error')}" except Exception as e: return f"Failed to find join path: {str(e)}" async def _arun(self, table_names: str) -> str: """Async version - just calls sync version.""" return self._run(table_names) class QueryExecutorTool(BaseTool): """LangChain tool for executing SQL queries.""" name: str = "execute_query" description: str = """ Execute a SQL query against the databases and return results. Use this after you have a valid SQL query. Input should be a valid SQL query string. """ mcp_client: MCPClient def _run(self, sql: str) -> str: """Execute query.""" try: response = self.mcp_client.post( "intelligence/execute_query", {"sql": sql} ) if response.get("status") == "success": results = response.get("results", []) if results: # Format results as a readable table result_text = f"Query returned {len(results)} rows:\\n" headers = list(results[0].keys()) result_text += " | ".join(headers) + "\n" result_text += "-" * (len(" | ".join(headers))) + "\n" for row in results[:10]: # Limit display to first 10 rows values = [str(row.get(h, "")) for h in headers] result_text += " | ".join(values) + "\n" if len(results) > 10: result_text += f"... and {len(results) - 10} more rows\n" return result_text else: return "Query executed successfully but returned no results." else: return f"Error executing query: {response.get('message', 'Unknown error')}" except Exception as e: return f"Failed to execute query: {str(e)}" async def _arun(self, sql: str) -> str: """Async version - just calls sync version.""" return self._run(sql)