Spaces:
Sleeping
Sleeping
| """Query profiling utilities. | |
| A watchdog-wrapped DuckDB execute plus a median-of-3 warm timer. | |
| * :func:`execute_once_timed` runs a statement exactly once, enforcing a | |
| hard ``timeout_s`` wall-clock budget. It is the single entry point used | |
| by the env for agent-provided SQL so the documented query timeout | |
| cannot be bypassed. An optional ``max_rows`` caps result-set | |
| materialization — the fetch is aborted as soon as more than | |
| ``max_rows`` rows are observed, so a pathological ``SELECT *`` cannot | |
| drive the server OOM before the caller's size check runs. | |
| * :func:`execute_hash_timed` executes a statement once and hashes its full | |
| result incrementally via ``fetchmany`` so correctness checks do not have | |
| to materialize the full row set in Python memory. | |
| * :func:`median_of_3_warm_ms` performs one untimed warm-up then three | |
| timed runs and returns the median milliseconds. Used by scenario | |
| materialization to publish a stable baseline runtime. | |
| Both helpers raise :class:`TimeoutError` when a single run exceeds the | |
| budget; ``duckdb.Error`` propagates unchanged to the caller. | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import os | |
| import threading | |
| import time | |
| from collections.abc import Callable, Iterator | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Any, cast | |
| from engine.verifier import canonical_row_hash | |
| from utilities.logger import get_module_logger | |
| if TYPE_CHECKING: | |
| import duckdb | |
| DEFAULT_TIMEOUT_S: float = 2.0 | |
| INTERRUPT_GRACE_S: float = 0.25 | |
| # Maximum number of watchdog escalations (leaked threads) tolerated before | |
| # logging at CRITICAL. Override via SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS. | |
| MAX_LEAKED_WATCHDOG_THREADS: int = int(os.environ.get("SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS", "3")) | |
| _LOG = get_module_logger(__name__) | |
| _FETCH_CHUNK_ROWS = 1024 | |
| # Module-level counter — incremented each time a watchdog thread survives | |
| # interrupt (i.e. a genuine escalation, not a normal timeout). Thread-safe | |
| # via _watchdog_leak_lock. Callers can read this via get_watchdog_leak_count(). | |
| _watchdog_leak_lock: threading.Lock = threading.Lock() | |
| _watchdog_leaked_count: int = 0 | |
| def get_watchdog_leak_count() -> int: | |
| """Return the cumulative number of watchdog threads that survived interrupt. | |
| A non-zero value means at least one DuckDB worker thread was not stopped | |
| cleanly and is still alive in the background. Production monitoring should | |
| alert when this exceeds :data:`MAX_LEAKED_WATCHDOG_THREADS`. | |
| """ | |
| return _watchdog_leaked_count | |
| class QueryWatchdogEscalationError(RuntimeError): | |
| """DuckDB worker survived interrupt; the connection is no longer safe.""" | |
| class TimedResult: | |
| """Output of :func:`execute_once_timed`. | |
| ``columns`` preserves DuckDB's cursor ``description`` order so callers | |
| can emit a :class:`models.RunQueryResult` without re-executing the | |
| query just to recover column names. | |
| ``truncated`` is ``True`` when the caller supplied a ``max_rows`` cap | |
| and the query produced strictly more rows than that cap; in that | |
| case ``rows`` contains exactly ``max_rows + 1`` entries (the | |
| one-over read that proves overflow). Callers that care about size | |
| limits must branch on ``truncated`` rather than re-checking | |
| ``len(rows)`` against their cap. | |
| """ | |
| columns: list[str] | |
| rows: list[tuple[Any, ...]] | |
| elapsed_ms: float | |
| truncated: bool = False | |
| def _fetch_capped( | |
| cursor: duckdb.DuckDBPyConnection, | |
| max_rows: int, | |
| ) -> tuple[list[tuple[Any, ...]], bool]: | |
| """Drain at most ``max_rows + 1`` rows from ``cursor`` via fetchmany. | |
| Returns ``(rows, truncated)``. When ``truncated`` is ``True`` the | |
| cursor still has unread rows — we stopped on the first row past the | |
| cap so the caller can signal overflow without materialising the | |
| rest of a potentially enormous result set. | |
| """ | |
| # chunk=1024 trades a few extra Python calls for not over-fetching | |
| # by orders of magnitude when results are modest. The +1 in the | |
| # final budget is what makes overflow detectable. | |
| rows: list[tuple[Any, ...]] = [] | |
| budget = max_rows + 1 | |
| while budget > 0: | |
| batch = cursor.fetchmany(min(_FETCH_CHUNK_ROWS, budget)) | |
| if not batch: | |
| return rows, False | |
| rows.extend(batch) | |
| budget -= len(batch) | |
| return rows, len(rows) > max_rows | |
| def _iter_cursor_rows( | |
| cursor: duckdb.DuckDBPyConnection, | |
| ) -> Iterator[tuple[Any, ...]]: | |
| while True: | |
| batch = cursor.fetchmany(_FETCH_CHUNK_ROWS) | |
| if not batch: | |
| return | |
| yield from batch | |
| def _run_worker_with_watchdog[T]( | |
| conn: duckdb.DuckDBPyConnection, | |
| sql: str, | |
| timeout_s: float, | |
| worker: Callable[[], T], | |
| ) -> T: | |
| result_holder: dict[str, object] = {} | |
| def runner() -> None: | |
| try: | |
| result_holder["result"] = worker() | |
| except BaseException as exc: # Must forward all failures from the worker thread. | |
| result_holder["error"] = exc | |
| thread = threading.Thread(target=runner, daemon=True) | |
| thread.start() | |
| thread.join(timeout_s) | |
| if thread.is_alive(): | |
| # DuckDB's interrupt API is connection-scoped and thread-safe; | |
| # we ask the query to unwind and then wait *unconditionally* | |
| # for the worker to exit before surfacing the timeout to the | |
| # caller. If we returned while the thread were still alive, it | |
| # would retain access to ``conn`` and its result could race | |
| # future queries on the same connection — a previously | |
| # observed source of flaky post-timeout behaviour. In practice | |
| # DuckDB's interrupt releases the worker within a handful of | |
| # milliseconds; if the engine ever fails to honour interrupt | |
| # the process will hang here, which is the correct failure | |
| # mode for a connection whose state is no longer safe to | |
| # reuse. | |
| with contextlib.suppress(Exception): | |
| conn.interrupt() | |
| thread.join(INTERRUPT_GRACE_S) | |
| if thread.is_alive(): | |
| global _watchdog_leaked_count | |
| with _watchdog_leak_lock: | |
| _watchdog_leaked_count += 1 | |
| leak_count = _watchdog_leaked_count | |
| log_fn = _LOG.critical if leak_count > MAX_LEAKED_WATCHDOG_THREADS else _LOG.error | |
| log_fn( | |
| "query watchdog failed to stop worker after %.3fs timeout + %.3fs grace" | |
| " (cumulative leaked threads: %d)", | |
| timeout_s, | |
| INTERRUPT_GRACE_S, | |
| leak_count, | |
| ) | |
| raise QueryWatchdogEscalationError( | |
| f"query exceeded {timeout_s}s and worker did not stop after interrupt: {sql[:120]!r}" | |
| ) | |
| raise TimeoutError(f"query exceeded {timeout_s}s: {sql[:120]!r}") | |
| if "error" in result_holder: | |
| error = result_holder["error"] | |
| assert isinstance(error, BaseException) | |
| raise error | |
| return cast(T, result_holder["result"]) | |
| def _run_with_watchdog( | |
| conn: duckdb.DuckDBPyConnection, | |
| sql: str, | |
| timeout_s: float, | |
| max_rows: int | None, | |
| ) -> TimedResult: | |
| def worker() -> TimedResult: | |
| start = time.perf_counter_ns() | |
| cursor = conn.execute(sql) | |
| columns = [d[0] for d in cursor.description] if cursor.description else [] | |
| if max_rows is None: | |
| rows = cursor.fetchall() | |
| truncated = False | |
| else: | |
| rows, truncated = _fetch_capped(cursor, max_rows) | |
| elapsed_ns = time.perf_counter_ns() - start | |
| return TimedResult( | |
| columns=columns, | |
| rows=rows, | |
| elapsed_ms=elapsed_ns / 1_000_000.0, | |
| truncated=truncated, | |
| ) | |
| result = _run_worker_with_watchdog(conn, sql, timeout_s, worker) | |
| assert isinstance(result, TimedResult) | |
| return result | |
| def execute_once_timed( | |
| conn: duckdb.DuckDBPyConnection, | |
| sql: str, | |
| *, | |
| timeout_s: float = DEFAULT_TIMEOUT_S, | |
| max_rows: int | None = None, | |
| ) -> tuple[list[tuple[Any, ...]], float]: | |
| """Single timed execution — returns ``(rows, elapsed_ms)``. | |
| Thin wrapper for callers that don't need column metadata or the | |
| truncation flag. | |
| """ | |
| res = _run_with_watchdog(conn, sql, timeout_s, max_rows) | |
| return res.rows, res.elapsed_ms | |
| def execute_once_with_columns( | |
| conn: duckdb.DuckDBPyConnection, | |
| sql: str, | |
| *, | |
| timeout_s: float = DEFAULT_TIMEOUT_S, | |
| max_rows: int | None = None, | |
| ) -> TimedResult: | |
| """Single timed execution — returns columns + rows + elapsed_ms. | |
| When ``max_rows`` is supplied, the fetch aborts at the first row | |
| past the cap and ``TimedResult.truncated`` is set. The elapsed | |
| milliseconds in that case reflect the partial scan, not the query's | |
| would-be completion time — a truncated read is a *hard error* in | |
| agent-facing code paths, not a performance measurement. | |
| """ | |
| return _run_with_watchdog(conn, sql, timeout_s, max_rows) | |
| def execute_hash_timed( | |
| conn: duckdb.DuckDBPyConnection, | |
| sql: str, | |
| *, | |
| timeout_s: float = DEFAULT_TIMEOUT_S, | |
| ) -> tuple[str, float]: | |
| """Single timed execution — returns ``(result_hash, elapsed_ms)``. | |
| Unlike :func:`execute_once_timed`, this drains the cursor via | |
| ``fetchmany`` and hashes rows incrementally, so callers can compare a | |
| large final result to ground truth without materializing the full row | |
| set in Python memory. | |
| """ | |
| def worker() -> tuple[str, float]: | |
| start = time.perf_counter_ns() | |
| cursor = conn.execute(sql) | |
| result_hash = canonical_row_hash(_iter_cursor_rows(cursor)) | |
| elapsed_ns = time.perf_counter_ns() - start | |
| return result_hash, elapsed_ns / 1_000_000.0 | |
| result = _run_worker_with_watchdog(conn, sql, timeout_s, worker) | |
| result_hash, elapsed_ms = result | |
| assert isinstance(result_hash, str) | |
| assert isinstance(elapsed_ms, float) | |
| return result_hash, elapsed_ms | |
| def median_of_3_warm_ms( | |
| conn: duckdb.DuckDBPyConnection, | |
| sql: str, | |
| *, | |
| timeout_s: float = DEFAULT_TIMEOUT_S, | |
| ) -> float: | |
| """Warm cache, then median-of-3 timed runs. Returns milliseconds.""" | |
| _run_with_watchdog(conn, sql, timeout_s, None) | |
| timings = [_run_with_watchdog(conn, sql, timeout_s, None).elapsed_ms for _ in range(3)] | |
| timings.sort() | |
| return timings[1] | |
| __all__ = [ | |
| "DEFAULT_TIMEOUT_S", | |
| "INTERRUPT_GRACE_S", | |
| "MAX_LEAKED_WATCHDOG_THREADS", | |
| "QueryWatchdogEscalationError", | |
| "TimedResult", | |
| "execute_hash_timed", | |
| "execute_once_timed", | |
| "execute_once_with_columns", | |
| "get_watchdog_leak_count", | |
| "median_of_3_warm_ms", | |
| ] | |