File size: 9,539 Bytes
656f91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""Gradio-free streaming runner for one agent turn (F004).

The streaming sibling of ``evaluation.policies.evaluate()``: it drives the same
Protocol-driven ``select_action`` -> ``step`` loop against a real
``SQLEnvironment``, but **yields one frozen ``Step`` per action** instead of
accumulating aggregate metrics. The Gradio app (``app.py``) maps each ``Step`` to
a chat accordion, the result table, and the shown-SQL panel.

This module is **dependency-light by contract**: it imports NO gradio and NO
heavy training/serving deps (``torch``/``transformers``/``trl``). It is the engine
the demo rides on and is unit-testable headless against a scripted policy.

Append-only (ADR 0006): the runner never rewrites a prior ``Step``.
"""

from __future__ import annotations

from collections.abc import Iterator
from dataclasses import dataclass, field
import re
import sqlite3
import time

try:
    from ..evaluation.policies import Policy
    from .sql_environment import SQLEnvironment, resolve_db_path
except ImportError:  # pragma: no cover - Docker/flat-layout fallback
    try:
        from evaluation.policies import Policy  # type: ignore[no-redef]
        from server.sql_environment import (  # type: ignore[no-redef]
            SQLEnvironment,
            resolve_db_path,
        )
    except ImportError:
        from sql_env.evaluation.policies import Policy  # type: ignore[no-redef]
        from sql_env.server.sql_environment import (  # type: ignore[no-redef]
            SQLEnvironment,
            resolve_db_path,
        )

_FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)")


@dataclass(frozen=True)
class Step:
    """One streamed event from the agent turn.

    A frozen value dataclass mirroring ``EpisodeResult``/``IngestResult``: the
    single shape the runner yields and the UI maps. The env never half-completes
    a ``step()``, so each ``Step`` is yielded already resolved (``status`` is
    ``"done"`` or ``"error"``); the pending/done *visual* transition is a UI
    concern, not a ``Step`` field.
    """

    action_type: str  # "DESCRIBE" | "SAMPLE" | "QUERY" | "ANSWER"
    argument: str  # SQL string (QUERY), table name (DESCRIBE/SAMPLE), answer (ANSWER)
    result_text: str  # the env's formatted result text for this step (obs.result)
    status: str  # "done" | "error"
    duration_s: float  # wall-clock seconds for this step's env.step() call
    error: str = ""  # obs.error for this step ("" when status == "done")
    rows: list[tuple] = field(default_factory=list)  # only on the shown-table QUERY
    columns: list[str] = field(default_factory=list)  # headers for `rows`; else []


def _active_db_id(env: SQLEnvironment) -> str:
    """The db_id of the active episode (for read-only re-exec path resolution)."""
    episode = env._episode
    if episode is None:
        return ""
    return episode.question_record.database_name


def _resolve_db_path(env: SQLEnvironment, db_id: str) -> str | None:
    """Resolve the active ``.sqlite`` path (as a str) for read-only re-exec.

    Thin adapter over the shared ``sql_environment.resolve_db_path`` — the SAME
    candidate resolution + ``.resolve()``/``db_root in candidate.parents``
    containment guard the env's ``_open_db`` uses (R2: defense-in-depth) — so the
    runner can open its OWN connection (the env's is busy) without re-implementing
    the env's read-only execution. The ``db_id`` here is already pre-validated by
    ``env.begin_episode`` → ``_open_db`` (a `[A-Za-z0-9_]+` identifier rooted under
    ``db_dir``); the guard is shared (not duplicated) so this independent path can
    never resolve a file outside ``db_dir``. Returns ``None`` if no contained file
    resolves (the shown table then degrades gracefully).
    """
    resolved = resolve_db_path(env.db_dir, db_id)
    return str(resolved) if resolved is not None else None


def reexecute_select(
    db_path: str,
    sql: str,
    *,
    timeout_s: float = 5.0,
    max_rows: int = 20,
) -> tuple[list[tuple], list[str]]:
    """Re-execute a single SELECT read-only to get real rows + column headers.

    The env formats QUERY results into text and discards the raw tuples (and the
    observation has no column headers), so the shown ``gr.Dataframe`` needs a
    faithful re-read. SELECT/WITH-only, single statement, parameterless,
    ``mode=ro``, capped at ``max_rows``, with a ``timeout_s`` read deadline.
    Opens and closes its OWN connection (the env is untouched). ``columns`` comes
    from ``cursor.description``.

    On ANY error (bad SQL, write attempt, missing table, empty SQL) returns
    ``([], [])`` — never raises to the caller; the answer + SQL are still shown.
    """
    sql_stripped = sql.strip()
    if not sql_stripped:
        return [], []

    first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
    first_keyword = first_keyword_match.group(1).upper() if first_keyword_match else ""
    if first_keyword not in ("SELECT", "WITH"):
        return [], []

    single_statement = sql_stripped.rstrip(";").strip()
    if ";" in single_statement:
        return [], []

    connection: sqlite3.Connection | None = None
    try:
        connection = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
        deadline = time.monotonic() + timeout_s

        def _progress_callback() -> int:
            return 1 if time.monotonic() > deadline else 0

        connection.set_progress_handler(_progress_callback, 1000)
        cursor = connection.cursor()
        cursor.execute(single_statement)
        rows = cursor.fetchmany(max_rows)
        columns = [str(description[0]) for description in (cursor.description or [])]
        return list(rows), columns
    except sqlite3.Error:
        return [], []
    finally:
        if connection is not None:
            connection.close()


def run_agent_turn(
    question: str,
    env: SQLEnvironment,
    policy: Policy,
    *,
    budget: int = 15,
) -> Iterator[Step]:
    """Drive ONE non-gold agent turn and yield a ``Step`` per action.

    The streaming sibling of ``evaluate()``: mirrors ``action =
    policy.select_action(obs); obs = env.step(action)`` but yields each resolved
    step instead of accumulating metrics. Append-only (ADR 0006): never rewrites a
    prior step. The caller seeds the episode via ``env.begin_episode(...)`` before
    iterating; the loop here terminates on ``obs.done`` (ANSWER submitted or
    budget exhausted) or after ``budget`` actions, whichever comes first.

    For each successfully executed QUERY the runner re-executes the SELECT
    read-only (``reexecute_select``) to attach real ``rows`` + ``columns`` to that
    step (the env discards QUERY tuples and exposes no headers); the last such
    QUERY is the one the UI shows in the table. Non-QUERY and failed-QUERY steps
    carry empty ``rows``/``columns``.

    Args:
        question: the user's plain-English question (no gold answer exists).
        env: a real ``SQLEnvironment`` with an active non-gold episode. Used only
            for ``step()`` execution — never ``reset()``/gold scoring.
        policy: any object satisfying the ``Policy`` Protocol (default app policy
            is a deterministic stub; the real ``ModelPolicy`` swaps in at F006).
        budget: max actions to drive (default 15, matching ``env.step_budget``).

    Yields:
        ``Step`` instances in action order.

    Raises:
        Does not raise for model/SQL mistakes (those become ``status="error"``
        steps). Setup errors (e.g. a missing DB from ``begin_episode``) are the
        caller's responsibility and surface before iteration begins.
    """
    del question  # the episode already carries the question (begin_episode seeded it)

    db_path = _resolve_db_path(env, _active_db_id(env))

    observation = env._build_observation()  # current (initial) observation
    actions_taken = 0

    while not observation.done and actions_taken < budget:
        action = policy.select_action(observation)
        action_type = str(action.action_type).strip().upper()
        argument = str(action.argument)

        start = time.monotonic()
        observation = env.step(action)
        duration_s = time.monotonic() - start
        actions_taken += 1

        status = "error" if observation.error else "done"

        rows: list[tuple] = []
        columns: list[str] = []
        if action_type == "QUERY" and status == "done" and db_path is not None:
            # C4 caveat (non-deterministic SQL): we re-execute the model's SELECT
            # a SECOND time (read-only) purely to recover rows + column headers,
            # because the env formats QUERY results to text and discards the raw
            # tuples (+ cursor.description). For a non-deterministic SELECT
            # (RANDOM(), datetime('now'), unordered LIMIT) the re-read can differ
            # from what the agent answered on. Root cause is the env discarding
            # the tuples; a future env could expose the raw rows + description so
            # no re-exec is needed. Acceptable for the demo (queries are typically
            # deterministic aggregates); flagged here so it is not a silent trap.
            rows, columns = reexecute_select(db_path, argument)

        yield Step(
            action_type=action_type,
            argument=argument,
            result_text=observation.result,
            status=status,
            duration_s=duration_s,
            error=observation.error,
            rows=rows,
            columns=columns,
        )