File size: 10,694 Bytes
5850885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""Query profiling utilities.

A watchdog-wrapped DuckDB execute plus a median-of-3 warm timer.

* :func:`execute_once_timed` runs a statement exactly once, enforcing a
  hard ``timeout_s`` wall-clock budget. It is the single entry point used
  by the env for agent-provided SQL so the documented query timeout
  cannot be bypassed. An optional ``max_rows`` caps result-set
  materialization β€” the fetch is aborted as soon as more than
  ``max_rows`` rows are observed, so a pathological ``SELECT *`` cannot
  drive the server OOM before the caller's size check runs.
* :func:`execute_hash_timed` executes a statement once and hashes its full
  result incrementally via ``fetchmany`` so correctness checks do not have
  to materialize the full row set in Python memory.
* :func:`median_of_3_warm_ms` performs one untimed warm-up then three
  timed runs and returns the median milliseconds. Used by scenario
  materialization to publish a stable baseline runtime.

Both helpers raise :class:`TimeoutError` when a single run exceeds the
budget; ``duckdb.Error`` propagates unchanged to the caller.
"""

from __future__ import annotations

import contextlib
import os
import threading
import time
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast

from engine.verifier import canonical_row_hash
from utilities.logger import get_module_logger

if TYPE_CHECKING:
    import duckdb


DEFAULT_TIMEOUT_S: float = 2.0
INTERRUPT_GRACE_S: float = 0.25
# Maximum number of watchdog escalations (leaked threads) tolerated before
# logging at CRITICAL.  Override via SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS.
MAX_LEAKED_WATCHDOG_THREADS: int = int(os.environ.get("SQL_DRIFT_MAX_LEAKED_WATCHDOG_THREADS", "3"))

_LOG = get_module_logger(__name__)
_FETCH_CHUNK_ROWS = 1024

# Module-level counter β€” incremented each time a watchdog thread survives
# interrupt (i.e. a genuine escalation, not a normal timeout).  Thread-safe
# via _watchdog_leak_lock.  Callers can read this via get_watchdog_leak_count().
_watchdog_leak_lock: threading.Lock = threading.Lock()
_watchdog_leaked_count: int = 0


def get_watchdog_leak_count() -> int:
    """Return the cumulative number of watchdog threads that survived interrupt.

    A non-zero value means at least one DuckDB worker thread was not stopped
    cleanly and is still alive in the background.  Production monitoring should
    alert when this exceeds :data:`MAX_LEAKED_WATCHDOG_THREADS`.
    """
    return _watchdog_leaked_count


class QueryWatchdogEscalationError(RuntimeError):
    """DuckDB worker survived interrupt; the connection is no longer safe."""


@dataclass(frozen=True)
class TimedResult:
    """Output of :func:`execute_once_timed`.

    ``columns`` preserves DuckDB's cursor ``description`` order so callers
    can emit a :class:`models.RunQueryResult` without re-executing the
    query just to recover column names.

    ``truncated`` is ``True`` when the caller supplied a ``max_rows`` cap
    and the query produced strictly more rows than that cap; in that
    case ``rows`` contains exactly ``max_rows + 1`` entries (the
    one-over read that proves overflow). Callers that care about size
    limits must branch on ``truncated`` rather than re-checking
    ``len(rows)`` against their cap.
    """

    columns: list[str]
    rows: list[tuple[Any, ...]]
    elapsed_ms: float
    truncated: bool = False


def _fetch_capped(
    cursor: duckdb.DuckDBPyConnection,
    max_rows: int,
) -> tuple[list[tuple[Any, ...]], bool]:
    """Drain at most ``max_rows + 1`` rows from ``cursor`` via fetchmany.

    Returns ``(rows, truncated)``. When ``truncated`` is ``True`` the
    cursor still has unread rows β€” we stopped on the first row past the
    cap so the caller can signal overflow without materialising the
    rest of a potentially enormous result set.
    """
    # chunk=1024 trades a few extra Python calls for not over-fetching
    # by orders of magnitude when results are modest. The +1 in the
    # final budget is what makes overflow detectable.
    rows: list[tuple[Any, ...]] = []
    budget = max_rows + 1
    while budget > 0:
        batch = cursor.fetchmany(min(_FETCH_CHUNK_ROWS, budget))
        if not batch:
            return rows, False
        rows.extend(batch)
        budget -= len(batch)
    return rows, len(rows) > max_rows


def _iter_cursor_rows(
    cursor: duckdb.DuckDBPyConnection,
) -> Iterator[tuple[Any, ...]]:
    while True:
        batch = cursor.fetchmany(_FETCH_CHUNK_ROWS)
        if not batch:
            return
        yield from batch


def _run_worker_with_watchdog[T](
    conn: duckdb.DuckDBPyConnection,
    sql: str,
    timeout_s: float,
    worker: Callable[[], T],
) -> T:
    result_holder: dict[str, object] = {}

    def runner() -> None:
        try:
            result_holder["result"] = worker()
        except BaseException as exc:  # Must forward all failures from the worker thread.
            result_holder["error"] = exc

    thread = threading.Thread(target=runner, daemon=True)
    thread.start()
    thread.join(timeout_s)
    if thread.is_alive():
        # DuckDB's interrupt API is connection-scoped and thread-safe;
        # we ask the query to unwind and then wait *unconditionally*
        # for the worker to exit before surfacing the timeout to the
        # caller. If we returned while the thread were still alive, it
        # would retain access to ``conn`` and its result could race
        # future queries on the same connection β€” a previously
        # observed source of flaky post-timeout behaviour. In practice
        # DuckDB's interrupt releases the worker within a handful of
        # milliseconds; if the engine ever fails to honour interrupt
        # the process will hang here, which is the correct failure
        # mode for a connection whose state is no longer safe to
        # reuse.
        with contextlib.suppress(Exception):
            conn.interrupt()
        thread.join(INTERRUPT_GRACE_S)
        if thread.is_alive():
            global _watchdog_leaked_count
            with _watchdog_leak_lock:
                _watchdog_leaked_count += 1
                leak_count = _watchdog_leaked_count
            log_fn = _LOG.critical if leak_count > MAX_LEAKED_WATCHDOG_THREADS else _LOG.error
            log_fn(
                "query watchdog failed to stop worker after %.3fs timeout + %.3fs grace"
                " (cumulative leaked threads: %d)",
                timeout_s,
                INTERRUPT_GRACE_S,
                leak_count,
            )
            raise QueryWatchdogEscalationError(
                f"query exceeded {timeout_s}s and worker did not stop after interrupt: {sql[:120]!r}"
            )
        raise TimeoutError(f"query exceeded {timeout_s}s: {sql[:120]!r}")
    if "error" in result_holder:
        error = result_holder["error"]
        assert isinstance(error, BaseException)
        raise error
    return cast(T, result_holder["result"])


def _run_with_watchdog(
    conn: duckdb.DuckDBPyConnection,
    sql: str,
    timeout_s: float,
    max_rows: int | None,
) -> TimedResult:
    def worker() -> TimedResult:
        start = time.perf_counter_ns()
        cursor = conn.execute(sql)
        columns = [d[0] for d in cursor.description] if cursor.description else []
        if max_rows is None:
            rows = cursor.fetchall()
            truncated = False
        else:
            rows, truncated = _fetch_capped(cursor, max_rows)
        elapsed_ns = time.perf_counter_ns() - start
        return TimedResult(
            columns=columns,
            rows=rows,
            elapsed_ms=elapsed_ns / 1_000_000.0,
            truncated=truncated,
        )

    result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
    assert isinstance(result, TimedResult)
    return result


def execute_once_timed(
    conn: duckdb.DuckDBPyConnection,
    sql: str,
    *,
    timeout_s: float = DEFAULT_TIMEOUT_S,
    max_rows: int | None = None,
) -> tuple[list[tuple[Any, ...]], float]:
    """Single timed execution β€” returns ``(rows, elapsed_ms)``.

    Thin wrapper for callers that don't need column metadata or the
    truncation flag.
    """
    res = _run_with_watchdog(conn, sql, timeout_s, max_rows)
    return res.rows, res.elapsed_ms


def execute_once_with_columns(
    conn: duckdb.DuckDBPyConnection,
    sql: str,
    *,
    timeout_s: float = DEFAULT_TIMEOUT_S,
    max_rows: int | None = None,
) -> TimedResult:
    """Single timed execution β€” returns columns + rows + elapsed_ms.

    When ``max_rows`` is supplied, the fetch aborts at the first row
    past the cap and ``TimedResult.truncated`` is set. The elapsed
    milliseconds in that case reflect the partial scan, not the query's
    would-be completion time β€” a truncated read is a *hard error* in
    agent-facing code paths, not a performance measurement.
    """
    return _run_with_watchdog(conn, sql, timeout_s, max_rows)


def execute_hash_timed(
    conn: duckdb.DuckDBPyConnection,
    sql: str,
    *,
    timeout_s: float = DEFAULT_TIMEOUT_S,
) -> tuple[str, float]:
    """Single timed execution β€” returns ``(result_hash, elapsed_ms)``.

    Unlike :func:`execute_once_timed`, this drains the cursor via
    ``fetchmany`` and hashes rows incrementally, so callers can compare a
    large final result to ground truth without materializing the full row
    set in Python memory.
    """

    def worker() -> tuple[str, float]:
        start = time.perf_counter_ns()
        cursor = conn.execute(sql)
        result_hash = canonical_row_hash(_iter_cursor_rows(cursor))
        elapsed_ns = time.perf_counter_ns() - start
        return result_hash, elapsed_ns / 1_000_000.0

    result = _run_worker_with_watchdog(conn, sql, timeout_s, worker)
    result_hash, elapsed_ms = result
    assert isinstance(result_hash, str)
    assert isinstance(elapsed_ms, float)
    return result_hash, elapsed_ms


def median_of_3_warm_ms(
    conn: duckdb.DuckDBPyConnection,
    sql: str,
    *,
    timeout_s: float = DEFAULT_TIMEOUT_S,
) -> float:
    """Warm cache, then median-of-3 timed runs. Returns milliseconds."""
    _run_with_watchdog(conn, sql, timeout_s, None)
    timings = [_run_with_watchdog(conn, sql, timeout_s, None).elapsed_ms for _ in range(3)]
    timings.sort()
    return timings[1]


__all__ = [
    "DEFAULT_TIMEOUT_S",
    "INTERRUPT_GRACE_S",
    "MAX_LEAKED_WATCHDOG_THREADS",
    "QueryWatchdogEscalationError",
    "TimedResult",
    "execute_hash_timed",
    "execute_once_timed",
    "execute_once_with_columns",
    "get_watchdog_leak_count",
    "median_of_3_warm_ms",
]