"""Executor for registered database sources (source_type="database"). Flow per (client_id, question): 1. Collect all relevant (table_name, column_name) pairs from retrieval results. 2. Fetch the FULL schema for those tables from PGVector (not just top-k columns). 3. Build a schema context string and send to LLM → structured SQLQuery output. 4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced. 5. Execute on the user's DB via engine_scope + asyncio.to_thread. 6. Return QueryResult per client_id (may span multiple tables via JOINs). Supported db_types: postgres, supabase, mysql. Other types are skipped with a warning — they do not raise. """ import asyncio from collections import defaultdict from typing import Any import sqlglot import sqlglot.expressions as exp import tiktoken from langchain_core.prompts import ChatPromptTemplate from langchain_openai import AzureChatOpenAI from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from src.config.settings import settings from src.database_client.database_client_service import database_client_service from src.db.postgres.connection import _pgvector_engine from src.middlewares.logging import get_logger from src.models.sql_query import SQLQuery from src.pipeline.db_pipeline import db_pipeline_service from src.query.base import BaseExecutor, QueryResult from src.rag.base import RetrievalResult from src.utils.db_credential_encryption import decrypt_credentials_dict logger = get_logger("db_executor") _enc = tiktoken.get_encoding("cl100k_base") _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"} _MAX_RETRIES = 3 _MAX_LIMIT = 500 _FK_EXPANSION_MAX_TABLES = 5 _SQL_SYSTEM_PROMPT = """\ You are a SQL data analyst working with a user's database. Generate a single SQL SELECT statement that answers the user's question. Database dialect: {dialect} Rules: - ONLY reference tables and columns listed in the schema below. Do not invent names. - Always include a LIMIT clause (max {limit}). - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL. - Prefer explicit JOINs over subqueries when combining tables. - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count). - For date filtering, use dialect-appropriate functions ({dialect} syntax). Schema: {schema} {error_section}""" class DbExecutor(BaseExecutor): def __init__(self) -> None: self._llm = AzureChatOpenAI( azure_deployment=settings.azureai_deployment_name_4o, openai_api_version=settings.azureai_api_version_4o, azure_endpoint=settings.azureai_endpoint_url_4o, api_key=settings.azureai_api_key_4o, temperature=0, ) self._prompt = ChatPromptTemplate.from_messages([ ("system", _SQL_SYSTEM_PROMPT), ("human", "{question}"), ]) self._chain = self._prompt | self._llm.with_structured_output(SQLQuery) # ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------ async def execute( self, results: list[RetrievalResult], user_id: str, db: AsyncSession, question: str, limit: int = 100, ) -> list[QueryResult]: db_results = [r for r in results if r.source_type == "database"] if not db_results: return [] # Group by client_id — one SQL generation + execution pass per client by_client: dict[str, list[RetrievalResult]] = defaultdict(list) for r in db_results: client_id = r.metadata.get("database_client_id", "") if client_id: by_client[client_id].append(r) else: logger.warning("db result missing database_client_id, skipping") query_results: list[QueryResult] = [] for client_id, client_results in by_client.items(): try: qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit) if qr: query_results.append(qr) except Exception as e: logger.error("db executor failed for client", client_id=client_id, error=str(e)) return query_results # ------------------------------------------------------------------ # Per-client execution # ------------------------------------------------------------------ async def _execute_for_client( self, client_id: str, results: list[RetrievalResult], user_id: str, db: AsyncSession, question: str, limit: int, ) -> QueryResult | None: client = await database_client_service.get(db, client_id) if not client: logger.warning("database client not found", client_id=client_id) return None if client.user_id != user_id: logger.warning("client ownership mismatch", client_id=client_id) return None if client.db_type not in _SUPPORTED_DB_TYPES: logger.warning("unsupported db_type for query execution", db_type=client.db_type) return None # Hit tables = tables retrieval pointed at directly. Get full per-column # schema for these. Related tables (one FK hop away, both directions) are # fetched separately in abbreviated form to give the LLM enough context # to JOIN without paying the per-column profile token cost. hit_tables = list({ r.metadata.get("data", {}).get("table_name") for r in results if r.metadata.get("data", {}).get("table_name") }) if not hit_tables: logger.warning("no table_name on any retrieval result", client_id=client_id) return None full_schema = await self._fetch_full_schema(client_id, hit_tables, user_id) if not full_schema: logger.warning("no schema found in vector store", client_id=client_id, tables=hit_tables) return None related_tables = await self._find_related_tables(client_id, user_id, hit_tables) related_schema = ( await self._fetch_abbreviated_schema(client_id, user_id, related_tables) if related_tables else {} ) schema_ctx = self._build_schema_context(full_schema, related_schema) capped_limit = min(limit, _MAX_LIMIT) dialect = client.db_type # SQL generation with retry validated_sql: str | None = None prev_error: str = "" prev_reasoning: str = "" for attempt in range(_MAX_RETRIES): if prev_error: error_section = ( f"Previous attempt reasoning: {prev_reasoning}\n" f"Previous attempt failed: {prev_error}\n" "Fix the issue above." ) else: error_section = "" try: prompt_text = schema_ctx + error_section + question input_tokens = len(_enc.encode(prompt_text)) logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens) result: SQLQuery = await self._chain.ainvoke({ "schema": schema_ctx, "dialect": dialect, "limit": capped_limit, "error_section": error_section, "question": question, }) sql = result.sql.strip() allowed_tables = set(full_schema) | set(related_schema) column_map: dict[str, set[str]] = { t: {c["name"] for c in cols} for t, cols in full_schema.items() } for t, info in related_schema.items(): column_map[t] = set(info.get("column_names") or []) validation_error = self._validate(sql, allowed_tables, capped_limit, column_map) if validation_error: prev_error = validation_error prev_reasoning = result.reasoning logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error) continue validated_sql = self._enforce_limit(sql, capped_limit) output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning)) logger.info( "sql generated", attempt=attempt + 1, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=input_tokens + output_tokens, reasoning=result.reasoning, ) break except Exception as e: prev_error = str(e) logger.warning("sql generation error", attempt=attempt + 1, error=prev_error) if not validated_sql: logger.error("sql generation failed after retries", client_id=client_id) return None # Execute on user's DB creds = decrypt_credentials_dict(client.credentials) with db_pipeline_service.engine_scope(client.db_type, creds) as engine: rows = await asyncio.to_thread(self._run_sql, engine, validated_sql) column_types = { col["name"]: col["type"] for cols in full_schema.values() for col in cols } columns = list(rows[0].keys()) if rows else [] return QueryResult( source_type="database", source_id=client_id, table_or_file=", ".join(hit_tables), columns=columns, rows=rows, row_count=len(rows), metadata={ "db_type": client.db_type, "client_name": client.name, "sql": validated_sql, "column_types": {c: column_types.get(c, "unknown") for c in columns}, }, ) # ------------------------------------------------------------------ # Schema helpers # ------------------------------------------------------------------ async def _find_related_tables( self, client_id: str, user_id: str, hit_tables: list[str], ) -> list[str]: """One-hop FK neighbours of `hit_tables`, both directions, excluding hits. Prefers chunk_level='table' rows; if none exist for the client (legacy ingest predating Phase 1), falls back to aggregating from column-chunk metadata. Returns [] when no FK metadata is available. Capped at _FK_EXPANSION_MAX_TABLES, ranked by edge count desc then table name asc. A warning is logged when the cap kicks in. """ if not hit_tables: return [] hit_set = set(hit_tables) # edge_counts[related_table] = number of FK edges connecting it to the hit set edge_counts: dict[str, int] = defaultdict(int) # ---- Primary path: table-level chunks ---- sql = text(""" SELECT lpe.cmetadata FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'database_client_id' = :client_id AND lpe.cmetadata->>'chunk_level' = 'table' """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id}) table_rows = result.fetchall() if table_rows: for row in table_rows: data = row.cmetadata.get("data", {}) table = data.get("table_name") fks = data.get("foreign_keys") or [] if not table: continue if table in hit_set: # Outgoing: this hit's FKs point at related tables for fk in fks: target = fk.get("target_table") if target and target not in hit_set: edge_counts[target] += 1 else: # Incoming: this non-hit table's FKs point into the hit set for fk in fks: target = fk.get("target_table") if target in hit_set: edge_counts[table] += 1 else: # ---- Fallback: aggregate from column chunks ---- sql = text(""" SELECT lpe.cmetadata->'data'->>'table_name' AS src_table, lpe.cmetadata->'data'->>'foreign_key' AS fk FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'database_client_id' = :client_id AND lpe.cmetadata->>'chunk_level' = 'column' AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id}) col_rows = result.fetchall() for row in col_rows: src = row.src_table fk = row.fk if not src or not fk: continue target = fk.split(".", 1)[0] if src in hit_set and target and target not in hit_set: edge_counts[target] += 1 elif src not in hit_set and target in hit_set: edge_counts[src] += 1 if not edge_counts: return [] ranked = sorted(edge_counts.items(), key=lambda kv: (-kv[1], kv[0])) if len(ranked) > _FK_EXPANSION_MAX_TABLES: logger.warning( "fk expansion cap hit", client_id=client_id, total=len(ranked), cap=_FK_EXPANSION_MAX_TABLES, dropped=[t for t, _ in ranked[_FK_EXPANSION_MAX_TABLES:]], ) ranked = ranked[:_FK_EXPANSION_MAX_TABLES] related = [t for t, _ in ranked] logger.info("fk-related tables", hit=sorted(hit_set), related=related) return related async def _fetch_abbreviated_schema( self, client_id: str, user_id: str, table_names: list[str], ) -> dict[str, dict[str, Any]]: """Abbreviated schema: name, row_count, PK, FKs, column names — no profiles. Prefers chunk_level='table' rows. Falls back to aggregating column-chunk metadata when table chunks are missing for a given table_name. Returns {table_name: {"row_count": int|None, "primary_key": [str], "foreign_keys": [{column, target_table, target_column}], "column_names": [str]}}. """ if not table_names: return {} placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} for i, name in enumerate(table_names): params[f"t{i}"] = name # Primary path: one row per table from chunk_level='table' sql_table = text(f""" SELECT lpe.cmetadata FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'database_client_id' = :client_id AND lpe.cmetadata->>'chunk_level' = 'table' AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql_table, params) t_rows = result.fetchall() out: dict[str, dict[str, Any]] = {} for row in t_rows: data = row.cmetadata.get("data", {}) tname = data.get("table_name") if not tname: continue out[tname] = { "row_count": data.get("row_count"), "primary_key": list(data.get("primary_key") or []), "foreign_keys": list(data.get("foreign_keys") or []), "column_names": list(data.get("column_names") or []), } # Fallback for tables with no table-chunk: aggregate column chunks missing = [t for t in table_names if t not in out] if missing: placeholders_m = ", ".join(f":m{i}" for i in range(len(missing))) params_m: dict[str, Any] = {"user_id": user_id, "client_id": client_id} for i, name in enumerate(missing): params_m[f"m{i}"] = name sql_col = text(f""" SELECT lpe.cmetadata FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'database_client_id' = :client_id AND lpe.cmetadata->>'chunk_level' = 'column' AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders_m}) ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name' """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql_col, params_m) c_rows = result.fetchall() agg: dict[str, dict[str, Any]] = { t: {"row_count": None, "primary_key": [], "foreign_keys": [], "column_names": []} for t in missing } for row in c_rows: data = row.cmetadata.get("data", {}) tname = data.get("table_name") cname = data.get("column_name") if not tname or tname not in agg or not cname: continue bucket = agg[tname] bucket["column_names"].append(cname) if data.get("is_primary_key"): bucket["primary_key"].append(cname) fk = data.get("foreign_key") if fk: target_table, _, target_col = fk.partition(".") bucket["foreign_keys"].append({ "column": cname, "target_table": target_table, "target_column": target_col, }) for t, v in agg.items(): if v["column_names"]: out[t] = v return out async def _fetch_full_schema( self, client_id: str, table_names: list[str], user_id: str, ) -> dict[str, list[dict[str, Any]]]: """Fetch ALL column chunks for the given tables from PGVector. Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ..., "foreign_key": ..., "content": ...}]} """ placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) sql = text(f""" SELECT lpe.cmetadata, lpe.document FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'chunk_level' = 'column' AND lpe.cmetadata->>'database_client_id' = :client_id AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name' """) params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} for i, name in enumerate(table_names): params[f"t{i}"] = name async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, params) rows = result.fetchall() schema: dict[str, list[dict[str, Any]]] = defaultdict(list) for row in rows: data = row.cmetadata.get("data", {}) table = data.get("table_name") if table: schema[table].append({ "name": data.get("column_name", ""), "type": data.get("column_type", ""), "is_primary_key": data.get("is_primary_key", False), "foreign_key": data.get("foreign_key"), "content": row.document, # chunk text includes top values / samples }) return dict(schema) def _build_schema_context( self, schema: dict[str, list[dict[str, Any]]], related_schema: dict[str, dict[str, Any]] | None = None, ) -> str: lines: list[str] = [] for table, columns in schema.items(): lines.append(f"Table: {table}") for col in columns: flags = [] if col["is_primary_key"]: flags.append("PRIMARY KEY") if col["foreign_key"]: flags.append(f"FK -> {col['foreign_key']}") flag_str = f" [{', '.join(flags)}]" if flags else "" lines.append(f" - {col['name']} {col['type']}{flag_str}") # Include sample/top-values line from chunk content if present for line in col["content"].splitlines(): if line.startswith(("Top values:", "Sample values:")): lines.append(f" {line}") break lines.append("") related_block = self._build_related_schema_block(related_schema or {}) if related_block: lines.append(related_block) return "\n".join(lines).strip() def _build_related_schema_block(self, related_schema: dict[str, dict[str, Any]]) -> str: """Format the abbreviated FK-related-tables section. Empty string when no related.""" if not related_schema: return "" lines: list[str] = ["Related tables (one hop via FK, abbreviated — use for JOINs only):"] for table, info in related_schema.items(): row_count = info.get("row_count") header = f"- {table} ({row_count} rows)" if row_count is not None else f"- {table}" lines.append(header) pk = info.get("primary_key") or [] lines.append(f" Primary key: {', '.join(pk) if pk else '(none)'}") fks = info.get("foreign_keys") or [] if fks: fk_strs = [ f"{fk.get('column')} -> {fk.get('target_table')}.{fk.get('target_column')}" for fk in fks ] lines.append(f" Foreign keys: {', '.join(fk_strs)}") else: lines.append(" Foreign keys: (none)") cols = info.get("column_names") or [] lines.append(f" Columns: {', '.join(cols)}") return "\n".join(lines) # ------------------------------------------------------------------ # Guardrails # ------------------------------------------------------------------ def _validate( self, sql: str, allowed_tables: set[str], limit: int, column_map: dict[str, set[str]] | None = None, ) -> str: """Return an error string if validation fails, empty string if OK. `allowed_tables` is the union of hit-table names and FK-related table names — both are legal targets for SELECT/JOIN. `column_map` maps table_name → set of valid column names. When provided, any qualified table.column reference not found in the map triggers a retry with an informative error so the LLM can self-correct without hallucinating. """ # Layer 1: sqlglot parse + SELECT-only check try: parsed = sqlglot.parse_one(sql) except sqlglot.errors.ParseError as e: return f"SQL parse error: {e}" if not isinstance(parsed, exp.Select): return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}" # Check for DML anywhere in the AST (including writeable CTEs) for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)): return f"DML ({type(node).__name__}) is not allowed." # Layer 2: schema grounding — table names known_tables = {t.lower() for t in allowed_tables} alias_to_table: dict[str, str] = {} for tbl in parsed.find_all(exp.Table): name = tbl.name.lower() if name and name not in known_tables: return f"Unknown table '{tbl.name}'. Only use tables from the schema." alias = (tbl.alias or tbl.name).lower() alias_to_table[alias] = name # Layer 3: column grounding — qualified references only (table.column) if column_map: normalized_map = {t.lower(): {c.lower() for c in cols} for t, cols in column_map.items()} for col_node in parsed.find_all(exp.Column): tbl_ref = col_node.table if not tbl_ref: continue # unqualified — skip, can't resolve without full alias tracking tbl_name = alias_to_table.get(tbl_ref.lower(), tbl_ref.lower()) col_name = col_node.name.lower() if tbl_name in normalized_map and col_name not in normalized_map[tbl_name]: available = ", ".join(sorted(normalized_map[tbl_name])) return ( f"Column '{col_node.name}' does not exist on table '{tbl_name}'. " f"Available columns: {available}." ) # Layer 4: LIMIT enforcement (inject if missing — done before execution) return "" # ------------------------------------------------------------------ # SQL execution # ------------------------------------------------------------------ def _enforce_limit(self, sql: str, limit: int) -> str: """Inject or cap LIMIT using sqlglot AST manipulation.""" parsed = sqlglot.parse_one(sql) existing = parsed.find(exp.Limit) if existing: current = int(existing.expression.this) if current > limit: return parsed.limit(limit).sql() else: return parsed.limit(limit).sql() return parsed.sql() def _run_sql(self, engine: Any, sql: str) -> list[dict]: # Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient. with engine.connect() as conn: result = conn.execute(text(sql)) return [dict(row) for row in result.mappings()] db_executor = DbExecutor()