File size: 3,827 Bytes
9d411a7
86cbe3c
 
 
9d411a7
a0eb181
86cbe3c
9d411a7
86cbe3c
9d411a7
86cbe3c
9d411a7
 
86cbe3c
 
9d411a7
86cbe3c
 
 
 
9d411a7
 
 
 
 
 
86cbe3c
a0eb181
9d411a7
a0eb181
9d411a7
 
 
a0eb181
86cbe3c
9d411a7
a0eb181
 
 
 
 
 
 
 
9d411a7
a0eb181
 
9d411a7
a0eb181
 
 
 
 
 
 
 
 
 
9d411a7
 
a0eb181
 
 
 
 
 
 
 
 
 
 
 
9d411a7
a0eb181
9d411a7
 
 
 
a0eb181
9d411a7
a0eb181
9d411a7
 
 
a0eb181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from sqlalchemy import inspect
from sqlalchemy.engine import Engine
from typing import Dict, Any, List
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio

from .database import get_db_connections

logger = logging.getLogger(__name__)

def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]:
    """Discovers the schema for a single database engine."""
    inspector = inspect(engine)
    db_schema = {
        "database_name": db_name,
        "tables": []
    }
    table_names = inspector.get_table_names()
    for table_name in table_names:
        columns = inspector.get_columns(table_name)
        db_schema["tables"].append({
            "name": table_name,
            "columns": [{"name": c['name'], "type": str(c['type'])} for c in columns]
        })
    return db_schema

async def discover_all_schemas() -> List[Dict[str, Any]]:
    """
    Discovers the full schema for all connected databases in parallel.
    """
    db_engines = get_db_connections()
    all_schemas = []
    loop = asyncio.get_running_loop()

    with ThreadPoolExecutor() as executor:
        # Create a list of futures
        futures = [
            loop.run_in_executor(executor, _discover_single_db_schema, name, eng)
            for name, eng in db_engines.items()
        ]
        
        # await the results
        for future in asyncio.as_completed(futures):
            try:
                result = await future
                all_schemas.append(result)
            except Exception as e:
                logger.error(f"Schema discovery for a database failed: {e}", exc_info=True)
                
    return all_schemas

async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]:
    """
    Discovers schemas and performs a simple keyword search.
    If no query is provided, returns the full schema.
    """
    all_schemas = await discover_all_schemas()

    if not query:
        # If no query, return a flat list of all tables and columns for the UI
        flat_list = []
        for db in all_schemas:
            for tbl in db.get("tables", []):
                for col in tbl.get("columns", []):
                    flat_list.append({
                        "database": db["database_name"],
                        "table": tbl["name"],
                        "name": col["name"],
                        "type": [col["type"]]
                    })
        return flat_list

    # Simple keyword filtering logic...
    keywords = query.lower().split()
    relevant_schemas = []
    for db_schema in all_schemas:
        for table in db_schema.get("tables", []):
            match = False
            if any(keyword in table['name'].lower() for keyword in keywords):
                match = True
            else:
                for col in table.get("columns", []):
                    if any(keyword in col['name'].lower() for keyword in keywords):
                        match = True
                        break # column match is enough
            
            if match:
                # Return the full table info if there's a match
                for col in table.get("columns", []):
                     relevant_schemas.append({
                        "database": db_schema["database_name"],
                        "table": table['name'],
                        "name": col['name'],
                        "type": [col['type']]
                    })

    # Deduplicate results by converting to JSON strings
    seen = set()
    deduped = []
    for schema in relevant_schemas:
        # Convert to tuple for deduplication (lists aren't hashable)
        key = (schema['database'], schema['table'], schema['name'])
        if key not in seen:
            seen.add(key)
            deduped.append(schema)
    return deduped