from sqlalchemy import inspect from sqlalchemy.engine import Engine from typing import Dict, Any, List import logging from concurrent.futures import ThreadPoolExecutor, as_completed import asyncio from .database import get_db_connections logger = logging.getLogger(__name__) def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]: """Discovers the schema for a single database engine.""" inspector = inspect(engine) db_schema = { "database_name": db_name, "tables": [] } table_names = inspector.get_table_names() for table_name in table_names: columns = inspector.get_columns(table_name) db_schema["tables"].append({ "name": table_name, "columns": [{"name": c['name'], "type": str(c['type'])} for c in columns] }) return db_schema async def discover_all_schemas() -> List[Dict[str, Any]]: """ Discovers the full schema for all connected databases in parallel. """ db_engines = get_db_connections() all_schemas = [] loop = asyncio.get_running_loop() with ThreadPoolExecutor() as executor: # Create a list of futures futures = [ loop.run_in_executor(executor, _discover_single_db_schema, name, eng) for name, eng in db_engines.items() ] # await the results for future in asyncio.as_completed(futures): try: result = await future all_schemas.append(result) except Exception as e: logger.error(f"Schema discovery for a database failed: {e}", exc_info=True) return all_schemas async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]: """ Discovers schemas and performs a simple keyword search. If no query is provided, returns the full schema. """ all_schemas = await discover_all_schemas() if not query: # If no query, return a flat list of all tables and columns for the UI flat_list = [] for db in all_schemas: for tbl in db.get("tables", []): for col in tbl.get("columns", []): flat_list.append({ "database": db["database_name"], "table": tbl["name"], "name": col["name"], "type": [col["type"]] }) return flat_list # Simple keyword filtering logic... keywords = query.lower().split() relevant_schemas = [] for db_schema in all_schemas: for table in db_schema.get("tables", []): match = False if any(keyword in table['name'].lower() for keyword in keywords): match = True else: for col in table.get("columns", []): if any(keyword in col['name'].lower() for keyword in keywords): match = True break # column match is enough if match: # Return the full table info if there's a match for col in table.get("columns", []): relevant_schemas.append({ "database": db_schema["database_name"], "table": table['name'], "name": col['name'], "type": [col['type']] }) # Deduplicate results by converting to JSON strings seen = set() deduped = [] for schema in relevant_schemas: # Convert to tuple for deduplication (lists aren't hashable) key = (schema['database'], schema['table'], schema['name']) if key not in seen: seen.add(key) deduped.append(schema) return deduped