| """Executor for registered database sources (source_type="database"). |
| |
| Flow per (client_id, question): |
| 1. Collect all relevant (table_name, column_name) pairs from retrieval results. |
| 2. Fetch the FULL schema for those tables from PGVector (not just top-k columns). |
| 3. Build a schema context string and send to LLM β structured SQLQuery output. |
| 4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced. |
| 5. Execute on the user's DB via engine_scope + asyncio.to_thread. |
| 6. Return QueryResult per client_id (may span multiple tables via JOINs). |
| |
| Supported db_types: postgres, supabase, mysql. |
| Other types are skipped with a warning β they do not raise. |
| """ |
|
|
| import asyncio |
| from collections import defaultdict |
| from typing import Any |
|
|
| import sqlglot |
| import sqlglot.expressions as exp |
| import tiktoken |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_openai import AzureChatOpenAI |
| from sqlalchemy import text |
| from sqlalchemy.ext.asyncio import AsyncSession |
|
|
| from src.config.settings import settings |
| from src.database_client.database_client_service import database_client_service |
| from src.db.postgres.connection import _pgvector_engine |
| from src.middlewares.logging import get_logger |
| from src.models.sql_query import SQLQuery |
| from src.pipeline.db_pipeline import db_pipeline_service |
| from src.query.base import BaseExecutor, QueryResult |
| from src.rag.base import RetrievalResult |
| from src.utils.db_credential_encryption import decrypt_credentials_dict |
|
|
| logger = get_logger("db_executor") |
|
|
| _enc = tiktoken.get_encoding("cl100k_base") |
|
|
| _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"} |
| _MAX_RETRIES = 3 |
| _MAX_LIMIT = 500 |
| _FK_EXPANSION_MAX_TABLES = 5 |
|
|
| _SQL_SYSTEM_PROMPT = """\ |
| You are a SQL data analyst working with a user's database. |
| Generate a single SQL SELECT statement that answers the user's question. |
| |
| Database dialect: {dialect} |
| |
| Rules: |
| - ONLY reference tables and columns listed in the schema below. Do not invent names. |
| - Always include a LIMIT clause (max {limit}). |
| - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL. |
| - Prefer explicit JOINs over subqueries when combining tables. |
| - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count). |
| - For date filtering, use dialect-appropriate functions ({dialect} syntax). |
| |
| Schema: |
| {schema} |
| |
| {error_section}""" |
|
|
|
|
| class DbExecutor(BaseExecutor): |
| def __init__(self) -> None: |
| self._llm = AzureChatOpenAI( |
| azure_deployment=settings.azureai_deployment_name_4o, |
| openai_api_version=settings.azureai_api_version_4o, |
| azure_endpoint=settings.azureai_endpoint_url_4o, |
| api_key=settings.azureai_api_key_4o, |
| temperature=0, |
| ) |
| self._prompt = ChatPromptTemplate.from_messages([ |
| ("system", _SQL_SYSTEM_PROMPT), |
| ("human", "{question}"), |
| ]) |
| self._chain = self._prompt | self._llm.with_structured_output(SQLQuery) |
|
|
| |
| |
| |
|
|
| async def execute( |
| self, |
| results: list[RetrievalResult], |
| user_id: str, |
| db: AsyncSession, |
| question: str, |
| limit: int = 100, |
| ) -> list[QueryResult]: |
| db_results = [r for r in results if r.source_type == "database"] |
| if not db_results: |
| return [] |
|
|
| |
| by_client: dict[str, list[RetrievalResult]] = defaultdict(list) |
| for r in db_results: |
| client_id = r.metadata.get("database_client_id", "") |
| if client_id: |
| by_client[client_id].append(r) |
| else: |
| logger.warning("db result missing database_client_id, skipping") |
|
|
| query_results: list[QueryResult] = [] |
| for client_id, client_results in by_client.items(): |
| try: |
| qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit) |
| if qr: |
| query_results.append(qr) |
| except Exception as e: |
| logger.error("db executor failed for client", client_id=client_id, error=str(e)) |
|
|
| return query_results |
|
|
| |
| |
| |
|
|
| async def _execute_for_client( |
| self, |
| client_id: str, |
| results: list[RetrievalResult], |
| user_id: str, |
| db: AsyncSession, |
| question: str, |
| limit: int, |
| ) -> QueryResult | None: |
| client = await database_client_service.get(db, client_id) |
| if not client: |
| logger.warning("database client not found", client_id=client_id) |
| return None |
| if client.user_id != user_id: |
| logger.warning("client ownership mismatch", client_id=client_id) |
| return None |
| if client.db_type not in _SUPPORTED_DB_TYPES: |
| logger.warning("unsupported db_type for query execution", db_type=client.db_type) |
| return None |
|
|
| |
| |
| |
| |
| hit_tables = list({ |
| r.metadata.get("data", {}).get("table_name") |
| for r in results |
| if r.metadata.get("data", {}).get("table_name") |
| }) |
| if not hit_tables: |
| logger.warning("no table_name on any retrieval result", client_id=client_id) |
| return None |
|
|
| full_schema = await self._fetch_full_schema(client_id, hit_tables, user_id) |
| if not full_schema: |
| logger.warning("no schema found in vector store", client_id=client_id, tables=hit_tables) |
| return None |
|
|
| related_tables = await self._find_related_tables(client_id, user_id, hit_tables) |
| related_schema = ( |
| await self._fetch_abbreviated_schema(client_id, user_id, related_tables) |
| if related_tables else {} |
| ) |
|
|
| schema_ctx = self._build_schema_context(full_schema, related_schema) |
| capped_limit = min(limit, _MAX_LIMIT) |
| dialect = client.db_type |
|
|
| |
| validated_sql: str | None = None |
| prev_error: str = "" |
| prev_reasoning: str = "" |
| for attempt in range(_MAX_RETRIES): |
| if prev_error: |
| error_section = ( |
| f"Previous attempt reasoning: {prev_reasoning}\n" |
| f"Previous attempt failed: {prev_error}\n" |
| "Fix the issue above." |
| ) |
| else: |
| error_section = "" |
| try: |
| prompt_text = schema_ctx + error_section + question |
| input_tokens = len(_enc.encode(prompt_text)) |
| logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens) |
|
|
| result: SQLQuery = await self._chain.ainvoke({ |
| "schema": schema_ctx, |
| "dialect": dialect, |
| "limit": capped_limit, |
| "error_section": error_section, |
| "question": question, |
| }) |
| sql = result.sql.strip() |
| allowed_tables = set(full_schema) | set(related_schema) |
| column_map: dict[str, set[str]] = { |
| t: {c["name"] for c in cols} for t, cols in full_schema.items() |
| } |
| for t, info in related_schema.items(): |
| column_map[t] = set(info.get("column_names") or []) |
| validation_error = self._validate(sql, allowed_tables, capped_limit, column_map) |
| if validation_error: |
| prev_error = validation_error |
| prev_reasoning = result.reasoning |
| logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error) |
| continue |
| validated_sql = self._enforce_limit(sql, capped_limit) |
| output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning)) |
| logger.info( |
| "sql generated", |
| attempt=attempt + 1, |
| input_tokens=input_tokens, |
| output_tokens=output_tokens, |
| total_tokens=input_tokens + output_tokens, |
| reasoning=result.reasoning, |
| ) |
| break |
| except Exception as e: |
| prev_error = str(e) |
| logger.warning("sql generation error", attempt=attempt + 1, error=prev_error) |
|
|
| if not validated_sql: |
| logger.error("sql generation failed after retries", client_id=client_id) |
| return None |
|
|
| |
| creds = decrypt_credentials_dict(client.credentials) |
| with db_pipeline_service.engine_scope(client.db_type, creds) as engine: |
| rows = await asyncio.to_thread(self._run_sql, engine, validated_sql) |
|
|
| column_types = { |
| col["name"]: col["type"] |
| for cols in full_schema.values() |
| for col in cols |
| } |
| columns = list(rows[0].keys()) if rows else [] |
|
|
| return QueryResult( |
| source_type="database", |
| source_id=client_id, |
| table_or_file=", ".join(hit_tables), |
| columns=columns, |
| rows=rows, |
| row_count=len(rows), |
| metadata={ |
| "db_type": client.db_type, |
| "client_name": client.name, |
| "sql": validated_sql, |
| "column_types": {c: column_types.get(c, "unknown") for c in columns}, |
| }, |
| ) |
|
|
| |
| |
| |
|
|
| async def _find_related_tables( |
| self, |
| client_id: str, |
| user_id: str, |
| hit_tables: list[str], |
| ) -> list[str]: |
| """One-hop FK neighbours of `hit_tables`, both directions, excluding hits. |
| |
| Prefers chunk_level='table' rows; if none exist for the client (legacy |
| ingest predating Phase 1), falls back to aggregating from column-chunk |
| metadata. Returns [] when no FK metadata is available. |
| |
| Capped at _FK_EXPANSION_MAX_TABLES, ranked by edge count desc then |
| table name asc. A warning is logged when the cap kicks in. |
| """ |
| if not hit_tables: |
| return [] |
|
|
| hit_set = set(hit_tables) |
| |
| edge_counts: dict[str, int] = defaultdict(int) |
|
|
| |
| sql = text(""" |
| SELECT lpe.cmetadata |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'database_client_id' = :client_id |
| AND lpe.cmetadata->>'chunk_level' = 'table' |
| """) |
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id}) |
| table_rows = result.fetchall() |
|
|
| if table_rows: |
| for row in table_rows: |
| data = row.cmetadata.get("data", {}) |
| table = data.get("table_name") |
| fks = data.get("foreign_keys") or [] |
| if not table: |
| continue |
| if table in hit_set: |
| |
| for fk in fks: |
| target = fk.get("target_table") |
| if target and target not in hit_set: |
| edge_counts[target] += 1 |
| else: |
| |
| for fk in fks: |
| target = fk.get("target_table") |
| if target in hit_set: |
| edge_counts[table] += 1 |
| else: |
| |
| sql = text(""" |
| SELECT lpe.cmetadata->'data'->>'table_name' AS src_table, |
| lpe.cmetadata->'data'->>'foreign_key' AS fk |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'database_client_id' = :client_id |
| AND lpe.cmetadata->>'chunk_level' = 'column' |
| AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL |
| """) |
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id}) |
| col_rows = result.fetchall() |
|
|
| for row in col_rows: |
| src = row.src_table |
| fk = row.fk |
| if not src or not fk: |
| continue |
| target = fk.split(".", 1)[0] |
| if src in hit_set and target and target not in hit_set: |
| edge_counts[target] += 1 |
| elif src not in hit_set and target in hit_set: |
| edge_counts[src] += 1 |
|
|
| if not edge_counts: |
| return [] |
|
|
| ranked = sorted(edge_counts.items(), key=lambda kv: (-kv[1], kv[0])) |
| if len(ranked) > _FK_EXPANSION_MAX_TABLES: |
| logger.warning( |
| "fk expansion cap hit", |
| client_id=client_id, |
| total=len(ranked), |
| cap=_FK_EXPANSION_MAX_TABLES, |
| dropped=[t for t, _ in ranked[_FK_EXPANSION_MAX_TABLES:]], |
| ) |
| ranked = ranked[:_FK_EXPANSION_MAX_TABLES] |
|
|
| related = [t for t, _ in ranked] |
| logger.info("fk-related tables", hit=sorted(hit_set), related=related) |
| return related |
|
|
| async def _fetch_abbreviated_schema( |
| self, |
| client_id: str, |
| user_id: str, |
| table_names: list[str], |
| ) -> dict[str, dict[str, Any]]: |
| """Abbreviated schema: name, row_count, PK, FKs, column names β no profiles. |
| |
| Prefers chunk_level='table' rows. Falls back to aggregating column-chunk |
| metadata when table chunks are missing for a given table_name. |
| |
| Returns {table_name: {"row_count": int|None, "primary_key": [str], |
| "foreign_keys": [{column, target_table, target_column}], |
| "column_names": [str]}}. |
| """ |
| if not table_names: |
| return {} |
|
|
| placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) |
| params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} |
| for i, name in enumerate(table_names): |
| params[f"t{i}"] = name |
|
|
| |
| sql_table = text(f""" |
| SELECT lpe.cmetadata |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'database_client_id' = :client_id |
| AND lpe.cmetadata->>'chunk_level' = 'table' |
| AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) |
| """) |
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql_table, params) |
| t_rows = result.fetchall() |
|
|
| out: dict[str, dict[str, Any]] = {} |
| for row in t_rows: |
| data = row.cmetadata.get("data", {}) |
| tname = data.get("table_name") |
| if not tname: |
| continue |
| out[tname] = { |
| "row_count": data.get("row_count"), |
| "primary_key": list(data.get("primary_key") or []), |
| "foreign_keys": list(data.get("foreign_keys") or []), |
| "column_names": list(data.get("column_names") or []), |
| } |
|
|
| |
| missing = [t for t in table_names if t not in out] |
| if missing: |
| placeholders_m = ", ".join(f":m{i}" for i in range(len(missing))) |
| params_m: dict[str, Any] = {"user_id": user_id, "client_id": client_id} |
| for i, name in enumerate(missing): |
| params_m[f"m{i}"] = name |
| sql_col = text(f""" |
| SELECT lpe.cmetadata |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'database_client_id' = :client_id |
| AND lpe.cmetadata->>'chunk_level' = 'column' |
| AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders_m}) |
| ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name' |
| """) |
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql_col, params_m) |
| c_rows = result.fetchall() |
|
|
| agg: dict[str, dict[str, Any]] = { |
| t: {"row_count": None, "primary_key": [], "foreign_keys": [], "column_names": []} |
| for t in missing |
| } |
| for row in c_rows: |
| data = row.cmetadata.get("data", {}) |
| tname = data.get("table_name") |
| cname = data.get("column_name") |
| if not tname or tname not in agg or not cname: |
| continue |
| bucket = agg[tname] |
| bucket["column_names"].append(cname) |
| if data.get("is_primary_key"): |
| bucket["primary_key"].append(cname) |
| fk = data.get("foreign_key") |
| if fk: |
| target_table, _, target_col = fk.partition(".") |
| bucket["foreign_keys"].append({ |
| "column": cname, |
| "target_table": target_table, |
| "target_column": target_col, |
| }) |
| for t, v in agg.items(): |
| if v["column_names"]: |
| out[t] = v |
|
|
| return out |
|
|
| async def _fetch_full_schema( |
| self, |
| client_id: str, |
| table_names: list[str], |
| user_id: str, |
| ) -> dict[str, list[dict[str, Any]]]: |
| """Fetch ALL column chunks for the given tables from PGVector. |
| |
| Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ..., |
| "foreign_key": ..., "content": ...}]} |
| """ |
| placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) |
| sql = text(f""" |
| SELECT lpe.cmetadata, lpe.document |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'chunk_level' = 'column' |
| AND lpe.cmetadata->>'database_client_id' = :client_id |
| AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) |
| ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name' |
| """) |
|
|
| params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} |
| for i, name in enumerate(table_names): |
| params[f"t{i}"] = name |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, params) |
| rows = result.fetchall() |
|
|
| schema: dict[str, list[dict[str, Any]]] = defaultdict(list) |
| for row in rows: |
| data = row.cmetadata.get("data", {}) |
| table = data.get("table_name") |
| if table: |
| schema[table].append({ |
| "name": data.get("column_name", ""), |
| "type": data.get("column_type", ""), |
| "is_primary_key": data.get("is_primary_key", False), |
| "foreign_key": data.get("foreign_key"), |
| "content": row.document, |
| }) |
| return dict(schema) |
|
|
| def _build_schema_context( |
| self, |
| schema: dict[str, list[dict[str, Any]]], |
| related_schema: dict[str, dict[str, Any]] | None = None, |
| ) -> str: |
| lines: list[str] = [] |
| for table, columns in schema.items(): |
| lines.append(f"Table: {table}") |
| for col in columns: |
| flags = [] |
| if col["is_primary_key"]: |
| flags.append("PRIMARY KEY") |
| if col["foreign_key"]: |
| flags.append(f"FK -> {col['foreign_key']}") |
| flag_str = f" [{', '.join(flags)}]" if flags else "" |
| lines.append(f" - {col['name']} {col['type']}{flag_str}") |
| |
| for line in col["content"].splitlines(): |
| if line.startswith(("Top values:", "Sample values:")): |
| lines.append(f" {line}") |
| break |
| lines.append("") |
|
|
| related_block = self._build_related_schema_block(related_schema or {}) |
| if related_block: |
| lines.append(related_block) |
|
|
| return "\n".join(lines).strip() |
|
|
| def _build_related_schema_block(self, related_schema: dict[str, dict[str, Any]]) -> str: |
| """Format the abbreviated FK-related-tables section. Empty string when no related.""" |
| if not related_schema: |
| return "" |
| lines: list[str] = ["Related tables (one hop via FK, abbreviated β use for JOINs only):"] |
| for table, info in related_schema.items(): |
| row_count = info.get("row_count") |
| header = f"- {table} ({row_count} rows)" if row_count is not None else f"- {table}" |
| lines.append(header) |
| pk = info.get("primary_key") or [] |
| lines.append(f" Primary key: {', '.join(pk) if pk else '(none)'}") |
| fks = info.get("foreign_keys") or [] |
| if fks: |
| fk_strs = [ |
| f"{fk.get('column')} -> {fk.get('target_table')}.{fk.get('target_column')}" |
| for fk in fks |
| ] |
| lines.append(f" Foreign keys: {', '.join(fk_strs)}") |
| else: |
| lines.append(" Foreign keys: (none)") |
| cols = info.get("column_names") or [] |
| lines.append(f" Columns: {', '.join(cols)}") |
| return "\n".join(lines) |
|
|
| |
| |
| |
|
|
| def _validate( |
| self, |
| sql: str, |
| allowed_tables: set[str], |
| limit: int, |
| column_map: dict[str, set[str]] | None = None, |
| ) -> str: |
| """Return an error string if validation fails, empty string if OK. |
| |
| `allowed_tables` is the union of hit-table names and FK-related table |
| names β both are legal targets for SELECT/JOIN. |
| |
| `column_map` maps table_name β set of valid column names. When provided, |
| any qualified table.column reference not found in the map triggers a retry |
| with an informative error so the LLM can self-correct without hallucinating. |
| """ |
| |
| try: |
| parsed = sqlglot.parse_one(sql) |
| except sqlglot.errors.ParseError as e: |
| return f"SQL parse error: {e}" |
|
|
| if not isinstance(parsed, exp.Select): |
| return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}" |
|
|
| |
| for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)): |
| return f"DML ({type(node).__name__}) is not allowed." |
|
|
| |
| known_tables = {t.lower() for t in allowed_tables} |
| alias_to_table: dict[str, str] = {} |
| for tbl in parsed.find_all(exp.Table): |
| name = tbl.name.lower() |
| if name and name not in known_tables: |
| return f"Unknown table '{tbl.name}'. Only use tables from the schema." |
| alias = (tbl.alias or tbl.name).lower() |
| alias_to_table[alias] = name |
|
|
| |
| if column_map: |
| normalized_map = {t.lower(): {c.lower() for c in cols} for t, cols in column_map.items()} |
| for col_node in parsed.find_all(exp.Column): |
| tbl_ref = col_node.table |
| if not tbl_ref: |
| continue |
| tbl_name = alias_to_table.get(tbl_ref.lower(), tbl_ref.lower()) |
| col_name = col_node.name.lower() |
| if tbl_name in normalized_map and col_name not in normalized_map[tbl_name]: |
| available = ", ".join(sorted(normalized_map[tbl_name])) |
| return ( |
| f"Column '{col_node.name}' does not exist on table '{tbl_name}'. " |
| f"Available columns: {available}." |
| ) |
|
|
| |
| return "" |
|
|
| |
| |
| |
|
|
| def _enforce_limit(self, sql: str, limit: int) -> str: |
| """Inject or cap LIMIT using sqlglot AST manipulation.""" |
| parsed = sqlglot.parse_one(sql) |
| existing = parsed.find(exp.Limit) |
| if existing: |
| current = int(existing.expression.this) |
| if current > limit: |
| return parsed.limit(limit).sql() |
| else: |
| return parsed.limit(limit).sql() |
| return parsed.sql() |
|
|
| def _run_sql(self, engine: Any, sql: str) -> list[dict]: |
| |
| with engine.connect() as conn: |
| result = conn.execute(text(sql)) |
| return [dict(row) for row in result.mappings()] |
|
|
|
|
| db_executor = DbExecutor() |
|
|