sofhiaazzhr's picture
[NOTICKET] Adopt verb-first skill naming
2d6406d
Raw
History Blame
10.1 kB
"""DbExecutor — runs a compiled IR against a user's external SQL database.
Pipeline:
IR → SqlCompiler.compile() → CompiledSql(sql, params)
sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL)
resolve creds (catalog.location_ref → dbclient://{client_id} → DatabaseClient
row → Fernet decrypt)
asyncio.to_thread(_run_sync)
└ db_pipeline_service.engine_scope(db_type, creds)
└ session-level: default_transaction_read_only + statement_timeout=30s
(postgres / supabase only)
└ engine.execute(text(sql), params)
QueryResult (always returned — errors populate `.error`, never raised)
"""
from __future__ import annotations
import asyncio
import time
from typing import Any
import sqlglot
import sqlglot.expressions as exp
from sqlalchemy import text
from ...catalog.models import Catalog, Source
from ...database_client.database_client_service import database_client_service
from ...database_client.engine import user_engine_cache
from ...db.postgres.connection import AsyncSessionLocal
from ...middlewares.logging import get_logger
from ...pipeline.db_pipeline import db_pipeline_service
from ...utils.db_credential_encryption import decrypt_credentials_dict
from ..compiler.sql import CompiledSql, SqlCompiler
from ..ir.models import QueryIR
from .base import BaseExecutor, QueryResult
logger = get_logger("db_executor")
_QUERY_TIMEOUT_SECONDS = 30
_DBCLIENT_PREFIX = "dbclient://"
class DbExecutor(BaseExecutor):
"""Executes compiled SQL on the user's registered DB.
Constructed once per query with the user's catalog. The catalog is the
source of truth for identifiers; the executor never touches the user's
DB metadata at execution time.
"""
def __init__(self, catalog: Catalog) -> None:
self._catalog = catalog
self._compiler = SqlCompiler(catalog)
async def run(self, ir: QueryIR) -> QueryResult:
started = time.perf_counter()
table_name = ""
source_name = ""
try:
source = self._find_source(ir.source_id)
source_name = source.name
table_name = next(
(t.name for t in source.tables if t.table_id == ir.table_id), ""
)
if source.source_type != "schema":
raise ValueError(
f"DbExecutor cannot run on source_type={source.source_type!r}; "
"expected 'schema'"
)
compiled = self._compiler.compile(ir)
self._sqlglot_guard(compiled.sql)
client_id = self._parse_client_id(source.location_ref)
client = await self._fetch_client(client_id)
if client.user_id != self._catalog.user_id:
raise PermissionError(
f"DatabaseClient {client_id!r} owner mismatch "
f"(client.user_id != catalog.user_id)"
)
creds = decrypt_credentials_dict(client.credentials)
columns, rows = await asyncio.wait_for(
asyncio.to_thread(
self._run_sync, client_id, client.db_type, creds, compiled
),
timeout=_QUERY_TIMEOUT_SECONDS,
)
# The compiler bounded the SQL to `row_cap` (+1 when the IR was
# unbounded). More than row_cap rows means the result was truncated.
truncated = len(rows) > compiled.row_cap
capped = rows[:compiled.row_cap]
elapsed_ms = int((time.perf_counter() - started) * 1000)
logger.info(
"db query complete",
source_id=ir.source_id,
rows=len(capped),
truncated=truncated,
elapsed_ms=elapsed_ms,
)
return QueryResult(
source_id=ir.source_id,
backend="sql",
columns=columns,
rows=capped,
row_count=len(capped),
truncated=truncated,
elapsed_ms=elapsed_ms,
table_id=ir.table_id,
table_name=table_name,
source_name=source_name,
)
except Exception as e:
elapsed_ms = int((time.perf_counter() - started) * 1000)
logger.error(
"db executor failed",
source_id=ir.source_id,
error=str(e),
elapsed_ms=elapsed_ms,
)
return QueryResult(
source_id=ir.source_id,
backend="sql",
elapsed_ms=elapsed_ms,
error=str(e),
table_id=ir.table_id,
table_name=table_name,
source_name=source_name,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find_source(self, source_id: str) -> Source:
for s in self._catalog.sources:
if s.source_id == source_id:
return s
raise ValueError(f"source_id {source_id!r} not in catalog")
@staticmethod
def _parse_client_id(location_ref: str) -> str:
if not location_ref.startswith(_DBCLIENT_PREFIX):
raise ValueError(
f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}"
)
client_id = location_ref[len(_DBCLIENT_PREFIX):]
if not client_id:
raise ValueError("location_ref is missing client_id after 'dbclient://'")
return client_id
@staticmethod
async def _fetch_client(client_id: str) -> Any:
async with AsyncSessionLocal() as session:
client = await database_client_service.get(session, client_id)
if client is None:
raise ValueError(f"DatabaseClient {client_id!r} not found")
if client.status != "active":
raise ValueError(
f"DatabaseClient {client_id!r} is not active "
f"(status={client.status!r})"
)
return client
@staticmethod
def _sqlglot_guard(sql: str) -> None:
"""Defense-in-depth: ensure the compiled SQL is a SELECT statement.
The compiler is already deterministic and only constructs SELECTs from
validated IR, but this guard catches any future bug that could leak
DML/DDL through.
"""
try:
parsed = sqlglot.parse_one(sql, read="postgres")
except sqlglot.errors.ParseError as e:
raise ValueError(f"compiled SQL failed to parse: {e}") from e
if not isinstance(parsed, exp.Select):
raise ValueError(
f"compiled SQL is not a SELECT (got {type(parsed).__name__})"
)
forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter)
for node in parsed.find_all(forbidden):
raise ValueError(
f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}"
)
@staticmethod
def _run_sync(
client_id: str, db_type: str, creds: dict, compiled: CompiledSql
) -> tuple[list[str], list[dict]]:
engine = user_engine_cache.get_engine(client_id, db_type, creds)
if engine is not None:
# Pooled, reused engine (postgres-like). Read-only + statement_timeout
# are set once per physical connection (connect event in UserEngineCache),
# so no per-query SET round-trips and no dispose — the connection returns
# to the pool warm for the next query.
with engine.connect() as conn:
result = conn.execute(text(compiled.sql), compiled.params)
return list(result.keys()), [dict(row) for row in result.mappings()]
# Legacy per-call path for non-postgres db_types (connect once, dispose).
# These never set read-only/timeout before, so behavior is unchanged.
with db_pipeline_service.engine_scope(db_type, creds) as eng:
with eng.connect() as conn:
result = conn.execute(text(compiled.sql), compiled.params)
return list(result.keys()), [dict(row) for row in result.mappings()]
# ------------------------------------------------------------------
# Speculative pre-connect (DB3)
# ------------------------------------------------------------------
@classmethod
async def prewarm(cls, catalog: Catalog, user_id: str) -> None:
"""Best-effort: warm pooled engines for the catalog's schema sources.
Called at slow-path entry so the TCP+TLS+auth handshake overlaps the ~4s
Planner LLM call — by the time `retrieve_data` runs, the connection is
already established. Warming is an optimization, never a requirement, so
this never raises and per-source failures are swallowed.
"""
for source in catalog.sources:
if source.source_type != "schema":
continue
try:
client_id = cls._parse_client_id(source.location_ref)
client = await cls._fetch_client(client_id)
if client.user_id != user_id:
continue
creds = decrypt_credentials_dict(client.credentials)
await asyncio.to_thread(cls._warm_sync, client_id, client.db_type, creds)
except Exception as exc: # noqa: BLE001 — best-effort warming
logger.info("prewarm skipped", source_id=source.source_id, error=str(exc))
@staticmethod
def _warm_sync(client_id: str, db_type: str, creds: dict) -> None:
engine = user_engine_cache.get_engine(client_id, db_type, creds)
if engine is not None:
# Open + return a pooled physical connection: forces the handshake and
# runs the connect-event session SETs, leaving the pool warm.
with engine.connect():
pass