| | from sqlalchemy import inspect |
| | from sqlalchemy.engine import Engine |
| | from typing import Dict, Any, List |
| | import logging |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | import asyncio |
| |
|
| | from .database import get_db_connections |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]: |
| | """Discovers the schema for a single database engine.""" |
| | inspector = inspect(engine) |
| | db_schema = { |
| | "database_name": db_name, |
| | "tables": [] |
| | } |
| | table_names = inspector.get_table_names() |
| | for table_name in table_names: |
| | columns = inspector.get_columns(table_name) |
| | db_schema["tables"].append({ |
| | "name": table_name, |
| | "columns": [{"name": c['name'], "type": str(c['type'])} for c in columns] |
| | }) |
| | return db_schema |
| |
|
| | async def discover_all_schemas() -> List[Dict[str, Any]]: |
| | """ |
| | Discovers the full schema for all connected databases in parallel. |
| | """ |
| | db_engines = get_db_connections() |
| | all_schemas = [] |
| | loop = asyncio.get_running_loop() |
| |
|
| | with ThreadPoolExecutor() as executor: |
| | |
| | futures = [ |
| | loop.run_in_executor(executor, _discover_single_db_schema, name, eng) |
| | for name, eng in db_engines.items() |
| | ] |
| | |
| | |
| | for future in asyncio.as_completed(futures): |
| | try: |
| | result = await future |
| | all_schemas.append(result) |
| | except Exception as e: |
| | logger.error(f"Schema discovery for a database failed: {e}", exc_info=True) |
| | |
| | return all_schemas |
| |
|
| | async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]: |
| | """ |
| | Discovers schemas and performs a simple keyword search. |
| | If no query is provided, returns the full schema. |
| | """ |
| | all_schemas = await discover_all_schemas() |
| |
|
| | if not query: |
| | |
| | flat_list = [] |
| | for db in all_schemas: |
| | for tbl in db.get("tables", []): |
| | for col in tbl.get("columns", []): |
| | flat_list.append({ |
| | "database": db["database_name"], |
| | "table": tbl["name"], |
| | "name": col["name"], |
| | "type": [col["type"]] |
| | }) |
| | return flat_list |
| |
|
| | |
| | keywords = query.lower().split() |
| | relevant_schemas = [] |
| | for db_schema in all_schemas: |
| | for table in db_schema.get("tables", []): |
| | match = False |
| | if any(keyword in table['name'].lower() for keyword in keywords): |
| | match = True |
| | else: |
| | for col in table.get("columns", []): |
| | if any(keyword in col['name'].lower() for keyword in keywords): |
| | match = True |
| | break |
| | |
| | if match: |
| | |
| | for col in table.get("columns", []): |
| | relevant_schemas.append({ |
| | "database": db_schema["database_name"], |
| | "table": table['name'], |
| | "name": col['name'], |
| | "type": [col['type']] |
| | }) |
| |
|
| | |
| | seen = set() |
| | deduped = [] |
| | for schema in relevant_schemas: |
| | |
| | key = (schema['database'], schema['table'], schema['name']) |
| | if key not in seen: |
| | seen.add(key) |
| | deduped.append(schema) |
| | return deduped |
| |
|