Spaces:
No application file
No application file
| import os | |
| import requests | |
| import json | |
| from typing import Dict, Any, List, Optional | |
| from langchain.tools import BaseTool | |
| from pydantic import Field | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class MCPClient: | |
| """Client for making authenticated REST API calls to the MCP server.""" | |
| def __init__(self, mcp_url: str, api_key: str): | |
| self.mcp_url = mcp_url | |
| self.headers = { | |
| "x-api-key": api_key, | |
| "Content-Type": "application/json" | |
| } | |
| def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Send a POST request to a given MCP endpoint.""" | |
| try: | |
| url = f"{self.mcp_url}/{endpoint}" | |
| response = requests.post(url, headers=self.headers, data=json.dumps(data)) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.HTTPError as http_err: | |
| logger.error(f"HTTP error occurred: {http_err} - {response.text}") | |
| return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"} | |
| except requests.exceptions.RequestException as req_err: | |
| logger.error(f"Request error occurred: {req_err}") | |
| return {"status": "error", "message": f"Request failed: {req_err}"} | |
| except json.JSONDecodeError: | |
| logger.error("Failed to decode JSON response.") | |
| return {"status": "error", "message": "Invalid JSON response from server."} | |
| class SchemaSearchTool(BaseTool): | |
| """LangChain tool for searching database schemas.""" | |
| name: str = "schema_search" | |
| description: str = """ | |
| Search for relevant database schemas based on a natural language query. | |
| Use this when you need to find which tables/columns are relevant to a user's question. | |
| Input should be a descriptive query like 'patient information' or 'drug trials'. | |
| """ | |
| mcp_client: MCPClient | |
| def _run(self, query: str) -> str: | |
| """Execute schema search.""" | |
| response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query}) | |
| if response.get("status") == "success": | |
| schemas = response.get("schemas", []) | |
| if schemas: | |
| schema_text = "Found relevant schemas:\\n" | |
| for schema in schemas: | |
| schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n" | |
| return schema_text | |
| else: | |
| return "No relevant schemas found." | |
| else: | |
| return f"Error searching schemas: {response.get('message', 'Unknown error')}" | |
| async def _arun(self, query: str) -> str: | |
| """Async version - just calls sync version.""" | |
| return self._run(query) | |
| class JoinPathFinderTool(BaseTool): | |
| """LangChain tool for finding join paths between tables.""" | |
| name: str = "find_join_path" | |
| description: str = """ | |
| Find how to join two tables together using foreign key relationships. | |
| Use this when you need to query across multiple tables. | |
| Input should be two table names separated by a comma, like 'patients,studies'. | |
| """ | |
| mcp_client: MCPClient | |
| def _run(self, table_names: str) -> str: | |
| """Find join path.""" | |
| try: | |
| tables = [t.strip() for t in table_names.split(',')] | |
| if len(tables) != 2: | |
| return "Please provide exactly two table names separated by a comma." | |
| response = self.mcp_client.post( | |
| "graph/find_join_path", | |
| {"table1": tables[0], "table2": tables[1]} | |
| ) | |
| if response.get("status") == "success": | |
| path = response.get("path", "No path found") | |
| return f"Join path: {path}" | |
| else: | |
| return f"Error finding join path: {response.get('message', 'Unknown error')}" | |
| except Exception as e: | |
| return f"Failed to find join path: {str(e)}" | |
| async def _arun(self, table_names: str) -> str: | |
| """Async version - just calls sync version.""" | |
| return self._run(table_names) | |
| class QueryExecutorTool(BaseTool): | |
| """LangChain tool for executing SQL queries.""" | |
| name: str = "execute_query" | |
| description: str = """ | |
| Execute a SQL query against the databases and return results. | |
| Use this after you have a valid SQL query. | |
| Input should be a valid SQL query string. | |
| """ | |
| mcp_client: MCPClient | |
| def _run(self, sql: str) -> str: | |
| """Execute query.""" | |
| try: | |
| response = self.mcp_client.post( | |
| "intelligence/execute_query", | |
| {"sql": sql} | |
| ) | |
| if response.get("status") == "success": | |
| results = response.get("results", []) | |
| if results: | |
| # Format results as a readable table | |
| result_text = f"Query returned {len(results)} rows:\\n" | |
| headers = list(results[0].keys()) | |
| result_text += " | ".join(headers) + "\n" | |
| result_text += "-" * (len(" | ".join(headers))) + "\n" | |
| for row in results[:10]: # Limit display to first 10 rows | |
| values = [str(row.get(h, "")) for h in headers] | |
| result_text += " | ".join(values) + "\n" | |
| if len(results) > 10: | |
| result_text += f"... and {len(results) - 10} more rows\n" | |
| return result_text | |
| else: | |
| return "Query executed successfully but returned no results." | |
| else: | |
| return f"Error executing query: {response.get('message', 'Unknown error')}" | |
| except Exception as e: | |
| return f"Failed to execute query: {str(e)}" | |
| async def _arun(self, sql: str) -> str: | |
| """Async version - just calls sync version.""" | |
| return self._run(sql) | |