agent-mcp-sql / agent /tools.py
ohmygaugh's picture
All major query types now work:
6422ca4
import os
import requests
import json
from typing import Dict, Any, List, Optional
from langchain.tools import BaseTool
from pydantic import Field
import logging
logger = logging.getLogger(__name__)
class MCPClient:
"""Client for making authenticated REST API calls to the MCP server."""
def __init__(self, mcp_url: str, api_key: str):
self.mcp_url = mcp_url
self.headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Send a POST request to a given MCP endpoint."""
try:
url = f"{self.mcp_url}/{endpoint}"
response = requests.post(url, headers=self.headers, data=json.dumps(data))
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
logger.error(f"HTTP error occurred: {http_err} - {response.text}")
return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"}
except requests.exceptions.RequestException as req_err:
logger.error(f"Request error occurred: {req_err}")
return {"status": "error", "message": f"Request failed: {req_err}"}
except json.JSONDecodeError:
logger.error("Failed to decode JSON response.")
return {"status": "error", "message": "Invalid JSON response from server."}
class SchemaSearchTool(BaseTool):
"""LangChain tool for searching database schemas."""
name: str = "schema_search"
description: str = """
Search for relevant database schemas based on a natural language query.
Use this when you need to find which tables/columns are relevant to a user's question.
Input should be a descriptive query like 'patient information' or 'drug trials'.
"""
mcp_client: MCPClient
def _run(self, query: str) -> str:
"""Execute schema search."""
response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query})
if response.get("status") == "success":
schemas = response.get("schemas", [])
if schemas:
schema_text = "Found relevant schemas:\\n"
for schema in schemas:
schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n"
return schema_text
else:
return "No relevant schemas found."
else:
return f"Error searching schemas: {response.get('message', 'Unknown error')}"
async def _arun(self, query: str) -> str:
"""Async version - just calls sync version."""
return self._run(query)
class JoinPathFinderTool(BaseTool):
"""LangChain tool for finding join paths between tables."""
name: str = "find_join_path"
description: str = """
Find how to join two tables together using foreign key relationships.
Use this when you need to query across multiple tables.
Input should be two table names separated by a comma, like 'patients,studies'.
"""
mcp_client: MCPClient
def _run(self, table_names: str) -> str:
"""Find join path."""
try:
tables = [t.strip() for t in table_names.split(',')]
if len(tables) != 2:
return "Please provide exactly two table names separated by a comma."
response = self.mcp_client.post(
"graph/find_join_path",
{"table1": tables[0], "table2": tables[1]}
)
if response.get("status") == "success":
path = response.get("path", "No path found")
return f"Join path: {path}"
else:
return f"Error finding join path: {response.get('message', 'Unknown error')}"
except Exception as e:
return f"Failed to find join path: {str(e)}"
async def _arun(self, table_names: str) -> str:
"""Async version - just calls sync version."""
return self._run(table_names)
class QueryExecutorTool(BaseTool):
"""LangChain tool for executing SQL queries."""
name: str = "execute_query"
description: str = """
Execute a SQL query against the databases and return results.
Use this after you have a valid SQL query.
Input should be a valid SQL query string.
"""
mcp_client: MCPClient
def _run(self, sql: str) -> str:
"""Execute query."""
try:
response = self.mcp_client.post(
"intelligence/execute_query",
{"sql": sql}
)
if response.get("status") == "success":
results = response.get("results", [])
if results:
# Format results as a readable table
result_text = f"Query returned {len(results)} rows:\\n"
headers = list(results[0].keys())
result_text += " | ".join(headers) + "\n"
result_text += "-" * (len(" | ".join(headers))) + "\n"
for row in results[:10]: # Limit display to first 10 rows
values = [str(row.get(h, "")) for h in headers]
result_text += " | ".join(values) + "\n"
if len(results) > 10:
result_text += f"... and {len(results) - 10} more rows\n"
return result_text
else:
return "Query executed successfully but returned no results."
else:
return f"Error executing query: {response.get('message', 'Unknown error')}"
except Exception as e:
return f"Failed to execute query: {str(e)}"
async def _arun(self, sql: str) -> str:
"""Async version - just calls sync version."""
return self._run(sql)