| | import os |
| | import requests |
| | import json |
| | from typing import Dict, Any, List, Optional |
| | from langchain.tools import BaseTool |
| | from pydantic import Field |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class MCPClient: |
| | """Client for making authenticated REST API calls to the MCP server.""" |
| | |
| | def __init__(self, mcp_url: str, api_key: str): |
| | self.mcp_url = mcp_url |
| | self.headers = { |
| | "x-api-key": api_key, |
| | "Content-Type": "application/json" |
| | } |
| | |
| | def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """Send a POST request to a given MCP endpoint.""" |
| | try: |
| | url = f"{self.mcp_url}/{endpoint}" |
| | response = requests.post(url, headers=self.headers, data=json.dumps(data)) |
| | response.raise_for_status() |
| | return response.json() |
| | except requests.exceptions.HTTPError as http_err: |
| | logger.error(f"HTTP error occurred: {http_err} - {response.text}") |
| | return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"} |
| | except requests.exceptions.RequestException as req_err: |
| | logger.error(f"Request error occurred: {req_err}") |
| | return {"status": "error", "message": f"Request failed: {req_err}"} |
| | except json.JSONDecodeError: |
| | logger.error("Failed to decode JSON response.") |
| | return {"status": "error", "message": "Invalid JSON response from server."} |
| |
|
| |
|
| | class SchemaSearchTool(BaseTool): |
| | """LangChain tool for searching database schemas.""" |
| | |
| | name: str = "schema_search" |
| | description: str = """ |
| | Search for relevant database schemas based on a natural language query. |
| | Use this when you need to find which tables/columns are relevant to a user's question. |
| | Input should be a descriptive query like 'patient information' or 'drug trials'. |
| | """ |
| | mcp_client: MCPClient |
| | |
| | def _run(self, query: str) -> str: |
| | """Execute schema search.""" |
| | response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query}) |
| | |
| | if response.get("status") == "success": |
| | schemas = response.get("schemas", []) |
| | if schemas: |
| | schema_text = "Found relevant schemas:\\n" |
| | for schema in schemas: |
| | schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n" |
| | return schema_text |
| | else: |
| | return "No relevant schemas found." |
| | else: |
| | return f"Error searching schemas: {response.get('message', 'Unknown error')}" |
| | |
| | async def _arun(self, query: str) -> str: |
| | """Async version - just calls sync version.""" |
| | return self._run(query) |
| |
|
| |
|
| | class JoinPathFinderTool(BaseTool): |
| | """LangChain tool for finding join paths between tables.""" |
| | |
| | name: str = "find_join_path" |
| | description: str = """ |
| | Find how to join two tables together using foreign key relationships. |
| | Use this when you need to query across multiple tables. |
| | Input should be two table names separated by a comma, like 'patients,studies'. |
| | """ |
| | mcp_client: MCPClient |
| |
|
| | def _run(self, table_names: str) -> str: |
| | """Find join path.""" |
| | try: |
| | tables = [t.strip() for t in table_names.split(',')] |
| | if len(tables) != 2: |
| | return "Please provide exactly two table names separated by a comma." |
| | |
| | response = self.mcp_client.post( |
| | "graph/find_join_path", |
| | {"table1": tables[0], "table2": tables[1]} |
| | ) |
| | |
| | if response.get("status") == "success": |
| | path = response.get("path", "No path found") |
| | return f"Join path: {path}" |
| | else: |
| | return f"Error finding join path: {response.get('message', 'Unknown error')}" |
| | except Exception as e: |
| | return f"Failed to find join path: {str(e)}" |
| |
|
| | async def _arun(self, table_names: str) -> str: |
| | """Async version - just calls sync version.""" |
| | return self._run(table_names) |
| |
|
| |
|
| | class QueryExecutorTool(BaseTool): |
| | """LangChain tool for executing SQL queries.""" |
| | |
| | name: str = "execute_query" |
| | description: str = """ |
| | Execute a SQL query against the databases and return results. |
| | Use this after you have a valid SQL query. |
| | Input should be a valid SQL query string. |
| | """ |
| | mcp_client: MCPClient |
| |
|
| | def _run(self, sql: str) -> str: |
| | """Execute query.""" |
| | try: |
| | response = self.mcp_client.post( |
| | "intelligence/execute_query", |
| | {"sql": sql} |
| | ) |
| | |
| | if response.get("status") == "success": |
| | results = response.get("results", []) |
| | |
| | if results: |
| | |
| | result_text = f"Query returned {len(results)} rows:\\n" |
| | headers = list(results[0].keys()) |
| | result_text += " | ".join(headers) + "\n" |
| | result_text += "-" * (len(" | ".join(headers))) + "\n" |
| | |
| | for row in results[:10]: |
| | values = [str(row.get(h, "")) for h in headers] |
| | result_text += " | ".join(values) + "\n" |
| | |
| | if len(results) > 10: |
| | result_text += f"... and {len(results) - 10} more rows\n" |
| | |
| | return result_text |
| | else: |
| | return "Query executed successfully but returned no results." |
| | else: |
| | return f"Error executing query: {response.get('message', 'Unknown error')}" |
| | except Exception as e: |
| | return f"Failed to execute query: {str(e)}" |
| |
|
| | async def _arun(self, sql: str) -> str: |
| | """Async version - just calls sync version.""" |
| | return self._run(sql) |
| |
|