File size: 3,095 Bytes
86cbe3c
 
 
 
 
9d411a7
86cbe3c
9d411a7
86cbe3c
9d411a7
86cbe3c
9d411a7
 
86cbe3c
9d411a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8595be6
9d411a7
 
 
 
8595be6
 
9d411a7
8595be6
 
 
 
9d411a7
 
8595be6
9d411a7
 
 
 
8595be6
 
9d411a7
 
 
 
 
 
8595be6
9d411a7
 
 
8595be6
 
 
 
 
 
86cbe3c
9d411a7
 
86cbe3c
9d411a7
 
86cbe3c
9d411a7
86cbe3c
9d411a7
86cbe3c
9d411a7
86cbe3c
9d411a7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import sqlparse
import logging
from typing import List, Dict, Any
from sqlalchemy import text

from .database import get_db_connections

logger = logging.getLogger(__name__)

def _get_database_for_table(table_name: str) -> str | None:
    """
    Finds which database a table belongs to by checking the graph.
    (This is a simplified helper; assumes GraphStore is accessible or passed)
    """
    # This is a placeholder for the logic to find a table's database.
    # In a real scenario, this would query Neo4j. We'll simulate it.
    # A simple mapping for our known databases:
    if table_name in ["studies", "patients", "adverse_events"]:
        return "clinical_trials"
    if table_name in ["lab_tests", "test_results", "biomarkers"]:
        return "laboratory"
    if table_name in ["compounds", "assay_results", "drug_targets", "compound_targets"]:
        return "drug_discovery"
    return None


async def execute_federated_query(sql: str) -> List[Dict[str, Any]]:
    """
    Executes a SQL query against the correct SQLite database.
    Strips database prefixes from table names (e.g., clinical_trials.patients → patients).
    """
    parsed = sqlparse.parse(sql)[0]
    target_table = None

    # Find table name from FROM clause
    from_found = False
    for token in parsed.tokens:
        if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'FROM':
            from_found = True
            continue
        elif from_found and isinstance(token, sqlparse.sql.Identifier):
            target_table = token.get_real_name()
            break
        elif from_found and token.is_group:
            for sub_token in token.tokens:
                if isinstance(sub_token, sqlparse.sql.Identifier):
                    target_table = sub_token.get_real_name()
                    break
            if target_table:
                break
    
    if not target_table:
        raise ValueError("Could not identify a target table in the SQL query.")

    logger.info(f"Identified target table: {target_table}")
    
    # Determine which database this table belongs to
    db_name = _get_database_for_table(target_table)
    if not db_name:
        raise ValueError(f"Table '{target_table}' not found in any known database.")
    
    # Strip all database prefixes from SQL (e.g., "clinical_trials.patients" → "patients")
    for known_db in ["clinical_trials", "laboratory", "drug_discovery"]:
        sql = sql.replace(f"{known_db}.", "")
    
    logger.info(f"Cleaned SQL for database '{db_name}': {sql}")
        
    db_engines = get_db_connections()
    engine = db_engines.get(db_name)

    if not engine:
        raise ConnectionError(f"No active connection for database '{db_name}'.")

    logger.info(f"Executing query against database: {db_name}")

    try:
        with engine.connect() as connection:
            result = connection.execute(text(sql))
            return [dict(row._mapping) for row in result.fetchall()]
    except Exception as e:
        logger.error(f"Failed to execute query on {db_name}: {e}", exc_info=True)
        raise