Rifqi Hafizuddin
[KM-564] fix source, now shows name instead of id. added diff retrieval vs catalog
96598f8 | """DbExecutor — runs a compiled IR against a user's external SQL database. | |
| Pipeline: | |
| IR → SqlCompiler.compile() → CompiledSql(sql, params) | |
| ↓ | |
| sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL) | |
| ↓ | |
| resolve creds (catalog.location_ref → dbclient://{client_id} → DatabaseClient | |
| row → Fernet decrypt) | |
| ↓ | |
| asyncio.to_thread(_run_sync) | |
| └ db_pipeline_service.engine_scope(db_type, creds) | |
| └ session-level: default_transaction_read_only + statement_timeout=30s | |
| (postgres / supabase only) | |
| └ engine.execute(text(sql), params) | |
| ↓ | |
| QueryResult (always returned — errors populate `.error`, never raised) | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import time | |
| from typing import Any | |
| import sqlglot | |
| import sqlglot.expressions as exp | |
| from sqlalchemy import text | |
| from ...catalog.models import Catalog, Source | |
| from ...database_client.database_client_service import database_client_service | |
| from ...db.postgres.connection import AsyncSessionLocal | |
| from ...middlewares.logging import get_logger | |
| from ...pipeline.db_pipeline import db_pipeline_service | |
| from ...utils.db_credential_encryption import decrypt_credentials_dict | |
| from ..compiler.sql import CompiledSql, SqlCompiler | |
| from ..ir.models import QueryIR | |
| from .base import BaseExecutor, QueryResult | |
| logger = get_logger("db_executor") | |
| _QUERY_TIMEOUT_SECONDS = 30 | |
| _ROW_HARD_CAP = 10_000 # belt-and-suspenders cap regardless of LIMIT | |
| _DBCLIENT_PREFIX = "dbclient://" | |
| _POSTGRES_LIKE = frozenset({"postgres", "supabase"}) | |
| class DbExecutor(BaseExecutor): | |
| """Executes compiled SQL on the user's registered DB. | |
| Constructed once per query with the user's catalog. The catalog is the | |
| source of truth for identifiers; the executor never touches the user's | |
| DB metadata at execution time. | |
| """ | |
| def __init__(self, catalog: Catalog) -> None: | |
| self._catalog = catalog | |
| self._compiler = SqlCompiler(catalog) | |
| async def run(self, ir: QueryIR) -> QueryResult: | |
| started = time.perf_counter() | |
| table_name = "" | |
| source_name = "" | |
| try: | |
| source = self._find_source(ir.source_id) | |
| source_name = source.name | |
| table_name = next( | |
| (t.name for t in source.tables if t.table_id == ir.table_id), "" | |
| ) | |
| if source.source_type != "schema": | |
| raise ValueError( | |
| f"DbExecutor cannot run on source_type={source.source_type!r}; " | |
| "expected 'schema'" | |
| ) | |
| compiled = self._compiler.compile(ir) | |
| self._sqlglot_guard(compiled.sql) | |
| client_id = self._parse_client_id(source.location_ref) | |
| client = await self._fetch_client(client_id) | |
| if client.user_id != self._catalog.user_id: | |
| raise PermissionError( | |
| f"DatabaseClient {client_id!r} owner mismatch " | |
| f"(client.user_id != catalog.user_id)" | |
| ) | |
| creds = decrypt_credentials_dict(client.credentials) | |
| columns, rows = await asyncio.wait_for( | |
| asyncio.to_thread(self._run_sync, client.db_type, creds, compiled), | |
| timeout=_QUERY_TIMEOUT_SECONDS, | |
| ) | |
| truncated = len(rows) > _ROW_HARD_CAP | |
| capped = rows[:_ROW_HARD_CAP] | |
| elapsed_ms = int((time.perf_counter() - started) * 1000) | |
| logger.info( | |
| "db query complete", | |
| source_id=ir.source_id, | |
| rows=len(capped), | |
| truncated=truncated, | |
| elapsed_ms=elapsed_ms, | |
| ) | |
| return QueryResult( | |
| source_id=ir.source_id, | |
| backend="sql", | |
| columns=columns, | |
| rows=capped, | |
| row_count=len(capped), | |
| truncated=truncated, | |
| elapsed_ms=elapsed_ms, | |
| table_id=ir.table_id, | |
| table_name=table_name, | |
| source_name=source_name, | |
| ) | |
| except Exception as e: | |
| elapsed_ms = int((time.perf_counter() - started) * 1000) | |
| logger.error( | |
| "db executor failed", | |
| source_id=ir.source_id, | |
| error=str(e), | |
| elapsed_ms=elapsed_ms, | |
| ) | |
| return QueryResult( | |
| source_id=ir.source_id, | |
| backend="sql", | |
| elapsed_ms=elapsed_ms, | |
| error=str(e), | |
| table_id=ir.table_id, | |
| table_name=table_name, | |
| source_name=source_name, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _find_source(self, source_id: str) -> Source: | |
| for s in self._catalog.sources: | |
| if s.source_id == source_id: | |
| return s | |
| raise ValueError(f"source_id {source_id!r} not in catalog") | |
| def _parse_client_id(location_ref: str) -> str: | |
| if not location_ref.startswith(_DBCLIENT_PREFIX): | |
| raise ValueError( | |
| f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}" | |
| ) | |
| client_id = location_ref[len(_DBCLIENT_PREFIX):] | |
| if not client_id: | |
| raise ValueError("location_ref is missing client_id after 'dbclient://'") | |
| return client_id | |
| async def _fetch_client(client_id: str) -> Any: | |
| async with AsyncSessionLocal() as session: | |
| client = await database_client_service.get(session, client_id) | |
| if client is None: | |
| raise ValueError(f"DatabaseClient {client_id!r} not found") | |
| if client.status != "active": | |
| raise ValueError( | |
| f"DatabaseClient {client_id!r} is not active " | |
| f"(status={client.status!r})" | |
| ) | |
| return client | |
| def _sqlglot_guard(sql: str) -> None: | |
| """Defense-in-depth: ensure the compiled SQL is a SELECT statement. | |
| The compiler is already deterministic and only constructs SELECTs from | |
| validated IR, but this guard catches any future bug that could leak | |
| DML/DDL through. | |
| """ | |
| try: | |
| parsed = sqlglot.parse_one(sql, read="postgres") | |
| except sqlglot.errors.ParseError as e: | |
| raise ValueError(f"compiled SQL failed to parse: {e}") from e | |
| if not isinstance(parsed, exp.Select): | |
| raise ValueError( | |
| f"compiled SQL is not a SELECT (got {type(parsed).__name__})" | |
| ) | |
| forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter) | |
| for node in parsed.find_all(forbidden): | |
| raise ValueError( | |
| f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}" | |
| ) | |
| def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]: | |
| with db_pipeline_service.engine_scope(db_type, creds) as engine: | |
| with engine.connect() as conn: | |
| if db_type in _POSTGRES_LIKE: | |
| # session-level read-only + per-statement timeout (ms) | |
| conn.execute(text("SET default_transaction_read_only = on")) | |
| conn.execute( | |
| text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}") | |
| ) | |
| result = conn.execute(text(compiled.sql), compiled.params) | |
| columns = list(result.keys()) | |
| rows = [dict(row) for row in result.mappings()] | |
| return columns, rows | |