Spaces:
No application file
No application file
| from sqlalchemy import inspect | |
| from sqlalchemy.engine import Engine | |
| from typing import Dict, Any, List | |
| import logging | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import asyncio | |
| from .database import get_db_connections | |
| logger = logging.getLogger(__name__) | |
| def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]: | |
| """Discovers the schema for a single database engine.""" | |
| inspector = inspect(engine) | |
| db_schema = { | |
| "database_name": db_name, | |
| "tables": [] | |
| } | |
| table_names = inspector.get_table_names() | |
| for table_name in table_names: | |
| columns = inspector.get_columns(table_name) | |
| db_schema["tables"].append({ | |
| "name": table_name, | |
| "columns": [{"name": c['name'], "type": str(c['type'])} for c in columns] | |
| }) | |
| return db_schema | |
| async def discover_all_schemas() -> List[Dict[str, Any]]: | |
| """ | |
| Discovers the full schema for all connected databases in parallel. | |
| """ | |
| db_engines = get_db_connections() | |
| all_schemas = [] | |
| loop = asyncio.get_running_loop() | |
| with ThreadPoolExecutor() as executor: | |
| # Create a list of futures | |
| futures = [ | |
| loop.run_in_executor(executor, _discover_single_db_schema, name, eng) | |
| for name, eng in db_engines.items() | |
| ] | |
| # await the results | |
| for future in asyncio.as_completed(futures): | |
| try: | |
| result = await future | |
| all_schemas.append(result) | |
| except Exception as e: | |
| logger.error(f"Schema discovery for a database failed: {e}", exc_info=True) | |
| return all_schemas | |
| async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]: | |
| """ | |
| Discovers schemas and performs a simple keyword search. | |
| If no query is provided, returns the full schema. | |
| """ | |
| all_schemas = await discover_all_schemas() | |
| if not query: | |
| # If no query, return a flat list of all tables and columns for the UI | |
| flat_list = [] | |
| for db in all_schemas: | |
| for tbl in db.get("tables", []): | |
| for col in tbl.get("columns", []): | |
| flat_list.append({ | |
| "database": db["database_name"], | |
| "table": tbl["name"], | |
| "name": col["name"], | |
| "type": [col["type"]] | |
| }) | |
| return flat_list | |
| # Simple keyword filtering logic... | |
| keywords = query.lower().split() | |
| relevant_schemas = [] | |
| for db_schema in all_schemas: | |
| for table in db_schema.get("tables", []): | |
| match = False | |
| if any(keyword in table['name'].lower() for keyword in keywords): | |
| match = True | |
| else: | |
| for col in table.get("columns", []): | |
| if any(keyword in col['name'].lower() for keyword in keywords): | |
| match = True | |
| break # column match is enough | |
| if match: | |
| # Return the full table info if there's a match | |
| for col in table.get("columns", []): | |
| relevant_schemas.append({ | |
| "database": db_schema["database_name"], | |
| "table": table['name'], | |
| "name": col['name'], | |
| "type": [col['type']] | |
| }) | |
| # Deduplicate results by converting to JSON strings | |
| seen = set() | |
| deduped = [] | |
| for schema in relevant_schemas: | |
| # Convert to tuple for deduplication (lists aren't hashable) | |
| key = (schema['database'], schema['table'], schema['name']) | |
| if key not in seen: | |
| seen.add(key) | |
| deduped.append(schema) | |
| return deduped | |