ishaq101's picture
feat/Planner Agent (#2)
81e5fe7
Raw
History Blame
7.04 kB
"""UserEngineCache — pooled, reused SQLAlchemy engines for users' external DBs.
The query path (`DbExecutor`) previously built a fresh engine and tore it down on
EVERY query (`db_pipeline_service.engine_scope`), paying a full TCP+TLS+auth
handshake per call (~6-8s measured, dominating slow-path latency). That helper's
connect-once-then-dispose semantics are correct for the *ingestion* pipeline
(infrequent, one connection per run) but wrong for the query path (frequent,
latency-sensitive, repeated to the same DB).
This module caches one pooled engine per external DB so connections stay warm
across queries. Scope: **postgres / supabase only** (the measured case and the
`schema` source type). Other db_types fall back to the legacy per-call path in
`DbExecutor`, so nothing regresses.
Safety / multi-tenancy:
- Key = client_id + a hash of the decrypted credentials, so a credential rotation
produces a new key (the stale engine idle-evicts) — a cached engine never serves
rotated creds.
- Read-only + statement_timeout are pinned at connection establishment via libpq
`options` (read-only-at-birth), so they can't be escaped by a reused pooled
connection and cost zero per-query round-trips.
- The caller still re-fetches the DatabaseClient row every query and re-checks
ownership + `active` status — caching the engine never bypasses authorization.
- Bounded LRU + idle TTL cap memory / file descriptors / connections held on the
user's DB. `invalidate(client_id)` disposes eagerly on client update/delete.
"""
from __future__ import annotations
import hashlib
import json
import threading
import time
from collections import OrderedDict
from sqlalchemy import URL, create_engine, event
from sqlalchemy.engine import Engine
from src.middlewares.logging import get_logger
logger = get_logger("user_engine_cache")
_POSTGRES_LIKE = frozenset({"postgres", "supabase"})
_STATEMENT_TIMEOUT_MS = 30_000
# Pool sizing is deliberately small: this is a per-user external DB, often with a
# low max_connections, and we cache many of them. pool_pre_ping drops dead
# connections; pool_recycle bounds connection age so a serverless user DB can still
# autosuspend between bursts.
_POOL_SIZE = 1
_MAX_OVERFLOW = 2
_POOL_RECYCLE_SECONDS = 300
# Cache bounds across all users.
_MAX_ENGINES = 50
_IDLE_TTL_SECONDS = 600
def _creds_fingerprint(credentials: dict) -> str:
blob = json.dumps(credentials, sort_keys=True, default=str)
return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16]
class UserEngineCache:
"""Process-wide cache of pooled engines for users' external Postgres DBs.
Thread-safe: `DbExecutor` runs sync DB work in `asyncio.to_thread` worker
threads, so concurrent requests can hit this from multiple threads.
"""
def __init__(self) -> None:
# key -> (engine, last_used_monotonic)
self._engines: OrderedDict[str, tuple[Engine, float]] = OrderedDict()
self._lock = threading.Lock()
def get_engine(self, client_id: str, db_type: str, credentials: dict) -> Engine | None:
"""Return a pooled engine for (client_id, creds), or None if unsupported.
None means "not a postgres-like DB" — the caller should use its legacy
per-call path for those (rare, unmeasured) db_types.
"""
if db_type not in _POSTGRES_LIKE:
return None
key = f"{client_id}:{_creds_fingerprint(credentials)}"
now = time.monotonic()
with self._lock:
self._evict_idle(now)
entry = self._engines.get(key)
if entry is not None:
self._engines[key] = (entry[0], now)
self._engines.move_to_end(key)
return entry[0]
engine = self._build_engine(credentials)
self._engines[key] = (engine, now)
self._engines.move_to_end(key)
self._evict_overflow()
logger.info("user engine created", client_id=client_id, cached=len(self._engines))
return engine
def invalidate(self, client_id: str) -> None:
"""Dispose + drop every cached engine for a client (creds rotated/deleted)."""
with self._lock:
stale = [k for k in self._engines if k.startswith(f"{client_id}:")]
for k in stale:
engine, _ = self._engines.pop(k)
engine.dispose()
if stale:
logger.info("user engine invalidated", client_id=client_id, disposed=len(stale))
# ------------------------------------------------------------------
@staticmethod
def _build_engine(credentials: dict) -> Engine:
# Mirrors db_pipeline_service.connect()'s postgres URL shape, plus a real pool.
query = {"sslmode": credentials["ssl_mode"]} if credentials.get("ssl_mode") else {}
url = URL.create(
drivername="postgresql+psycopg2",
username=credentials["username"],
password=credentials["password"],
host=credentials["host"],
port=credentials["port"],
database=credentials["database"],
query=query,
)
engine = create_engine(
url,
pool_size=_POOL_SIZE,
max_overflow=_MAX_OVERFLOW,
pool_recycle=_POOL_RECYCLE_SECONDS,
pool_pre_ping=True,
)
# Apply read-only + statement_timeout once per PHYSICAL connection via a
# connect event (not per query, so the pooling latency win stays). These are
# ordinary SET commands, NOT libpq startup `options` — Neon's transaction
# pooler rejects `default_transaction_read_only` as a startup parameter but
# accepts it as a SET. Best-effort: the authoritative read-only guarantee is
# the compiler (SELECT-only) + the sqlglot DML guard; statement_timeout is
# backed by the executor's asyncio.wait_for. So a failure here must not break
# the connection.
@event.listens_for(engine, "connect")
def _init_session(dbapi_conn, _record): # noqa: ANN001
try:
cur = dbapi_conn.cursor()
cur.execute(f"SET statement_timeout = {_STATEMENT_TIMEOUT_MS}")
cur.execute("SET default_transaction_read_only = on")
cur.close()
except Exception as exc: # noqa: BLE001 — best-effort session hardening
logger.warning("session init SET failed", error=str(exc))
return engine
def _evict_idle(self, now: float) -> None:
stale = [k for k, (_, ts) in self._engines.items() if now - ts > _IDLE_TTL_SECONDS]
for k in stale:
engine, _ = self._engines.pop(k)
engine.dispose()
def _evict_overflow(self) -> None:
while len(self._engines) > _MAX_ENGINES:
_, (engine, _) = self._engines.popitem(last=False) # LRU = oldest end
engine.dispose()
# Process-wide singleton consumed by DbExecutor.
user_engine_cache = UserEngineCache()