Spaces:
Sleeping
Sleeping
File size: 10,694 Bytes
5850885 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | """Query profiling utilities.
A watchdog-wrapped DuckDB execute plus a median-of-3 warm timer.
* :func:`execute_once_timed` runs a statement exactly once, enforcing a
hard ``timeout_s`` wall-clock budget. It is the single entry point used
by the env for agent-provided SQL so the documented query timeout
cannot be bypassed. An optional ``max_rows`` caps result-set
materialization β the fetch is aborted as soon as more than
``max_rows`` rows are observed, so a pathological ``SELECT *`` cannot
drive the server OOM before the caller's size check runs.
* :func:`execute_hash_timed` executes a statement once and hashes its full
result incrementally via ``fetchmany`` so correctness checks do not have
to materialize the full row set in Python memory.
* :func:`median_of_3_warm_ms` performs one untimed warm-up then three
timed runs and returns the median milliseconds. Used by scenario
materialization to publish a stable baseline runtime.
Both helpers raise :class:`TimeoutError` when a single run exceeds the
budget; ``duckdb.Error`` propagates unchanged to the caller.
"""
from __future__ import annotations
import contextlib
import os
import threading
import time
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from engine.verifier import canonical_row_hash
from utilities.logger import get_module_logger
if TYPE_CHECKING:
import duckdb
DEFAULT_TIMEOUT_S: float = 2.0
INTERRUPT_GRACE_S: float = 0.25
# Maximum number of watchdog escalations (leaked threads) tolerated before
# logging at CRITICAL. Override via SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS.
MAX_LEAKED_WATCHDOG_THREADS: int = int(os.environ.get("SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS", "3"))
_LOG = get_module_logger(__name__)
_FETCH_CHUNK_ROWS = 1024
# Module-level counter β incremented each time a watchdog thread survives
# interrupt (i.e. a genuine escalation, not a normal timeout). Thread-safe
# via _watchdog_leak_lock. Callers can read this via get_watchdog_leak_count().
_watchdog_leak_lock: threading.Lock = threading.Lock()
_watchdog_leaked_count: int = 0
def get_watchdog_leak_count() -> int:
"""Return the cumulative number of watchdog threads that survived interrupt.
A non-zero value means at least one DuckDB worker thread was not stopped
cleanly and is still alive in the background. Production monitoring should
alert when this exceeds :data:`MAX_LEAKED_WATCHDOG_THREADS`.
"""
return _watchdog_leaked_count
class QueryWatchdogEscalationError(RuntimeError):
"""DuckDB worker survived interrupt; the connection is no longer safe."""
@dataclass(frozen=True)
class TimedResult:
"""Output of :func:`execute_once_timed`.
``columns`` preserves DuckDB's cursor ``description`` order so callers
can emit a :class:`models.RunQueryResult` without re-executing the
query just to recover column names.
``truncated`` is ``True`` when the caller supplied a ``max_rows`` cap
and the query produced strictly more rows than that cap; in that
case ``rows`` contains exactly ``max_rows + 1`` entries (the
one-over read that proves overflow). Callers that care about size
limits must branch on ``truncated`` rather than re-checking
``len(rows)`` against their cap.
"""
columns: list[str]
rows: list[tuple[Any, ...]]
elapsed_ms: float
truncated: bool = False
def _fetch_capped(
cursor: duckdb.DuckDBPyConnection,
max_rows: int,
) -> tuple[list[tuple[Any, ...]], bool]:
"""Drain at most ``max_rows + 1`` rows from ``cursor`` via fetchmany.
Returns ``(rows, truncated)``. When ``truncated`` is ``True`` the
cursor still has unread rows β we stopped on the first row past the
cap so the caller can signal overflow without materialising the
rest of a potentially enormous result set.
"""
# chunk=1024 trades a few extra Python calls for not over-fetching
# by orders of magnitude when results are modest. The +1 in the
# final budget is what makes overflow detectable.
rows: list[tuple[Any, ...]] = []
budget = max_rows + 1
while budget > 0:
batch = cursor.fetchmany(min(_FETCH_CHUNK_ROWS, budget))
if not batch:
return rows, False
rows.extend(batch)
budget -= len(batch)
return rows, len(rows) > max_rows
def _iter_cursor_rows(
cursor: duckdb.DuckDBPyConnection,
) -> Iterator[tuple[Any, ...]]:
while True:
batch = cursor.fetchmany(_FETCH_CHUNK_ROWS)
if not batch:
return
yield from batch
def _run_worker_with_watchdog[T](
conn: duckdb.DuckDBPyConnection,
sql: str,
timeout_s: float,
worker: Callable[[], T],
) -> T:
result_holder: dict[str, object] = {}
def runner() -> None:
try:
result_holder["result"] = worker()
except BaseException as exc: # Must forward all failures from the worker thread.
result_holder["error"] = exc
thread = threading.Thread(target=runner, daemon=True)
thread.start()
thread.join(timeout_s)
if thread.is_alive():
# DuckDB's interrupt API is connection-scoped and thread-safe;
# we ask the query to unwind and then wait *unconditionally*
# for the worker to exit before surfacing the timeout to the
# caller. If we returned while the thread were still alive, it
# would retain access to ``conn`` and its result could race
# future queries on the same connection β a previously
# observed source of flaky post-timeout behaviour. In practice
# DuckDB's interrupt releases the worker within a handful of
# milliseconds; if the engine ever fails to honour interrupt
# the process will hang here, which is the correct failure
# mode for a connection whose state is no longer safe to
# reuse.
with contextlib.suppress(Exception):
conn.interrupt()
thread.join(INTERRUPT_GRACE_S)
if thread.is_alive():
global _watchdog_leaked_count
with _watchdog_leak_lock:
_watchdog_leaked_count += 1
leak_count = _watchdog_leaked_count
log_fn = _LOG.critical if leak_count > MAX_LEAKED_WATCHDOG_THREADS else _LOG.error
log_fn(
"query watchdog failed to stop worker after %.3fs timeout + %.3fs grace"
" (cumulative leaked threads: %d)",
timeout_s,
INTERRUPT_GRACE_S,
leak_count,
)
raise QueryWatchdogEscalationError(
f"query exceeded {timeout_s}s and worker did not stop after interrupt: {sql[:120]!r}"
)
raise TimeoutError(f"query exceeded {timeout_s}s: {sql[:120]!r}")
if "error" in result_holder:
error = result_holder["error"]
assert isinstance(error, BaseException)
raise error
return cast(T, result_holder["result"])
def _run_with_watchdog(
conn: duckdb.DuckDBPyConnection,
sql: str,
timeout_s: float,
max_rows: int | None,
) -> TimedResult:
def worker() -> TimedResult:
start = time.perf_counter_ns()
cursor = conn.execute(sql)
columns = [d[0] for d in cursor.description] if cursor.description else []
if max_rows is None:
rows = cursor.fetchall()
truncated = False
else:
rows, truncated = _fetch_capped(cursor, max_rows)
elapsed_ns = time.perf_counter_ns() - start
return TimedResult(
columns=columns,
rows=rows,
elapsed_ms=elapsed_ns / 1_000_000.0,
truncated=truncated,
)
result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
assert isinstance(result, TimedResult)
return result
def execute_once_timed(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
max_rows: int | None = None,
) -> tuple[list[tuple[Any, ...]], float]:
"""Single timed execution β returns ``(rows, elapsed_ms)``.
Thin wrapper for callers that don't need column metadata or the
truncation flag.
"""
res = _run_with_watchdog(conn, sql, timeout_s, max_rows)
return res.rows, res.elapsed_ms
def execute_once_with_columns(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
max_rows: int | None = None,
) -> TimedResult:
"""Single timed execution β returns columns + rows + elapsed_ms.
When ``max_rows`` is supplied, the fetch aborts at the first row
past the cap and ``TimedResult.truncated`` is set. The elapsed
milliseconds in that case reflect the partial scan, not the query's
would-be completion time β a truncated read is a *hard error* in
agent-facing code paths, not a performance measurement.
"""
return _run_with_watchdog(conn, sql, timeout_s, max_rows)
def execute_hash_timed(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
) -> tuple[str, float]:
"""Single timed execution β returns ``(result_hash, elapsed_ms)``.
Unlike :func:`execute_once_timed`, this drains the cursor via
``fetchmany`` and hashes rows incrementally, so callers can compare a
large final result to ground truth without materializing the full row
set in Python memory.
"""
def worker() -> tuple[str, float]:
start = time.perf_counter_ns()
cursor = conn.execute(sql)
result_hash = canonical_row_hash(_iter_cursor_rows(cursor))
elapsed_ns = time.perf_counter_ns() - start
return result_hash, elapsed_ns / 1_000_000.0
result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
result_hash, elapsed_ms = result
assert isinstance(result_hash, str)
assert isinstance(elapsed_ms, float)
return result_hash, elapsed_ms
def median_of_3_warm_ms(
conn: duckdb.DuckDBPyConnection,
sql: str,
*,
timeout_s: float = DEFAULT_TIMEOUT_S,
) -> float:
"""Warm cache, then median-of-3 timed runs. Returns milliseconds."""
_run_with_watchdog(conn, sql, timeout_s, None)
timings = [_run_with_watchdog(conn, sql, timeout_s, None).elapsed_ms for _ in range(3)]
timings.sort()
return timings[1]
__all__ = [
"DEFAULT_TIMEOUT_S",
"INTERRUPT_GRACE_S",
"MAX_LEAKED_WATCHDOG_THREADS",
"QueryWatchdogEscalationError",
"TimedResult",
"execute_hash_timed",
"execute_once_timed",
"execute_once_with_columns",
"get_watchdog_leak_count",
"median_of_3_warm_ms",
]
|