"""Executor for registered database sources (source_type="database"). Flow per (client_id, question): 1. Collect all relevant (table_name, column_name) pairs from retrieval results. 2. Fetch the FULL schema for those tables from PGVector (not just top-k columns). 3. Build a schema context string and send to LLM → structured SQLQuery output. 4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced. 5. Execute on the user's DB via engine_scope + asyncio.to_thread. 6. Return QueryResult per client_id (may span multiple tables via JOINs). Supported db_types: postgres, supabase, mysql. Other types are skipped with a warning — they do not raise. """ import asyncio from collections import defaultdict from typing import Any import sqlglot import sqlglot.expressions as exp import tiktoken from langchain_core.prompts import ChatPromptTemplate from langchain_openai import AzureChatOpenAI from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from src.config.settings import settings from src.database_client.database_client_service import database_client_service from src.db.postgres.connection import _pgvector_engine from src.middlewares.logging import get_logger from src.models.sql_query import SQLQuery from src.pipeline.db_pipeline import db_pipeline_service from src.query.base import BaseExecutor, QueryResult from src.rag.base import RetrievalResult from src.utils.db_credential_encryption import decrypt_credentials_dict logger = get_logger("db_executor") _enc = tiktoken.get_encoding("cl100k_base") _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"} _MAX_RETRIES = 3 _MAX_LIMIT = 500 _SQL_SYSTEM_PROMPT = """\ You are a SQL data analyst working with a user's database. Generate a single SQL SELECT statement that answers the user's question. Database dialect: {dialect} Rules: - ONLY reference tables and columns listed in the schema below. Do not invent names. - Always include a LIMIT clause (max {limit}). - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL. - Prefer explicit JOINs over subqueries when combining tables. - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count). - For date filtering, use dialect-appropriate functions ({dialect} syntax). Schema: {schema} {error_section}""" class DbExecutor(BaseExecutor): def __init__(self) -> None: self._llm = AzureChatOpenAI( azure_deployment=settings.azureai_deployment_name_4o, openai_api_version=settings.azureai_api_version_4o, azure_endpoint=settings.azureai_endpoint_url_4o, api_key=settings.azureai_api_key_4o, temperature=0, ) self._prompt = ChatPromptTemplate.from_messages([ ("system", _SQL_SYSTEM_PROMPT), ("human", "{question}"), ]) self._chain = self._prompt | self._llm.with_structured_output(SQLQuery) # ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------ async def execute( self, results: list[RetrievalResult], user_id: str, db: AsyncSession, question: str, limit: int = 100, ) -> list[QueryResult]: db_results = [r for r in results if r.source_type == "database"] if not db_results: return [] # Group by client_id — one SQL generation + execution pass per client by_client: dict[str, list[RetrievalResult]] = defaultdict(list) for r in db_results: client_id = r.metadata.get("database_client_id", "") if client_id: by_client[client_id].append(r) else: logger.warning("db result missing database_client_id, skipping") query_results: list[QueryResult] = [] for client_id, client_results in by_client.items(): try: qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit) if qr: query_results.append(qr) except Exception as e: logger.error("db executor failed for client", client_id=client_id, error=str(e)) return query_results # ------------------------------------------------------------------ # Per-client execution # ------------------------------------------------------------------ async def _execute_for_client( self, client_id: str, results: list[RetrievalResult], user_id: str, db: AsyncSession, question: str, limit: int, ) -> QueryResult | None: client = await database_client_service.get(db, client_id) if not client: logger.warning("database client not found", client_id=client_id) return None if client.user_id != user_id: logger.warning("client ownership mismatch", client_id=client_id) return None if client.db_type not in _SUPPORTED_DB_TYPES: logger.warning("unsupported db_type for query execution", db_type=client.db_type) return None # Distinct table names from retrieval results, expanded via FK relationships table_names = list({ r.metadata.get("data", {}).get("table_name") for r in results if r.metadata.get("data", {}).get("table_name") }) table_names = await self._expand_with_fk_tables(client_id, user_id, table_names) full_schema = await self._fetch_full_schema(client_id, table_names, user_id) if not full_schema: logger.warning("no schema found in vector store", client_id=client_id, tables=table_names) return None schema_ctx = self._build_schema_context(full_schema) capped_limit = min(limit, _MAX_LIMIT) dialect = client.db_type # SQL generation with retry validated_sql: str | None = None prev_error: str = "" prev_reasoning: str = "" for attempt in range(_MAX_RETRIES): if prev_error: error_section = ( f"Previous attempt reasoning: {prev_reasoning}\n" f"Previous attempt failed: {prev_error}\n" "Fix the issue above." ) else: error_section = "" try: prompt_text = schema_ctx + error_section + question input_tokens = len(_enc.encode(prompt_text)) logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens) result: SQLQuery = await self._chain.ainvoke({ "schema": schema_ctx, "dialect": dialect, "limit": capped_limit, "error_section": error_section, "question": question, }) sql = result.sql.strip() validation_error = self._validate(sql, full_schema, capped_limit) if validation_error: prev_error = validation_error prev_reasoning = result.reasoning logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error) continue validated_sql = self._enforce_limit(sql, capped_limit) output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning)) logger.info( "sql generated", attempt=attempt + 1, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=input_tokens + output_tokens, reasoning=result.reasoning, ) break except Exception as e: prev_error = str(e) logger.warning("sql generation error", attempt=attempt + 1, error=prev_error) if not validated_sql: logger.error("sql generation failed after retries", client_id=client_id) return None # Execute on user's DB creds = decrypt_credentials_dict(client.credentials) with db_pipeline_service.engine_scope(client.db_type, creds) as engine: rows = await asyncio.to_thread(self._run_sql, engine, validated_sql) column_types = { col["name"]: col["type"] for cols in full_schema.values() for col in cols } columns = list(rows[0].keys()) if rows else [] return QueryResult( source_type="database", source_id=client_id, table_or_file=", ".join(table_names), columns=columns, rows=rows, row_count=len(rows), metadata={ "db_type": client.db_type, "client_name": client.name, "sql": validated_sql, "column_types": {c: column_types.get(c, "unknown") for c in columns}, }, ) # ------------------------------------------------------------------ # Schema helpers # ------------------------------------------------------------------ async def _expand_with_fk_tables( self, client_id: str, user_id: str, table_names: list[str], ) -> list[str]: """Expand table_names with any tables FK-referenced by the retrieved tables. Prevents SQL generation failures when a required table (e.g. orders) wasn't returned by retrieval but is referenced via FK from a table that was (e.g. order_items.order_id -> orders.id). """ if not table_names: return table_names placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) sql = text(f""" SELECT DISTINCT lpe.cmetadata->'data'->>'foreign_key' AS fk FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'database_client_id' = :client_id AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL """) params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} for i, name in enumerate(table_names): params[f"t{i}"] = name async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, params) rows = result.fetchall() expanded = set(table_names) for row in rows: fk = row.fk # format: "referred_table.referred_column" if fk: referred_table = fk.split(".")[0] expanded.add(referred_table) if expanded != set(table_names): logger.info( "expanded tables via FK", original=sorted(table_names), expanded=sorted(expanded), ) return list(expanded) async def _fetch_full_schema( self, client_id: str, table_names: list[str], user_id: str, ) -> dict[str, list[dict[str, Any]]]: """Fetch ALL column chunks for the given tables from PGVector. Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ..., "foreign_key": ..., "content": ...}]} """ placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) sql = text(f""" SELECT lpe.cmetadata, lpe.document FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'database_client_id' = :client_id AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name' """) params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} for i, name in enumerate(table_names): params[f"t{i}"] = name async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, params) rows = result.fetchall() schema: dict[str, list[dict[str, Any]]] = defaultdict(list) for row in rows: data = row.cmetadata.get("data", {}) table = data.get("table_name") if table: schema[table].append({ "name": data.get("column_name", ""), "type": data.get("column_type", ""), "is_primary_key": data.get("is_primary_key", False), "foreign_key": data.get("foreign_key"), "content": row.document, # chunk text includes top values / samples }) return dict(schema) def _build_schema_context(self, schema: dict[str, list[dict[str, Any]]]) -> str: lines: list[str] = [] for table, columns in schema.items(): lines.append(f"Table: {table}") for col in columns: flags = [] if col["is_primary_key"]: flags.append("PRIMARY KEY") if col["foreign_key"]: flags.append(f"FK -> {col['foreign_key']}") flag_str = f" [{', '.join(flags)}]" if flags else "" lines.append(f" - {col['name']} {col['type']}{flag_str}") # Include sample/top-values line from chunk content if present for line in col["content"].splitlines(): if line.startswith(("Top values:", "Sample values:")): lines.append(f" {line}") break lines.append("") return "\n".join(lines).strip() # ------------------------------------------------------------------ # Guardrails # ------------------------------------------------------------------ def _validate(self, sql: str, schema: dict[str, list[dict]], limit: int) -> str: """Return an error string if validation fails, empty string if OK.""" # Layer 1: sqlglot parse + SELECT-only check try: parsed = sqlglot.parse_one(sql) except sqlglot.errors.ParseError as e: return f"SQL parse error: {e}" if not isinstance(parsed, exp.Select): return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}" # Check for DML anywhere in the AST (including writeable CTEs) for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)): return f"DML ({type(node).__name__}) is not allowed." # Layer 2: schema grounding — table names known_tables = {t.lower() for t in schema} for tbl in parsed.find_all(exp.Table): name = tbl.name.lower() if name and name not in known_tables: return f"Unknown table '{tbl.name}'. Only use tables from the schema." # Layer 3: LIMIT enforcement (inject if missing — done before execution) return "" # ------------------------------------------------------------------ # SQL execution # ------------------------------------------------------------------ def _enforce_limit(self, sql: str, limit: int) -> str: """Inject or cap LIMIT using sqlglot AST manipulation.""" parsed = sqlglot.parse_one(sql) existing = parsed.find(exp.Limit) if existing: current = int(existing.expression.this) if current > limit: existing.expression.set("this", limit) else: parsed = parsed.limit(limit) return parsed.sql() def _run_sql(self, engine: Any, sql: str) -> list[dict]: # Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient. with engine.connect() as conn: result = conn.execute(text(sql)) return [dict(row) for row in result.mappings()] db_executor = DbExecutor()