Spaces:
Running on Zero
Running on Zero
| """Gradio-free streaming runner for one agent turn (F004). | |
| The streaming sibling of ``evaluation.policies.evaluate()``: it drives the same | |
| Protocol-driven ``select_action`` -> ``step`` loop against a real | |
| ``SQLEnvironment``, but **yields one frozen ``Step`` per action** instead of | |
| accumulating aggregate metrics. The Gradio app (``app.py``) maps each ``Step`` to | |
| a chat accordion, the result table, and the shown-SQL panel. | |
| This module is **dependency-light by contract**: it imports NO gradio and NO | |
| heavy training/serving deps (``torch``/``transformers``/``trl``). It is the engine | |
| the demo rides on and is unit-testable headless against a scripted policy. | |
| Append-only (ADR 0006): the runner never rewrites a prior ``Step``. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Iterator | |
| from dataclasses import dataclass, field | |
| import re | |
| import sqlite3 | |
| import time | |
| try: | |
| from ..evaluation.policies import Policy | |
| from .sql_environment import SQLEnvironment, resolve_db_path | |
| except ImportError: # pragma: no cover - Docker/flat-layout fallback | |
| try: | |
| from evaluation.policies import Policy # type: ignore[no-redef] | |
| from server.sql_environment import ( # type: ignore[no-redef] | |
| SQLEnvironment, | |
| resolve_db_path, | |
| ) | |
| except ImportError: | |
| from sql_env.evaluation.policies import Policy # type: ignore[no-redef] | |
| from sql_env.server.sql_environment import ( # type: ignore[no-redef] | |
| SQLEnvironment, | |
| resolve_db_path, | |
| ) | |
| _FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)") | |
| class Step: | |
| """One streamed event from the agent turn. | |
| A frozen value dataclass mirroring ``EpisodeResult``/``IngestResult``: the | |
| single shape the runner yields and the UI maps. The env never half-completes | |
| a ``step()``, so each ``Step`` is yielded already resolved (``status`` is | |
| ``"done"`` or ``"error"``); the pending/done *visual* transition is a UI | |
| concern, not a ``Step`` field. | |
| """ | |
| action_type: str # "DESCRIBE" | "SAMPLE" | "QUERY" | "ANSWER" | |
| argument: str # SQL string (QUERY), table name (DESCRIBE/SAMPLE), answer (ANSWER) | |
| result_text: str # the env's formatted result text for this step (obs.result) | |
| status: str # "done" | "error" | |
| duration_s: float # wall-clock seconds for this step's env.step() call | |
| error: str = "" # obs.error for this step ("" when status == "done") | |
| rows: list[tuple] = field(default_factory=list) # only on the shown-table QUERY | |
| columns: list[str] = field(default_factory=list) # headers for `rows`; else [] | |
| def _active_db_id(env: SQLEnvironment) -> str: | |
| """The db_id of the active episode (for read-only re-exec path resolution).""" | |
| episode = env._episode | |
| if episode is None: | |
| return "" | |
| return episode.question_record.database_name | |
| def _resolve_db_path(env: SQLEnvironment, db_id: str) -> str | None: | |
| """Resolve the active ``.sqlite`` path (as a str) for read-only re-exec. | |
| Thin adapter over the shared ``sql_environment.resolve_db_path`` — the SAME | |
| candidate resolution + ``.resolve()``/``db_root in candidate.parents`` | |
| containment guard the env's ``_open_db`` uses (R2: defense-in-depth) — so the | |
| runner can open its OWN connection (the env's is busy) without re-implementing | |
| the env's read-only execution. The ``db_id`` here is already pre-validated by | |
| ``env.begin_episode`` → ``_open_db`` (a `[A-Za-z0-9_]+` identifier rooted under | |
| ``db_dir``); the guard is shared (not duplicated) so this independent path can | |
| never resolve a file outside ``db_dir``. Returns ``None`` if no contained file | |
| resolves (the shown table then degrades gracefully). | |
| """ | |
| resolved = resolve_db_path(env.db_dir, db_id) | |
| return str(resolved) if resolved is not None else None | |
| def reexecute_select( | |
| db_path: str, | |
| sql: str, | |
| *, | |
| timeout_s: float = 5.0, | |
| max_rows: int = 20, | |
| ) -> tuple[list[tuple], list[str]]: | |
| """Re-execute a single SELECT read-only to get real rows + column headers. | |
| The env formats QUERY results into text and discards the raw tuples (and the | |
| observation has no column headers), so the shown ``gr.Dataframe`` needs a | |
| faithful re-read. SELECT/WITH-only, single statement, parameterless, | |
| ``mode=ro``, capped at ``max_rows``, with a ``timeout_s`` read deadline. | |
| Opens and closes its OWN connection (the env is untouched). ``columns`` comes | |
| from ``cursor.description``. | |
| On ANY error (bad SQL, write attempt, missing table, empty SQL) returns | |
| ``([], [])`` — never raises to the caller; the answer + SQL are still shown. | |
| """ | |
| sql_stripped = sql.strip() | |
| if not sql_stripped: | |
| return [], [] | |
| first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped) | |
| first_keyword = first_keyword_match.group(1).upper() if first_keyword_match else "" | |
| if first_keyword not in ("SELECT", "WITH"): | |
| return [], [] | |
| single_statement = sql_stripped.rstrip(";").strip() | |
| if ";" in single_statement: | |
| return [], [] | |
| connection: sqlite3.Connection | None = None | |
| try: | |
| connection = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) | |
| deadline = time.monotonic() + timeout_s | |
| def _progress_callback() -> int: | |
| return 1 if time.monotonic() > deadline else 0 | |
| connection.set_progress_handler(_progress_callback, 1000) | |
| cursor = connection.cursor() | |
| cursor.execute(single_statement) | |
| rows = cursor.fetchmany(max_rows) | |
| columns = [str(description[0]) for description in (cursor.description or [])] | |
| return list(rows), columns | |
| except sqlite3.Error: | |
| return [], [] | |
| finally: | |
| if connection is not None: | |
| connection.close() | |
| def run_agent_turn( | |
| question: str, | |
| env: SQLEnvironment, | |
| policy: Policy, | |
| *, | |
| budget: int = 15, | |
| ) -> Iterator[Step]: | |
| """Drive ONE non-gold agent turn and yield a ``Step`` per action. | |
| The streaming sibling of ``evaluate()``: mirrors ``action = | |
| policy.select_action(obs); obs = env.step(action)`` but yields each resolved | |
| step instead of accumulating metrics. Append-only (ADR 0006): never rewrites a | |
| prior step. The caller seeds the episode via ``env.begin_episode(...)`` before | |
| iterating; the loop here terminates on ``obs.done`` (ANSWER submitted or | |
| budget exhausted) or after ``budget`` actions, whichever comes first. | |
| For each successfully executed QUERY the runner re-executes the SELECT | |
| read-only (``reexecute_select``) to attach real ``rows`` + ``columns`` to that | |
| step (the env discards QUERY tuples and exposes no headers); the last such | |
| QUERY is the one the UI shows in the table. Non-QUERY and failed-QUERY steps | |
| carry empty ``rows``/``columns``. | |
| Args: | |
| question: the user's plain-English question (no gold answer exists). | |
| env: a real ``SQLEnvironment`` with an active non-gold episode. Used only | |
| for ``step()`` execution — never ``reset()``/gold scoring. | |
| policy: any object satisfying the ``Policy`` Protocol (default app policy | |
| is a deterministic stub; the real ``ModelPolicy`` swaps in at F006). | |
| budget: max actions to drive (default 15, matching ``env.step_budget``). | |
| Yields: | |
| ``Step`` instances in action order. | |
| Raises: | |
| Does not raise for model/SQL mistakes (those become ``status="error"`` | |
| steps). Setup errors (e.g. a missing DB from ``begin_episode``) are the | |
| caller's responsibility and surface before iteration begins. | |
| """ | |
| del question # the episode already carries the question (begin_episode seeded it) | |
| db_path = _resolve_db_path(env, _active_db_id(env)) | |
| observation = env._build_observation() # current (initial) observation | |
| actions_taken = 0 | |
| while not observation.done and actions_taken < budget: | |
| action = policy.select_action(observation) | |
| action_type = str(action.action_type).strip().upper() | |
| argument = str(action.argument) | |
| start = time.monotonic() | |
| observation = env.step(action) | |
| duration_s = time.monotonic() - start | |
| actions_taken += 1 | |
| status = "error" if observation.error else "done" | |
| rows: list[tuple] = [] | |
| columns: list[str] = [] | |
| if action_type == "QUERY" and status == "done" and db_path is not None: | |
| # C4 caveat (non-deterministic SQL): we re-execute the model's SELECT | |
| # a SECOND time (read-only) purely to recover rows + column headers, | |
| # because the env formats QUERY results to text and discards the raw | |
| # tuples (+ cursor.description). For a non-deterministic SELECT | |
| # (RANDOM(), datetime('now'), unordered LIMIT) the re-read can differ | |
| # from what the agent answered on. Root cause is the env discarding | |
| # the tuples; a future env could expose the raw rows + description so | |
| # no re-exec is needed. Acceptable for the demo (queries are typically | |
| # deterministic aggregates); flagged here so it is not a silent trap. | |
| rows, columns = reexecute_select(db_path, argument) | |
| yield Step( | |
| action_type=action_type, | |
| argument=argument, | |
| result_text=observation.result, | |
| status=status, | |
| duration_s=duration_s, | |
| error=observation.error, | |
| rows=rows, | |
| columns=columns, | |
| ) | |