agent-mcp-sql / mcp /core /discovery.py
ohmygaugh's picture
demo working
a0eb181
from sqlalchemy import inspect
from sqlalchemy.engine import Engine
from typing import Dict, Any, List
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from .database import get_db_connections
logger = logging.getLogger(__name__)
def _discover_single_db_schema(db_name: str, engine: Engine) -> Dict[str, Any]:
"""Discovers the schema for a single database engine."""
inspector = inspect(engine)
db_schema = {
"database_name": db_name,
"tables": []
}
table_names = inspector.get_table_names()
for table_name in table_names:
columns = inspector.get_columns(table_name)
db_schema["tables"].append({
"name": table_name,
"columns": [{"name": c['name'], "type": str(c['type'])} for c in columns]
})
return db_schema
async def discover_all_schemas() -> List[Dict[str, Any]]:
"""
Discovers the full schema for all connected databases in parallel.
"""
db_engines = get_db_connections()
all_schemas = []
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
# Create a list of futures
futures = [
loop.run_in_executor(executor, _discover_single_db_schema, name, eng)
for name, eng in db_engines.items()
]
# await the results
for future in asyncio.as_completed(futures):
try:
result = await future
all_schemas.append(result)
except Exception as e:
logger.error(f"Schema discovery for a database failed: {e}", exc_info=True)
return all_schemas
async def get_relevant_schemas(query: str) -> List[Dict[str, Any]]:
"""
Discovers schemas and performs a simple keyword search.
If no query is provided, returns the full schema.
"""
all_schemas = await discover_all_schemas()
if not query:
# If no query, return a flat list of all tables and columns for the UI
flat_list = []
for db in all_schemas:
for tbl in db.get("tables", []):
for col in tbl.get("columns", []):
flat_list.append({
"database": db["database_name"],
"table": tbl["name"],
"name": col["name"],
"type": [col["type"]]
})
return flat_list
# Simple keyword filtering logic...
keywords = query.lower().split()
relevant_schemas = []
for db_schema in all_schemas:
for table in db_schema.get("tables", []):
match = False
if any(keyword in table['name'].lower() for keyword in keywords):
match = True
else:
for col in table.get("columns", []):
if any(keyword in col['name'].lower() for keyword in keywords):
match = True
break # column match is enough
if match:
# Return the full table info if there's a match
for col in table.get("columns", []):
relevant_schemas.append({
"database": db_schema["database_name"],
"table": table['name'],
"name": col['name'],
"type": [col['type']]
})
# Deduplicate results by converting to JSON strings
seen = set()
deduped = []
for schema in relevant_schemas:
# Convert to tuple for deduplication (lists aren't hashable)
key = (schema['database'], schema['table'], schema['name'])
if key not in seen:
seen.add(key)
deduped.append(schema)
return deduped