agent-mcp-sql / mcp /core /intelligence.py
ohmygaugh's picture
fixed one more bug. during stress testing
8595be6
import sqlparse
import logging
from typing import List, Dict, Any
from sqlalchemy import text
from .database import get_db_connections
logger = logging.getLogger(__name__)
def _get_database_for_table(table_name: str) -> str | None:
"""
Finds which database a table belongs to by checking the graph.
(This is a simplified helper; assumes GraphStore is accessible or passed)
"""
# This is a placeholder for the logic to find a table's database.
# In a real scenario, this would query Neo4j. We'll simulate it.
# A simple mapping for our known databases:
if table_name in ["studies", "patients", "adverse_events"]:
return "clinical_trials"
if table_name in ["lab_tests", "test_results", "biomarkers"]:
return "laboratory"
if table_name in ["compounds", "assay_results", "drug_targets", "compound_targets"]:
return "drug_discovery"
return None
async def execute_federated_query(sql: str) -> List[Dict[str, Any]]:
"""
Executes a SQL query against the correct SQLite database.
Strips database prefixes from table names (e.g., clinical_trials.patients → patients).
"""
parsed = sqlparse.parse(sql)[0]
target_table = None
# Find table name from FROM clause
from_found = False
for token in parsed.tokens:
if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM':
from_found = True
continue
elif from_found and isinstance(token, sqlparse.sql.Identifier):
target_table = token.get_real_name()
break
elif from_found and token.is_group:
for sub_token in token.tokens:
if isinstance(sub_token, sqlparse.sql.Identifier):
target_table = sub_token.get_real_name()
break
if target_table:
break
if not target_table:
raise ValueError("Could not identify a target table in the SQL query.")
logger.info(f"Identified target table: {target_table}")
# Determine which database this table belongs to
db_name = _get_database_for_table(target_table)
if not db_name:
raise ValueError(f"Table '{target_table}' not found in any known database.")
# Strip all database prefixes from SQL (e.g., "clinical_trials.patients" → "patients")
for known_db in ["clinical_trials", "laboratory", "drug_discovery"]:
sql = sql.replace(f"{known_db}.", "")
logger.info(f"Cleaned SQL for database '{db_name}': {sql}")
db_engines = get_db_connections()
engine = db_engines.get(db_name)
if not engine:
raise ConnectionError(f"No active connection for database '{db_name}'.")
logger.info(f"Executing query against database: {db_name}")
try:
with engine.connect() as connection:
result = connection.execute(text(sql))
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error(f"Failed to execute query on {db_name}: {e}", exc_info=True)
raise