import sqlparse import logging from typing import List, Dict, Any from sqlalchemy import text from .database import get_db_connections logger = logging.getLogger(__name__) def _get_database_for_table(table_name: str) -> str | None: """ Finds which database a table belongs to by checking the graph. (This is a simplified helper; assumes GraphStore is accessible or passed) """ # This is a placeholder for the logic to find a table's database. # In a real scenario, this would query Neo4j. We'll simulate it. # A simple mapping for our known databases: if table_name in ["studies", "patients", "adverse_events"]: return "clinical_trials" if table_name in ["lab_tests", "test_results", "biomarkers"]: return "laboratory" if table_name in ["compounds", "assay_results", "drug_targets", "compound_targets"]: return "drug_discovery" return None async def execute_federated_query(sql: str) -> List[Dict[str, Any]]: """ Executes a SQL query against the correct SQLite database. Strips database prefixes from table names (e.g., clinical_trials.patients → patients). """ parsed = sqlparse.parse(sql)[0] target_table = None # Find table name from FROM clause from_found = False for token in parsed.tokens: if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM': from_found = True continue elif from_found and isinstance(token, sqlparse.sql.Identifier): target_table = token.get_real_name() break elif from_found and token.is_group: for sub_token in token.tokens: if isinstance(sub_token, sqlparse.sql.Identifier): target_table = sub_token.get_real_name() break if target_table: break if not target_table: raise ValueError("Could not identify a target table in the SQL query.") logger.info(f"Identified target table: {target_table}") # Determine which database this table belongs to db_name = _get_database_for_table(target_table) if not db_name: raise ValueError(f"Table '{target_table}' not found in any known database.") # Strip all database prefixes from SQL (e.g., "clinical_trials.patients" → "patients") for known_db in ["clinical_trials", "laboratory", "drug_discovery"]: sql = sql.replace(f"{known_db}.", "") logger.info(f"Cleaned SQL for database '{db_name}': {sql}") db_engines = get_db_connections() engine = db_engines.get(db_name) if not engine: raise ConnectionError(f"No active connection for database '{db_name}'.") logger.info(f"Executing query against database: {db_name}") try: with engine.connect() as connection: result = connection.execute(text(sql)) return [dict(row._mapping) for row in result.fetchall()] except Exception as e: logger.error(f"Failed to execute query on {db_name}: {e}", exc_info=True) raise