sql-drift-env / engine /profiler.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""Query profiling utilities.
A watchdog-wrapped DuckDB execute plus a median-of-3 warm timer.
* :func:`execute_once_timed` runs a statement exactly once, enforcing a
hard ``timeout_s`` wall-clock budget. It is the single entry point used
by the env for agent-provided SQL so the documented query timeout
cannot be bypassed. An optional ``max_rows`` caps result-set
materialization — the fetch is aborted as soon as more than
``max_rows`` rows are observed, so a pathological ``SELECT *`` cannot
drive the server OOM before the caller's size check runs.
* :func:`execute_hash_timed` executes a statement once and hashes its full
result incrementally via ``fetchmany`` so correctness checks do not have
to materialize the full row set in Python memory.
* :func:`median_of_3_warm_ms` performs one untimed warm-up then three
timed runs and returns the median milliseconds. Used by scenario
materialization to publish a stable baseline runtime.
Both helpers raise :class:`TimeoutError` when a single run exceeds the
budget; ``duckdb.Error`` propagates unchanged to the caller.
"""
from __future__ import annotations
import contextlib
import os
import threading
import time
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from engine.verifier import canonical_row_hash
from utilities.logger import get_module_logger
if TYPE_CHECKING:
import duckdb
DEFAULT_TIMEOUT_S: float = 2.0
INTERRUPT_GRACE_S: float = 0.25
# Maximum number of watchdog escalations (leaked threads) tolerated before
# logging at CRITICAL. Override via SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS.
MAX_LEAKED_WATCHDOG_THREADS: int = int(os.environ.get("SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS", "3"))
_LOG = get_module_logger(__name__)
_FETCH_CHUNK_ROWS = 1024
# Module-level counter — incremented each time a watchdog thread survives
# interrupt (i.e. a genuine escalation, not a normal timeout). Thread-safe
# via _watchdog_leak_lock. Callers can read this via get_watchdog_leak_count().
_watchdog_leak_lock: threading.Lock = threading.Lock()
_watchdog_leaked_count: int = 0
def get_watchdog_leak_count() -> int:
"""Return the cumulative number of watchdog threads that survived interrupt.
A non-zero value means at least one DuckDB worker thread was not stopped
cleanly and is still alive in the background. Production monitoring should
alert when this exceeds :data:`MAX_LEAKED_WATCHDOG_THREADS`.
"""
return _watchdog_leaked_count
class QueryWatchdogEscalationError(RuntimeError):
"""DuckDB worker survived interrupt; the connection is no longer safe."""
@dataclass(frozen=True)
class TimedResult:
"""Output of :func:`execute_once_timed`.
``columns`` preserves DuckDB's cursor ``description`` order so callers
can emit a :class:`models.RunQueryResult` without re-executing the
query just to recover column names.
``truncated`` is ``True`` when the caller supplied a ``max_rows`` cap
and the query produced strictly more rows than that cap; in that
case ``rows`` contains exactly ``max_rows + 1`` entries (the
one-over read that proves overflow). Callers that care about size
limits must branch on ``truncated`` rather than re-checking
``len(rows)`` against their cap.
"""
columns: list[str]
rows: list[tuple[Any, ...]]
elapsed_ms: float
truncated: bool = False
def _fetch_capped(
cursor: duckdb.DuckDBPyConnection,
max_rows: int,
) -> tuple[list[tuple[Any, ...]], bool]:
"""Drain at most ``max_rows + 1`` rows from ``cursor`` via fetchmany.
Returns ``(rows, truncated)``. When ``truncated`` is ``True`` the
cursor still has unread rows — we stopped on the first row past the
cap so the caller can signal overflow without materialising the
rest of a potentially enormous result set.
"""
# chunk=1024 trades a few extra Python calls for not over-fetching
# by orders of magnitude when results are modest. The +1 in the
# final budget is what makes overflow detectable.
rows: list[tuple[Any, ...]] = []
budget = max_rows + 1
while budget > 0:
batch = cursor.fetchmany(min(_FETCH_CHUNK_ROWS, budget))
if not batch:
return rows, False
rows.extend(batch)
budget -= len(batch)
return rows, len(rows) > max_rows
def _iter_cursor_rows(
cursor: duckdb.DuckDBPyConnection,
) -> Iterator[tuple[Any, ...]]:
while True:
batch = cursor.fetchmany(_FETCH_CHUNK_ROWS)
if not batch:
return
yield from batch
def _run_worker_with_watchdog[T](
conn: duckdb.DuckDBPyConnection,
sql: str,
timeout_s: float,
worker: Callable[[], T],
) -> T:
result_holder: dict[str, object] = {}
def runner() -> None:
try:
result_holder["result"] = worker()
except BaseException as exc: # Must forward all failures from the worker thread.
result_holder["error"] = exc
thread = threading.Thread(target=runner, daemon=True)
thread.start()
thread.join(timeout_s)
if thread.is_alive():
# DuckDB's interrupt API is connection-scoped and thread-safe;
# we ask the query to unwind and then wait *unconditionally*
# for the worker to exit before surfacing the timeout to the
# caller. If we returned while the thread were still alive, it
# would retain access to ``conn`` and its result could race
# future queries on the same connection — a previously
# observed source of flaky post-timeout behaviour. In practice
# DuckDB's interrupt releases the worker within a handful of
# milliseconds; if the engine ever fails to honour interrupt
# the process will hang here, which is the correct failure
# mode for a connection whose state is no longer safe to
# reuse.
with contextlib.suppress(Exception):
conn.interrupt()
thread.join(INTERRUPT_GRACE_S)
if thread.is_alive():
global _watchdog_leaked_count
with _watchdog_leak_lock:
_watchdog_leaked_count += 1
leak_count = _watchdog_leaked_count
log_fn = _LOG.critical if leak_count > MAX_LEAKED_WATCHDOG_THREADS else _LOG.error
log_fn(
"query watchdog failed to stop worker after %.3fs timeout + %.3fs grace"
" (cumulative leaked threads: %d)",
timeout_s,
INTERRUPT_GRACE_S,
leak_count,
)
raise QueryWatchdogEscalationError(
f"query exceeded {timeout_s}s and worker did not stop after interrupt: {sql[:120]!r}"
)
raise TimeoutError(f"query exceeded {timeout_s}s: {sql[:120]!r}")
if "error" in result_holder:
error = result_holder["error"]
assert isinstance(error, BaseException)
raise error
return cast(T, result_holder["result"])
def _run_with_watchdog(
conn: duckdb.DuckDBPyConnection,
sql: str,
timeout_s: float,
max_rows: int | None,
) -> TimedResult:
def worker() -> TimedResult:
start = time.perf_counter_ns()
cursor = conn.execute(sql)
columns = [d[0] for d in cursor.description] if cursor.description else []
if max_rows is None:
rows = cursor.fetchall()
truncated = False
else:
rows, truncated = _fetch_capped(cursor, max_rows)
elapsed_ns = time.perf_counter_ns() - start
return TimedResult(
columns=columns,
rows=rows,
elapsed_ms=elapsed_ns / 1_000_000.0,
truncated=truncated,
)
result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
assert isinstance(result, TimedResult)
return result
def execute_once_timed(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
max_rows: int | None = None,
) -> tuple[list[tuple[Any, ...]], float]:
"""Single timed execution — returns ``(rows, elapsed_ms)``.
Thin wrapper for callers that don't need column metadata or the
truncation flag.
"""
res = _run_with_watchdog(conn, sql, timeout_s, max_rows)
return res.rows, res.elapsed_ms
def execute_once_with_columns(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
max_rows: int | None = None,
) -> TimedResult:
"""Single timed execution — returns columns + rows + elapsed_ms.
When ``max_rows`` is supplied, the fetch aborts at the first row
past the cap and ``TimedResult.truncated`` is set. The elapsed
milliseconds in that case reflect the partial scan, not the query's
would-be completion time — a truncated read is a *hard error* in
agent-facing code paths, not a performance measurement.
"""
return _run_with_watchdog(conn, sql, timeout_s, max_rows)
def execute_hash_timed(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
) -> tuple[str, float]:
"""Single timed execution — returns ``(result_hash, elapsed_ms)``.
Unlike :func:`execute_once_timed`, this drains the cursor via
``fetchmany`` and hashes rows incrementally, so callers can compare a
large final result to ground truth without materializing the full row
set in Python memory.
"""
def worker() -> tuple[str, float]:
start = time.perf_counter_ns()
cursor = conn.execute(sql)
result_hash = canonical_row_hash(_iter_cursor_rows(cursor))
elapsed_ns = time.perf_counter_ns() - start
return result_hash, elapsed_ns / 1_000_000.0
result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
result_hash, elapsed_ms = result
assert isinstance(result_hash, str)
assert isinstance(elapsed_ms, float)
return result_hash, elapsed_ms
def median_of_3_warm_ms(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
) -> float:
"""Warm cache, then median-of-3 timed runs. Returns milliseconds."""
_run_with_watchdog(conn, sql, timeout_s, None)
timings = [_run_with_watchdog(conn, sql, timeout_s, None).elapsed_ms for _ in range(3)]
timings.sort()
return timings[1]
__all__ = [
"DEFAULT_TIMEOUT_S",
"INTERRUPT_GRACE_S",
"MAX_LEAKED_WATCHDOG_THREADS",
"QueryWatchdogEscalationError",
"TimedResult",
"execute_hash_timed",
"execute_once_timed",
"execute_once_with_columns",
"get_watchdog_leak_count",
"median_of_3_warm_ms",
]