| | import sqlparse |
| | import logging |
| | from typing import List, Dict, Any |
| | from sqlalchemy import text |
| |
|
| | from .database import get_db_connections |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | def _get_database_for_table(table_name: str) -> str | None: |
| | """ |
| | Finds which database a table belongs to by checking the graph. |
| | (This is a simplified helper; assumes GraphStore is accessible or passed) |
| | """ |
| | |
| | |
| | |
| | if table_name in ["studies", "patients", "adverse_events"]: |
| | return "clinical_trials" |
| | if table_name in ["lab_tests", "test_results", "biomarkers"]: |
| | return "laboratory" |
| | if table_name in ["compounds", "assay_results", "drug_targets", "compound_targets"]: |
| | return "drug_discovery" |
| | return None |
| |
|
| |
|
| | async def execute_federated_query(sql: str) -> List[Dict[str, Any]]: |
| | """ |
| | Executes a SQL query against the correct SQLite database. |
| | Strips database prefixes from table names (e.g., clinical_trials.patients → patients). |
| | """ |
| | parsed = sqlparse.parse(sql)[0] |
| | target_table = None |
| |
|
| | |
| | from_found = False |
| | for token in parsed.tokens: |
| | if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM': |
| | from_found = True |
| | continue |
| | elif from_found and isinstance(token, sqlparse.sql.Identifier): |
| | target_table = token.get_real_name() |
| | break |
| | elif from_found and token.is_group: |
| | for sub_token in token.tokens: |
| | if isinstance(sub_token, sqlparse.sql.Identifier): |
| | target_table = sub_token.get_real_name() |
| | break |
| | if target_table: |
| | break |
| | |
| | if not target_table: |
| | raise ValueError("Could not identify a target table in the SQL query.") |
| |
|
| | logger.info(f"Identified target table: {target_table}") |
| | |
| | |
| | db_name = _get_database_for_table(target_table) |
| | if not db_name: |
| | raise ValueError(f"Table '{target_table}' not found in any known database.") |
| | |
| | |
| | for known_db in ["clinical_trials", "laboratory", "drug_discovery"]: |
| | sql = sql.replace(f"{known_db}.", "") |
| | |
| | logger.info(f"Cleaned SQL for database '{db_name}': {sql}") |
| | |
| | db_engines = get_db_connections() |
| | engine = db_engines.get(db_name) |
| |
|
| | if not engine: |
| | raise ConnectionError(f"No active connection for database '{db_name}'.") |
| |
|
| | logger.info(f"Executing query against database: {db_name}") |
| |
|
| | try: |
| | with engine.connect() as connection: |
| | result = connection.execute(text(sql)) |
| | return [dict(row._mapping) for row in result.fetchall()] |
| | except Exception as e: |
| | logger.error(f"Failed to execute query on {db_name}: {e}", exc_info=True) |
| | raise |
| |
|