Spaces:
Running on Zero
Running on Zero
File size: 9,539 Bytes
656f91e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | """Gradio-free streaming runner for one agent turn (F004).
The streaming sibling of ``evaluation.policies.evaluate()``: it drives the same
Protocol-driven ``select_action`` -> ``step`` loop against a real
``SQLEnvironment``, but **yields one frozen ``Step`` per action** instead of
accumulating aggregate metrics. The Gradio app (``app.py``) maps each ``Step`` to
a chat accordion, the result table, and the shown-SQL panel.
This module is **dependency-light by contract**: it imports NO gradio and NO
heavy training/serving deps (``torch``/``transformers``/``trl``). It is the engine
the demo rides on and is unit-testable headless against a scripted policy.
Append-only (ADR 0006): the runner never rewrites a prior ``Step``.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
import re
import sqlite3
import time
try:
from ..evaluation.policies import Policy
from .sql_environment import SQLEnvironment, resolve_db_path
except ImportError: # pragma: no cover - Docker/flat-layout fallback
try:
from evaluation.policies import Policy # type: ignore[no-redef]
from server.sql_environment import ( # type: ignore[no-redef]
SQLEnvironment,
resolve_db_path,
)
except ImportError:
from sql_env.evaluation.policies import Policy # type: ignore[no-redef]
from sql_env.server.sql_environment import ( # type: ignore[no-redef]
SQLEnvironment,
resolve_db_path,
)
_FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)")
@dataclass(frozen=True)
class Step:
"""One streamed event from the agent turn.
A frozen value dataclass mirroring ``EpisodeResult``/``IngestResult``: the
single shape the runner yields and the UI maps. The env never half-completes
a ``step()``, so each ``Step`` is yielded already resolved (``status`` is
``"done"`` or ``"error"``); the pending/done *visual* transition is a UI
concern, not a ``Step`` field.
"""
action_type: str # "DESCRIBE" | "SAMPLE" | "QUERY" | "ANSWER"
argument: str # SQL string (QUERY), table name (DESCRIBE/SAMPLE), answer (ANSWER)
result_text: str # the env's formatted result text for this step (obs.result)
status: str # "done" | "error"
duration_s: float # wall-clock seconds for this step's env.step() call
error: str = "" # obs.error for this step ("" when status == "done")
rows: list[tuple] = field(default_factory=list) # only on the shown-table QUERY
columns: list[str] = field(default_factory=list) # headers for `rows`; else []
def _active_db_id(env: SQLEnvironment) -> str:
"""The db_id of the active episode (for read-only re-exec path resolution)."""
episode = env._episode
if episode is None:
return ""
return episode.question_record.database_name
def _resolve_db_path(env: SQLEnvironment, db_id: str) -> str | None:
"""Resolve the active ``.sqlite`` path (as a str) for read-only re-exec.
Thin adapter over the shared ``sql_environment.resolve_db_path`` — the SAME
candidate resolution + ``.resolve()``/``db_root in candidate.parents``
containment guard the env's ``_open_db`` uses (R2: defense-in-depth) — so the
runner can open its OWN connection (the env's is busy) without re-implementing
the env's read-only execution. The ``db_id`` here is already pre-validated by
``env.begin_episode`` → ``_open_db`` (a `[A-Za-z0-9_]+` identifier rooted under
``db_dir``); the guard is shared (not duplicated) so this independent path can
never resolve a file outside ``db_dir``. Returns ``None`` if no contained file
resolves (the shown table then degrades gracefully).
"""
resolved = resolve_db_path(env.db_dir, db_id)
return str(resolved) if resolved is not None else None
def reexecute_select(
db_path: str,
sql: str,
*,
timeout_s: float = 5.0,
max_rows: int = 20,
) -> tuple[list[tuple], list[str]]:
"""Re-execute a single SELECT read-only to get real rows + column headers.
The env formats QUERY results into text and discards the raw tuples (and the
observation has no column headers), so the shown ``gr.Dataframe`` needs a
faithful re-read. SELECT/WITH-only, single statement, parameterless,
``mode=ro``, capped at ``max_rows``, with a ``timeout_s`` read deadline.
Opens and closes its OWN connection (the env is untouched). ``columns`` comes
from ``cursor.description``.
On ANY error (bad SQL, write attempt, missing table, empty SQL) returns
``([], [])`` — never raises to the caller; the answer + SQL are still shown.
"""
sql_stripped = sql.strip()
if not sql_stripped:
return [], []
first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
first_keyword = first_keyword_match.group(1).upper() if first_keyword_match else ""
if first_keyword not in ("SELECT", "WITH"):
return [], []
single_statement = sql_stripped.rstrip(";").strip()
if ";" in single_statement:
return [], []
connection: sqlite3.Connection | None = None
try:
connection = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
deadline = time.monotonic() + timeout_s
def _progress_callback() -> int:
return 1 if time.monotonic() > deadline else 0
connection.set_progress_handler(_progress_callback, 1000)
cursor = connection.cursor()
cursor.execute(single_statement)
rows = cursor.fetchmany(max_rows)
columns = [str(description[0]) for description in (cursor.description or [])]
return list(rows), columns
except sqlite3.Error:
return [], []
finally:
if connection is not None:
connection.close()
def run_agent_turn(
question: str,
env: SQLEnvironment,
policy: Policy,
*,
budget: int = 15,
) -> Iterator[Step]:
"""Drive ONE non-gold agent turn and yield a ``Step`` per action.
The streaming sibling of ``evaluate()``: mirrors ``action =
policy.select_action(obs); obs = env.step(action)`` but yields each resolved
step instead of accumulating metrics. Append-only (ADR 0006): never rewrites a
prior step. The caller seeds the episode via ``env.begin_episode(...)`` before
iterating; the loop here terminates on ``obs.done`` (ANSWER submitted or
budget exhausted) or after ``budget`` actions, whichever comes first.
For each successfully executed QUERY the runner re-executes the SELECT
read-only (``reexecute_select``) to attach real ``rows`` + ``columns`` to that
step (the env discards QUERY tuples and exposes no headers); the last such
QUERY is the one the UI shows in the table. Non-QUERY and failed-QUERY steps
carry empty ``rows``/``columns``.
Args:
question: the user's plain-English question (no gold answer exists).
env: a real ``SQLEnvironment`` with an active non-gold episode. Used only
for ``step()`` execution — never ``reset()``/gold scoring.
policy: any object satisfying the ``Policy`` Protocol (default app policy
is a deterministic stub; the real ``ModelPolicy`` swaps in at F006).
budget: max actions to drive (default 15, matching ``env.step_budget``).
Yields:
``Step`` instances in action order.
Raises:
Does not raise for model/SQL mistakes (those become ``status="error"``
steps). Setup errors (e.g. a missing DB from ``begin_episode``) are the
caller's responsibility and surface before iteration begins.
"""
del question # the episode already carries the question (begin_episode seeded it)
db_path = _resolve_db_path(env, _active_db_id(env))
observation = env._build_observation() # current (initial) observation
actions_taken = 0
while not observation.done and actions_taken < budget:
action = policy.select_action(observation)
action_type = str(action.action_type).strip().upper()
argument = str(action.argument)
start = time.monotonic()
observation = env.step(action)
duration_s = time.monotonic() - start
actions_taken += 1
status = "error" if observation.error else "done"
rows: list[tuple] = []
columns: list[str] = []
if action_type == "QUERY" and status == "done" and db_path is not None:
# C4 caveat (non-deterministic SQL): we re-execute the model's SELECT
# a SECOND time (read-only) purely to recover rows + column headers,
# because the env formats QUERY results to text and discards the raw
# tuples (+ cursor.description). For a non-deterministic SELECT
# (RANDOM(), datetime('now'), unordered LIMIT) the re-read can differ
# from what the agent answered on. Root cause is the env discarding
# the tuples; a future env could expose the raw rows + description so
# no re-exec is needed. Acceptable for the demo (queries are typically
# deterministic aggregates); flagged here so it is not a silent trap.
rows, columns = reexecute_select(db_path, argument)
yield Step(
action_type=action_type,
argument=argument,
result_text=observation.result,
status=status,
duration_s=duration_s,
error=observation.error,
rows=rows,
columns=columns,
)
|