File size: 6,171 Bytes
86cbe3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0eb181
86cbe3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0eb181
 
86cbe3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0eb181
 
86cbe3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6422ca4
 
86cbe3c
 
 
6422ca4
86cbe3c
 
6422ca4
86cbe3c
 
 
 
 
 
 
 
 
 
a0eb181
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import requests
import json
from typing import Dict, Any, List, Optional
from langchain.tools import BaseTool
from pydantic import Field
import logging

logger = logging.getLogger(__name__)

class MCPClient:
    """Client for making authenticated REST API calls to the MCP server."""
    
    def __init__(self, mcp_url: str, api_key: str):
        self.mcp_url = mcp_url
        self.headers = {
            "x-api-key": api_key,
            "Content-Type": "application/json"
        }
    
    def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
        """Send a POST request to a given MCP endpoint."""
        try:
            url = f"{self.mcp_url}/{endpoint}"
            response = requests.post(url, headers=self.headers, data=json.dumps(data))
            response.raise_for_status()
            return response.json()
        except requests.exceptions.HTTPError as http_err:
            logger.error(f"HTTP error occurred: {http_err} - {response.text}")
            return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"}
        except requests.exceptions.RequestException as req_err:
            logger.error(f"Request error occurred: {req_err}")
            return {"status": "error", "message": f"Request failed: {req_err}"}
        except json.JSONDecodeError:
            logger.error("Failed to decode JSON response.")
            return {"status": "error", "message": "Invalid JSON response from server."}


class SchemaSearchTool(BaseTool):
    """LangChain tool for searching database schemas."""
    
    name: str = "schema_search"
    description: str = """
    Search for relevant database schemas based on a natural language query.
    Use this when you need to find which tables/columns are relevant to a user's question.
    Input should be a descriptive query like 'patient information' or 'drug trials'.
    """
    mcp_client: MCPClient
    
    def _run(self, query: str) -> str:
        """Execute schema search."""
        response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query})
        
        if response.get("status") == "success":
            schemas = response.get("schemas", [])
            if schemas:
                schema_text = "Found relevant schemas:\\n"
                for schema in schemas:
                    schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n"
                return schema_text
            else:
                return "No relevant schemas found."
        else:
            return f"Error searching schemas: {response.get('message', 'Unknown error')}"
    
    async def _arun(self, query: str) -> str:
        """Async version - just calls sync version."""
        return self._run(query)


class JoinPathFinderTool(BaseTool):
    """LangChain tool for finding join paths between tables."""
    
    name: str = "find_join_path"
    description: str = """
    Find how to join two tables together using foreign key relationships.
    Use this when you need to query across multiple tables.
    Input should be two table names separated by a comma, like 'patients,studies'.
    """
    mcp_client: MCPClient

    def _run(self, table_names: str) -> str:
        """Find join path."""
        try:
            tables = [t.strip() for t in table_names.split(',')]
            if len(tables) != 2:
                return "Please provide exactly two table names separated by a comma."
            
            response = self.mcp_client.post(
                "graph/find_join_path", 
                {"table1": tables[0], "table2": tables[1]}
            )
            
            if response.get("status") == "success":
                path = response.get("path", "No path found")
                return f"Join path: {path}"
            else:
                return f"Error finding join path: {response.get('message', 'Unknown error')}"
        except Exception as e:
            return f"Failed to find join path: {str(e)}"

    async def _arun(self, table_names: str) -> str:
        """Async version - just calls sync version."""
        return self._run(table_names)


class QueryExecutorTool(BaseTool):
    """LangChain tool for executing SQL queries."""
    
    name: str = "execute_query"
    description: str = """
    Execute a SQL query against the databases and return results.
    Use this after you have a valid SQL query.
    Input should be a valid SQL query string.
    """
    mcp_client: MCPClient

    def _run(self, sql: str) -> str:
        """Execute query."""
        try:
            response = self.mcp_client.post(
                "intelligence/execute_query", 
                {"sql": sql}
            )
            
            if response.get("status") == "success":
                results = response.get("results", [])
                
                if results:
                    # Format results as a readable table
                    result_text = f"Query returned {len(results)} rows:\\n"
                    headers = list(results[0].keys())
                    result_text += " | ".join(headers) + "\n"
                    result_text += "-" * (len(" | ".join(headers))) + "\n"
                    
                    for row in results[:10]:  # Limit display to first 10 rows
                        values = [str(row.get(h, "")) for h in headers]
                        result_text += " | ".join(values) + "\n"
                    
                    if len(results) > 10:
                        result_text += f"... and {len(results) - 10} more rows\n"
                    
                    return result_text
                else:
                    return "Query executed successfully but returned no results."
            else:
                return f"Error executing query: {response.get('message', 'Unknown error')}"
        except Exception as e:
            return f"Failed to execute query: {str(e)}"

    async def _arun(self, sql: str) -> str:
        """Async version - just calls sync version."""
        return self._run(sql)