Spaces:
No application file
No application file
| import sqlparse | |
| import logging | |
| from typing import List, Dict, Any | |
| from sqlalchemy import text | |
| from .database import get_db_connections | |
| logger = logging.getLogger(__name__) | |
| def _get_database_for_table(table_name: str) -> str | None: | |
| """ | |
| Finds which database a table belongs to by checking the graph. | |
| (This is a simplified helper; assumes GraphStore is accessible or passed) | |
| """ | |
| # This is a placeholder for the logic to find a table's database. | |
| # In a real scenario, this would query Neo4j. We'll simulate it. | |
| # A simple mapping for our known databases: | |
| if table_name in ["studies", "patients", "adverse_events"]: | |
| return "clinical_trials" | |
| if table_name in ["lab_tests", "test_results", "biomarkers"]: | |
| return "laboratory" | |
| if table_name in ["compounds", "assay_results", "drug_targets", "compound_targets"]: | |
| return "drug_discovery" | |
| return None | |
| async def execute_federated_query(sql: str) -> List[Dict[str, Any]]: | |
| """ | |
| Executes a SQL query against the correct SQLite database. | |
| Strips database prefixes from table names (e.g., clinical_trials.patients → patients). | |
| """ | |
| parsed = sqlparse.parse(sql)[0] | |
| target_table = None | |
| # Find table name from FROM clause | |
| from_found = False | |
| for token in parsed.tokens: | |
| if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM': | |
| from_found = True | |
| continue | |
| elif from_found and isinstance(token, sqlparse.sql.Identifier): | |
| target_table = token.get_real_name() | |
| break | |
| elif from_found and token.is_group: | |
| for sub_token in token.tokens: | |
| if isinstance(sub_token, sqlparse.sql.Identifier): | |
| target_table = sub_token.get_real_name() | |
| break | |
| if target_table: | |
| break | |
| if not target_table: | |
| raise ValueError("Could not identify a target table in the SQL query.") | |
| logger.info(f"Identified target table: {target_table}") | |
| # Determine which database this table belongs to | |
| db_name = _get_database_for_table(target_table) | |
| if not db_name: | |
| raise ValueError(f"Table '{target_table}' not found in any known database.") | |
| # Strip all database prefixes from SQL (e.g., "clinical_trials.patients" → "patients") | |
| for known_db in ["clinical_trials", "laboratory", "drug_discovery"]: | |
| sql = sql.replace(f"{known_db}.", "") | |
| logger.info(f"Cleaned SQL for database '{db_name}': {sql}") | |
| db_engines = get_db_connections() | |
| engine = db_engines.get(db_name) | |
| if not engine: | |
| raise ConnectionError(f"No active connection for database '{db_name}'.") | |
| logger.info(f"Executing query against database: {db_name}") | |
| try: | |
| with engine.connect() as connection: | |
| result = connection.execute(text(sql)) | |
| return [dict(row._mapping) for row in result.fetchall()] | |
| except Exception as e: | |
| logger.error(f"Failed to execute query on {db_name}: {e}", exc_info=True) | |
| raise | |