File size: 3,827 Bytes
9d411a7 86cbe3c 9d411a7 a0eb181 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c 9d411a7 86cbe3c a0eb181 9d411a7 a0eb181 9d411a7 a0eb181 86cbe3c 9d411a7 a0eb181 9d411a7 a0eb181 9d411a7 a0eb181 9d411a7 a0eb181 9d411a7 a0eb181 9d411a7 a0eb181 9d411a7 a0eb181 9d411a7 a0eb181 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | from sqlalchemy import inspect
from sqlalchemy.engine import Engine
from typing import Dict, Any, List
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from .database import get_db_connections
logger = logging.getLogger(__name__)
def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]:
"""Discovers the schema for a single database engine."""
inspector = inspect(engine)
db_schema = {
"database_name": db_name,
"tables": []
}
table_names = inspector.get_table_names()
for table_name in table_names:
columns = inspector.get_columns(table_name)
db_schema["tables"].append({
"name": table_name,
"columns": [{"name": c['name'], "type": str(c['type'])} for c in columns]
})
return db_schema
async def discover_all_schemas() -> List[Dict[str, Any]]:
"""
Discovers the full schema for all connected databases in parallel.
"""
db_engines = get_db_connections()
all_schemas = []
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
# Create a list of futures
futures = [
loop.run_in_executor(executor, _discover_single_db_schema, name, eng)
for name, eng in db_engines.items()
]
# await the results
for future in asyncio.as_completed(futures):
try:
result = await future
all_schemas.append(result)
except Exception as e:
logger.error(f"Schema discovery for a database failed: {e}", exc_info=True)
return all_schemas
async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]:
"""
Discovers schemas and performs a simple keyword search.
If no query is provided, returns the full schema.
"""
all_schemas = await discover_all_schemas()
if not query:
# If no query, return a flat list of all tables and columns for the UI
flat_list = []
for db in all_schemas:
for tbl in db.get("tables", []):
for col in tbl.get("columns", []):
flat_list.append({
"database": db["database_name"],
"table": tbl["name"],
"name": col["name"],
"type": [col["type"]]
})
return flat_list
# Simple keyword filtering logic...
keywords = query.lower().split()
relevant_schemas = []
for db_schema in all_schemas:
for table in db_schema.get("tables", []):
match = False
if any(keyword in table['name'].lower() for keyword in keywords):
match = True
else:
for col in table.get("columns", []):
if any(keyword in col['name'].lower() for keyword in keywords):
match = True
break # column match is enough
if match:
# Return the full table info if there's a match
for col in table.get("columns", []):
relevant_schemas.append({
"database": db_schema["database_name"],
"table": table['name'],
"name": col['name'],
"type": [col['type']]
})
# Deduplicate results by converting to JSON strings
seen = set()
deduped = []
for schema in relevant_schemas:
# Convert to tuple for deduplication (lists aren't hashable)
key = (schema['database'], schema['table'], schema['name'])
if key not in seen:
seen.add(key)
deduped.append(schema)
return deduped
|