analyst-buddy / server /agent_loop.py
hjerpe's picture
F006/F008: serve Qwen models + model switcher (vanilla-first)
656f91e verified
Raw
History Blame Contribute Delete
9.54 kB
"""Gradio-free streaming runner for one agent turn (F004).
The streaming sibling of ``evaluation.policies.evaluate()``: it drives the same
Protocol-driven ``select_action`` -> ``step`` loop against a real
``SQLEnvironment``, but **yields one frozen ``Step`` per action** instead of
accumulating aggregate metrics. The Gradio app (``app.py``) maps each ``Step`` to
a chat accordion, the result table, and the shown-SQL panel.
This module is **dependency-light by contract**: it imports NO gradio and NO
heavy training/serving deps (``torch``/``transformers``/``trl``). It is the engine
the demo rides on and is unit-testable headless against a scripted policy.
Append-only (ADR 0006): the runner never rewrites a prior ``Step``.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
import re
import sqlite3
import time
try:
from ..evaluation.policies import Policy
from .sql_environment import SQLEnvironment, resolve_db_path
except ImportError: # pragma: no cover - Docker/flat-layout fallback
try:
from evaluation.policies import Policy # type: ignore[no-redef]
from server.sql_environment import ( # type: ignore[no-redef]
SQLEnvironment,
resolve_db_path,
)
except ImportError:
from sql_env.evaluation.policies import Policy # type: ignore[no-redef]
from sql_env.server.sql_environment import ( # type: ignore[no-redef]
SQLEnvironment,
resolve_db_path,
)
_FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)")
@dataclass(frozen=True)
class Step:
"""One streamed event from the agent turn.
A frozen value dataclass mirroring ``EpisodeResult``/``IngestResult``: the
single shape the runner yields and the UI maps. The env never half-completes
a ``step()``, so each ``Step`` is yielded already resolved (``status`` is
``"done"`` or ``"error"``); the pending/done *visual* transition is a UI
concern, not a ``Step`` field.
"""
action_type: str # "DESCRIBE" | "SAMPLE" | "QUERY" | "ANSWER"
argument: str # SQL string (QUERY), table name (DESCRIBE/SAMPLE), answer (ANSWER)
result_text: str # the env's formatted result text for this step (obs.result)
status: str # "done" | "error"
duration_s: float # wall-clock seconds for this step's env.step() call
error: str = "" # obs.error for this step ("" when status == "done")
rows: list[tuple] = field(default_factory=list) # only on the shown-table QUERY
columns: list[str] = field(default_factory=list) # headers for `rows`; else []
def _active_db_id(env: SQLEnvironment) -> str:
"""The db_id of the active episode (for read-only re-exec path resolution)."""
episode = env._episode
if episode is None:
return ""
return episode.question_record.database_name
def _resolve_db_path(env: SQLEnvironment, db_id: str) -> str | None:
"""Resolve the active ``.sqlite`` path (as a str) for read-only re-exec.
Thin adapter over the shared ``sql_environment.resolve_db_path`` — the SAME
candidate resolution + ``.resolve()``/``db_root in candidate.parents``
containment guard the env's ``_open_db`` uses (R2: defense-in-depth) — so the
runner can open its OWN connection (the env's is busy) without re-implementing
the env's read-only execution. The ``db_id`` here is already pre-validated by
``env.begin_episode`` → ``_open_db`` (a `[A-Za-z0-9_]+` identifier rooted under
``db_dir``); the guard is shared (not duplicated) so this independent path can
never resolve a file outside ``db_dir``. Returns ``None`` if no contained file
resolves (the shown table then degrades gracefully).
"""
resolved = resolve_db_path(env.db_dir, db_id)
return str(resolved) if resolved is not None else None
def reexecute_select(
db_path: str,
sql: str,
*,
timeout_s: float = 5.0,
max_rows: int = 20,
) -> tuple[list[tuple], list[str]]:
"""Re-execute a single SELECT read-only to get real rows + column headers.
The env formats QUERY results into text and discards the raw tuples (and the
observation has no column headers), so the shown ``gr.Dataframe`` needs a
faithful re-read. SELECT/WITH-only, single statement, parameterless,
``mode=ro``, capped at ``max_rows``, with a ``timeout_s`` read deadline.
Opens and closes its OWN connection (the env is untouched). ``columns`` comes
from ``cursor.description``.
On ANY error (bad SQL, write attempt, missing table, empty SQL) returns
``([], [])`` — never raises to the caller; the answer + SQL are still shown.
"""
sql_stripped = sql.strip()
if not sql_stripped:
return [], []
first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
first_keyword = first_keyword_match.group(1).upper() if first_keyword_match else ""
if first_keyword not in ("SELECT", "WITH"):
return [], []
single_statement = sql_stripped.rstrip(";").strip()
if ";" in single_statement:
return [], []
connection: sqlite3.Connection | None = None
try:
connection = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
deadline = time.monotonic() + timeout_s
def _progress_callback() -> int:
return 1 if time.monotonic() > deadline else 0
connection.set_progress_handler(_progress_callback, 1000)
cursor = connection.cursor()
cursor.execute(single_statement)
rows = cursor.fetchmany(max_rows)
columns = [str(description[0]) for description in (cursor.description or [])]
return list(rows), columns
except sqlite3.Error:
return [], []
finally:
if connection is not None:
connection.close()
def run_agent_turn(
question: str,
env: SQLEnvironment,
policy: Policy,
*,
budget: int = 15,
) -> Iterator[Step]:
"""Drive ONE non-gold agent turn and yield a ``Step`` per action.
The streaming sibling of ``evaluate()``: mirrors ``action =
policy.select_action(obs); obs = env.step(action)`` but yields each resolved
step instead of accumulating metrics. Append-only (ADR 0006): never rewrites a
prior step. The caller seeds the episode via ``env.begin_episode(...)`` before
iterating; the loop here terminates on ``obs.done`` (ANSWER submitted or
budget exhausted) or after ``budget`` actions, whichever comes first.
For each successfully executed QUERY the runner re-executes the SELECT
read-only (``reexecute_select``) to attach real ``rows`` + ``columns`` to that
step (the env discards QUERY tuples and exposes no headers); the last such
QUERY is the one the UI shows in the table. Non-QUERY and failed-QUERY steps
carry empty ``rows``/``columns``.
Args:
question: the user's plain-English question (no gold answer exists).
env: a real ``SQLEnvironment`` with an active non-gold episode. Used only
for ``step()`` execution — never ``reset()``/gold scoring.
policy: any object satisfying the ``Policy`` Protocol (default app policy
is a deterministic stub; the real ``ModelPolicy`` swaps in at F006).
budget: max actions to drive (default 15, matching ``env.step_budget``).
Yields:
``Step`` instances in action order.
Raises:
Does not raise for model/SQL mistakes (those become ``status="error"``
steps). Setup errors (e.g. a missing DB from ``begin_episode``) are the
caller's responsibility and surface before iteration begins.
"""
del question # the episode already carries the question (begin_episode seeded it)
db_path = _resolve_db_path(env, _active_db_id(env))
observation = env._build_observation() # current (initial) observation
actions_taken = 0
while not observation.done and actions_taken < budget:
action = policy.select_action(observation)
action_type = str(action.action_type).strip().upper()
argument = str(action.argument)
start = time.monotonic()
observation = env.step(action)
duration_s = time.monotonic() - start
actions_taken += 1
status = "error" if observation.error else "done"
rows: list[tuple] = []
columns: list[str] = []
if action_type == "QUERY" and status == "done" and db_path is not None:
# C4 caveat (non-deterministic SQL): we re-execute the model's SELECT
# a SECOND time (read-only) purely to recover rows + column headers,
# because the env formats QUERY results to text and discards the raw
# tuples (+ cursor.description). For a non-deterministic SELECT
# (RANDOM(), datetime('now'), unordered LIMIT) the re-read can differ
# from what the agent answered on. Root cause is the env discarding
# the tuples; a future env could expose the raw rows + description so
# no re-exec is needed. Acceptable for the demo (queries are typically
# deterministic aggregates); flagged here so it is not a silent trap.
rows, columns = reexecute_select(db_path, argument)
yield Step(
action_type=action_type,
argument=argument,
result_text=observation.result,
status=status,
duration_s=duration_s,
error=observation.error,
rows=rows,
columns=columns,
)