"""DbExecutor — runs a compiled IR against a user's external SQL database. Pipeline: IR → SqlCompiler.compile() → CompiledSql(sql, params) ↓ sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL) ↓ resolve creds (catalog.location_ref → dbclient://{client_id} → DatabaseClient row → Fernet decrypt) ↓ asyncio.to_thread(_run_sync) └ db_pipeline_service.engine_scope(db_type, creds) └ session-level: default_transaction_read_only + statement_timeout=30s (postgres / supabase only) └ engine.execute(text(sql), params) ↓ QueryResult (always returned — errors populate `.error`, never raised) """ from __future__ import annotations import asyncio import time from typing import Any import sqlglot import sqlglot.expressions as exp from sqlalchemy import text from ...catalog.models import Catalog, Source from ...database_client.database_client_service import database_client_service from ...database_client.engine import user_engine_cache from ...db.postgres.connection import AsyncSessionLocal from ...middlewares.logging import get_logger from ...pipeline.db_pipeline import db_pipeline_service from ...utils.db_credential_encryption import decrypt_credentials_dict from ..compiler.sql import CompiledSql, SqlCompiler from ..ir.models import QueryIR from .base import BaseExecutor, QueryResult logger = get_logger("db_executor") _QUERY_TIMEOUT_SECONDS = 30 _DBCLIENT_PREFIX = "dbclient://" class DbExecutor(BaseExecutor): """Executes compiled SQL on the user's registered DB. Constructed once per query with the user's catalog. The catalog is the source of truth for identifiers; the executor never touches the user's DB metadata at execution time. """ def __init__(self, catalog: Catalog) -> None: self._catalog = catalog self._compiler = SqlCompiler(catalog) async def run(self, ir: QueryIR) -> QueryResult: started = time.perf_counter() table_name = "" source_name = "" try: source = self._find_source(ir.source_id) source_name = source.name table_name = next( (t.name for t in source.tables if t.table_id == ir.table_id), "" ) if source.source_type != "schema": raise ValueError( f"DbExecutor cannot run on source_type={source.source_type!r}; " "expected 'schema'" ) compiled = self._compiler.compile(ir) self._sqlglot_guard(compiled.sql) client_id = self._parse_client_id(source.location_ref) client = await self._fetch_client(client_id) if client.user_id != self._catalog.user_id: raise PermissionError( f"DatabaseClient {client_id!r} owner mismatch " f"(client.user_id != catalog.user_id)" ) creds = decrypt_credentials_dict(client.credentials) columns, rows = await asyncio.wait_for( asyncio.to_thread( self._run_sync, client_id, client.db_type, creds, compiled ), timeout=_QUERY_TIMEOUT_SECONDS, ) # The compiler bounded the SQL to `row_cap` (+1 when the IR was # unbounded). More than row_cap rows means the result was truncated. truncated = len(rows) > compiled.row_cap capped = rows[:compiled.row_cap] elapsed_ms = int((time.perf_counter() - started) * 1000) logger.info( "db query complete", source_id=ir.source_id, rows=len(capped), truncated=truncated, elapsed_ms=elapsed_ms, ) return QueryResult( source_id=ir.source_id, backend="sql", columns=columns, rows=capped, row_count=len(capped), truncated=truncated, elapsed_ms=elapsed_ms, table_id=ir.table_id, table_name=table_name, source_name=source_name, ) except Exception as e: elapsed_ms = int((time.perf_counter() - started) * 1000) logger.error( "db executor failed", source_id=ir.source_id, error=str(e), elapsed_ms=elapsed_ms, ) return QueryResult( source_id=ir.source_id, backend="sql", elapsed_ms=elapsed_ms, error=str(e), table_id=ir.table_id, table_name=table_name, source_name=source_name, ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _find_source(self, source_id: str) -> Source: for s in self._catalog.sources: if s.source_id == source_id: return s raise ValueError(f"source_id {source_id!r} not in catalog") @staticmethod def _parse_client_id(location_ref: str) -> str: if not location_ref.startswith(_DBCLIENT_PREFIX): raise ValueError( f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}" ) client_id = location_ref[len(_DBCLIENT_PREFIX):] if not client_id: raise ValueError("location_ref is missing client_id after 'dbclient://'") return client_id @staticmethod async def _fetch_client(client_id: str) -> Any: async with AsyncSessionLocal() as session: client = await database_client_service.get(session, client_id) if client is None: raise ValueError(f"DatabaseClient {client_id!r} not found") if client.status != "active": raise ValueError( f"DatabaseClient {client_id!r} is not active " f"(status={client.status!r})" ) return client @staticmethod def _sqlglot_guard(sql: str) -> None: """Defense-in-depth: ensure the compiled SQL is a SELECT statement. The compiler is already deterministic and only constructs SELECTs from validated IR, but this guard catches any future bug that could leak DML/DDL through. """ try: parsed = sqlglot.parse_one(sql, read="postgres") except sqlglot.errors.ParseError as e: raise ValueError(f"compiled SQL failed to parse: {e}") from e if not isinstance(parsed, exp.Select): raise ValueError( f"compiled SQL is not a SELECT (got {type(parsed).__name__})" ) forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter) for node in parsed.find_all(forbidden): raise ValueError( f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}" ) @staticmethod def _run_sync( client_id: str, db_type: str, creds: dict, compiled: CompiledSql ) -> tuple[list[str], list[dict]]: engine = user_engine_cache.get_engine(client_id, db_type, creds) if engine is not None: # Pooled, reused engine (postgres-like). Read-only + statement_timeout # are set once per physical connection (connect event in UserEngineCache), # so no per-query SET round-trips and no dispose — the connection returns # to the pool warm for the next query. with engine.connect() as conn: result = conn.execute(text(compiled.sql), compiled.params) return list(result.keys()), [dict(row) for row in result.mappings()] # Legacy per-call path for non-postgres db_types (connect once, dispose). # These never set read-only/timeout before, so behavior is unchanged. with db_pipeline_service.engine_scope(db_type, creds) as eng: with eng.connect() as conn: result = conn.execute(text(compiled.sql), compiled.params) return list(result.keys()), [dict(row) for row in result.mappings()] # ------------------------------------------------------------------ # Speculative pre-connect (DB3) # ------------------------------------------------------------------ @classmethod async def prewarm(cls, catalog: Catalog, user_id: str) -> None: """Best-effort: warm pooled engines for the catalog's schema sources. Called at slow-path entry so the TCP+TLS+auth handshake overlaps the ~4s Planner LLM call — by the time `retrieve_data` runs, the connection is already established. Warming is an optimization, never a requirement, so this never raises and per-source failures are swallowed. """ for source in catalog.sources: if source.source_type != "schema": continue try: client_id = cls._parse_client_id(source.location_ref) client = await cls._fetch_client(client_id) if client.user_id != user_id: continue creds = decrypt_credentials_dict(client.credentials) await asyncio.to_thread(cls._warm_sync, client_id, client.db_type, creds) except Exception as exc: # noqa: BLE001 — best-effort warming logger.info("prewarm skipped", source_id=source.source_id, error=str(exc)) @staticmethod def _warm_sync(client_id: str, db_type: str, creds: dict) -> None: engine = user_engine_cache.get_engine(client_id, db_type, creds) if engine is not None: # Open + return a pooled physical connection: forces the handshake and # runs the connect-event session SETs, leaving the pool warm. with engine.connect(): pass