sql-drift-env / server /sql_drift_env_environment.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""OpenEnv ``Environment`` implementation for SQLDrift.
Responsibilities:
* Own the private :class:`engine.runtime.RuntimeEpisodeState` and the
composite :class:`engine.reward.SqlDriftRubric` for the current episode.
* Dispatch each of the eight tool-call payloads to a dedicated
``_handle_<tool>`` method that returns a typed
:class:`models.ToolResult` (or :class:`models.ToolError`).
* Fire drift on a schedule blended with a cooldown: ``max(scheduled,
first_run_query_step + cooldown)`` before the agent acts on the step
where drift applies, then recompute the post-drift ground truth hash.
* Publish public observations (:class:`models.SqlDriftObservation`) and a
strictly sanitised public state snapshot (:class:`models.SqlDriftState`).
Privacy: ``self._runtime`` holds the DuckDB handle, ground-truth hashes,
baseline runtime, and seed. They stay inside this class; the rubric reads
them via a closure, and ``env.state`` exposes only a fixed whitelist of fields.
"""
from __future__ import annotations
import contextlib
import math
import re
import secrets
from random import Random
from typing import TYPE_CHECKING, Any, Literal
import duckdb
import sqlglot
from openenv.core.env_server.interfaces import Environment
from pydantic import BaseModel, ConfigDict, Field
from actors import dba_oracle
from actors.engineering_manager import author_changelog
from engine.drift import apply_drift
from engine.profiler import (
QueryWatchdogEscalationError,
execute_hash_timed,
execute_once_timed,
execute_once_with_columns,
)
from engine.reward import (
SPEEDUP_CAP_FOR_INFTY,
STEP_REBATE_DESCRIBE_TABLE,
STEP_REBATE_EXPLAIN_QUERY,
STEP_REBATE_LIST_TABLES,
STEP_REBATE_READ_CHANGELOG,
STEP_REBATE_RUN_QUERY,
STEP_REBATE_SAMPLE_ROWS,
SqlDriftRubric,
canonicalize_sql,
effective_speedup,
)
from engine.runtime import RuntimeEpisodeState
from engine.verifier import canonical_row_hash
from models import (
REWARD_COMPONENT_KEYS,
ConsultDBAPayload,
ConsultDBAResult,
DescribeTablePayload,
DescribeTableResult,
EpisodePhase,
ExplainQueryPayload,
ExplainQueryResult,
ListTablesPayload,
ListTablesResult,
ReadChangelogPayload,
ReadChangelogResult,
RunQueryPayload,
RunQueryResult,
SampleRowsPayload,
SampleRowsResult,
SqlDriftAction,
SqlDriftObservation,
SqlDriftState,
SubmitRewritePayload,
SubmitRewriteResult,
ToolError,
ToolErrorCode,
ToolResult,
)
from scenarios import REGISTRY, get_spec
from skill_library import PlaybookEntry, Store, load_all, retrieve
from utilities.logger import get_module_logger, log_env_reset, log_env_step, log_interaction
from . import settings
if TYPE_CHECKING:
from scenarios.base import ScenarioSpec
_LOG = get_module_logger(__name__)
DEFAULT_STEP_BUDGET: int = settings.DEFAULT_STEP_BUDGET
MAX_RESULT_ROWS: int = settings.MAX_RESULT_ROWS
QUERY_TIMEOUT_S: float = settings.QUERY_TIMEOUT_S
class _ResetOptions(BaseModel):
model_config = ConfigDict(extra="ignore")
scenario_id: str | None = None
enable_dba_oracle: bool | None = None
difficulty: Literal["easy", "normal", "hard"] = "normal"
budget_steps: int = Field(default=DEFAULT_STEP_BUDGET, ge=1)
_READ_ONLY_EXPRESSION_KEYS: frozenset[str] = frozenset({"select", "with"})
# DuckDB exposes a family of table-valued functions and scalar helpers
# that read from the host filesystem or leak introspection state —
# ``read_csv``, ``read_parquet``, ``read_json``, ``read_text``,
# ``parquet_metadata``, ``duckdb_secrets``, ``glob``, etc. They are
# *technically* SELECT-shaped calls so the statement-key check alone
# admits them. We reject any function whose lowercased name starts with
# one of these prefixes or exactly matches one of the known-dangerous
# standalone names. Agent-facing SQL has no legitimate need for any of
# them — the DuckDB connection is pre-populated by the scenario builder.
_DENYLIST_PREFIXES: tuple[str, ...] = (
"read_",
"write_",
"copy_",
"duckdb_",
"pragma_",
"sniff_",
"parquet_",
"arrow_",
"json_table",
"json_each",
"sqlite_",
"load_",
"install_",
)
_DENYLIST_EXACT: frozenset[str] = frozenset(
{
"glob",
"attach",
"detach",
"checkpoint",
"force_checkpoint",
"set_secret",
"create_secret",
"drop_secret",
"enable_profiling",
"disable_profiling",
"enable_object_cache",
}
)
def _is_denylisted_function_name(name: str) -> bool:
"""Return True iff ``name`` (case-insensitively) matches a sandbox-escape."""
lowered = name.lower()
if lowered in _DENYLIST_EXACT:
return True
return any(lowered.startswith(p) for p in _DENYLIST_PREFIXES)
def _function_names(node: sqlglot.exp.Func) -> list[str]:
"""All plausible names to check against the denylist for one AST node.
sqlglot lowers a few DuckDB calls into dedicated expression classes
(``ReadCSV``, ``ReadParquet``, …) whose ``.name`` is actually the
first positional arg — the file path — not the function name. We
recover the function name from the class name in that case and fall
back to ``.name`` for the ``Anonymous`` form that covers everything
else. Including both lets one denylist lookup cover both lowerings.
"""
cls = type(node).__name__
out: list[str] = []
# Derive a snake-case function name from the class name. We insert
# an underscore at two kinds of CamelCase boundaries:
#
# * ``aB`` — normal lower-to-upper (``ReadParquet`` → ``read_parquet``)
# * ``ABc`` — end of an acronym run (``ReadCSVAuto`` → ``read_csv_auto``)
#
# Purely-lowercase class names (``Anonymous``) produce no prefix
# match; we fall through to ``.name`` below for those.
if cls and cls[0].isupper():
snake = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", cls).lower()
out.append(snake)
name_attr = getattr(node, "name", None)
if isinstance(name_attr, str) and name_attr:
out.append(name_attr)
return out
_VALID_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _resolve_timeout_s(timeout_s: float | None) -> float:
"""Caller-supplied per-step timeout or the module default.
``timeout_s`` is accepted on every OpenEnv ``step()`` (the abstract
base mandates the keyword). When the caller provides a positive
value we honour it as the wall-clock budget for any DuckDB query
this step runs; ``None`` and non-positive values fall back to the
module-level :data:`QUERY_TIMEOUT_S` so a mis-configured client
cannot silently disable the watchdog.
"""
if timeout_s is None or timeout_s <= 0:
return QUERY_TIMEOUT_S
return float(timeout_s)
def _initial_schema_synopsis(spec: ScenarioSpec, synopsis: str) -> str:
"""Reset-time synopsis with future drift details removed.
Drift scenarios should not reveal the exact schema/business-rule
change before the changelog is published at runtime. We therefore
trim the authored synopsis at the first ``" Under drift"`` clause on
reset and only surface the pre-drift schema shape.
"""
if spec.drift_config is None:
return synopsis
predrift, marker, _ = synopsis.partition(" Under drift")
return predrift if marker else synopsis
def _validate_read_only_sql(sql: str) -> None:
"""Reject anything that isn't a single-statement read-only SELECT/CTE.
Raises ``ValueError`` so the caller can translate to a typed
:class:`models.ToolError` with :attr:`ToolErrorCode.INVALID_TOOL_ARGUMENT`.
This is the only place that mediates what the policy may execute;
scenario builders and drift DDL call DuckDB directly with privileged
SQL and deliberately bypass this check.
Beyond the statement-level gate, this walker also rejects two
sandbox-escape vectors that would otherwise ride along inside a
perfectly-shaped SELECT:
1. Table-valued functions that read from the host filesystem
(``read_csv``, ``read_parquet``, ``read_json_auto``, ``glob``,
``read_text``, …) or leak engine introspection (``duckdb_secrets``
carries credentials; ``duckdb_settings`` /``duckdb_functions``
can enumerate available exploits). See :data:`_DENYLIST_PREFIXES`
/ :data:`_DENYLIST_EXACT`.
2. ``SELECT * FROM 'path/to/x.csv'`` — DuckDB treats a bare string
literal in a FROM clause as a filesystem path and auto-detects
the format. There is no function node to inspect in this form,
so we separately reject any :class:`sqlglot.exp.Table` whose
backing expression is a string literal.
"""
try:
statements = sqlglot.parse(sql, dialect="duckdb")
except sqlglot.errors.ParseError as exc:
raise ValueError(f"SQL failed to parse: {exc}") from exc
non_empty = [s for s in statements if s is not None]
if len(non_empty) != 1:
raise ValueError("multi-statement SQL is not allowed; submit one SELECT")
expr = non_empty[0]
if expr.key not in _READ_ONLY_EXPRESSION_KEYS:
raise ValueError(
f"only read-only SELECT/CTE queries are allowed (got {expr.key.upper()} statement)"
)
for node in expr.walk():
# (1) Function-valued sandbox escapes. Inspect both the class
# name (catches ``ReadCSV`` / ``ReadParquet`` lowerings where
# ``.name`` holds the file path, not the function name) and
# ``.name`` (catches the generic ``Anonymous`` form).
if isinstance(node, sqlglot.exp.Func):
for fn_name in _function_names(node):
if _is_denylisted_function_name(fn_name):
raise ValueError(
f"function {fn_name!r} is not allowed — agent-facing SQL may "
"only touch the scenario's in-memory tables"
)
# (2) Bare-path FROM form: ``SELECT * FROM 'x.csv'`` or
# ``SELECT * FROM '/etc/passwd'``. sqlglot normalises both
# single- and double-quoted identifiers to
# ``Identifier(quoted=True)``, so we can't rely on the quote
# flavour to distinguish a file path from a legitimately-quoted
# table name. Instead we require every agent-facing table name
# to be a valid unquoted SQL identifier — the scenarios never
# emit anything else, and paths always contain ``/``, ``.`` or
# ``~`` which fail the identifier regex.
if isinstance(node, sqlglot.exp.Table):
inner = node.this
if isinstance(inner, sqlglot.exp.Identifier):
ident_name = inner.name
if ident_name and not _VALID_IDENTIFIER_RE.match(ident_name):
raise ValueError(
f"table identifier {ident_name!r} is not a valid unquoted SQL "
"name — reading from file paths or other engine-specific "
"resources is not allowed"
)
class SqlDriftEnvironment(Environment[SqlDriftAction, SqlDriftObservation, SqlDriftState]):
"""OpenEnv environment for SQL repair + optimization under schema drift."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(
self,
skill_store: Store | None = None,
cleanup_on_close: bool = False,
) -> None:
self._runtime: RuntimeEpisodeState | None = None
self._skill_store: Store | None = skill_store
# When True, the skill-store directory is deleted when close() is called.
# Set this for server-managed per-session stores so disk usage doesn't grow
# monotonically; see design/codereview.md (session store issue).
self._cleanup_on_close: bool = cleanup_on_close
super().__init__(
rubric=SqlDriftRubric(ctx_provider=lambda: self._require_runtime()),
)
# ------------------------------------------------------------------
# OpenEnv contract
# ------------------------------------------------------------------
@log_env_reset
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: Any,
) -> SqlDriftObservation:
options = _ResetOptions.model_validate(kwargs)
scenario_id = options.scenario_id
enable_dba_oracle = dba_oracle.is_enabled(options.enable_dba_oracle)
difficulty = options.difficulty
budget_steps = options.budget_steps
if seed is None:
seed = secrets.randbits(31)
if episode_id is None:
episode_id = f"ep-{seed:08x}"
if scenario_id is None:
scenario_id = self._pick_scenario_for_seed(seed)
spec = get_spec(scenario_id)
instance = spec.materialize(seed, difficulty=difficulty)
drift_scheduled_step: int | None = None
if instance.drift_config is not None:
drift_scheduled_step = Random(seed).randint(
instance.drift_config.min_step,
instance.drift_config.max_step,
)
self._close_existing_runtime()
self._runtime = RuntimeEpisodeState(
episode_id=episode_id,
seed=seed,
scenario_id=scenario_id,
instance=instance,
conn=instance.conn,
gt_result_hash_predrift=instance.gt_result_hash_predrift,
gt_result_hash_postdrift=None,
baseline_runtime_ms=instance.baseline_runtime_ms,
baseline_tokens=instance.baseline_tokens,
baseline_sql_canonical=canonicalize_sql(instance.baseline_sql),
baseline_postdrift_raises=False,
drift_scheduled_step=drift_scheduled_step,
budget_steps=budget_steps,
dba_oracle_enabled=enable_dba_oracle,
)
self._reset_rubric()
learned_hints = kwargs.get("learned_hints")
if learned_hints is None:
learned_hints = self._render_learned_hints(spec, include_drift_cards=False)
if len(learned_hints) > 800:
learned_hints = learned_hints[:800]
rt = self._require_runtime()
return SqlDriftObservation(
step=0,
phase=EpisodePhase.DIAGNOSE,
last_tool=None,
tool_result=None,
drift_fired=False,
drift_acknowledged=False,
learned_hints=learned_hints,
baseline_sql=instance.baseline_sql,
schema_synopsis=_initial_schema_synopsis(spec, instance.schema_synopsis),
budget_steps_remaining=rt.budget_steps_remaining,
reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS},
done=False,
reward=None,
)
@log_env_step
def step(
self,
action: SqlDriftAction,
timeout_s: float | None = None,
**kwargs: Any,
) -> SqlDriftObservation:
rt = self._require_runtime()
if rt.submitted or rt.budget_steps_remaining <= 0:
raise ValueError("Episode is already finished; call reset() to start a new episode.")
rt.step_count += 1
rt.last_step_was_tool_error = False
rt.last_step_was_repeat_failing_query = False
rt.last_step_repeat_failing_query_count = 0
rt.last_step_productive_rebate = 0.0
self._maybe_fire_drift()
effective_timeout_s = _resolve_timeout_s(timeout_s)
try:
tool_result = self._dispatch(action, timeout_s=effective_timeout_s)
except QueryWatchdogEscalationError:
rt.connection_poisoned = True
rt.phase = EpisodePhase.FINALIZE
rt.step_count = max(rt.step_count, rt.budget_steps)
_LOG.error("episode %s aborted after watchdog escalation", rt.episode_id)
raise
rt.last_step_was_tool_error = isinstance(tool_result, ToolError)
if rt.last_step_was_tool_error:
rt.consecutive_tool_errors += 1
else:
rt.consecutive_tool_errors = 0
done = rt.submitted or rt.budget_steps_remaining <= 0
obs = SqlDriftObservation(
step=rt.step_count,
phase=rt.phase,
last_tool=action.tool,
tool_result=tool_result,
drift_fired=rt.drift_fired,
drift_acknowledged=rt.drift_acknowledged,
learned_hints="",
baseline_sql="",
schema_synopsis="",
budget_steps_remaining=rt.budget_steps_remaining,
reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS},
done=done,
reward=None,
)
if rt.drift_acknowledged:
spec = get_spec(rt.scenario_id)
obs.learned_hints = self._render_learned_hints(spec, include_drift_cards=True)
obs.reward = self._apply_rubric(action, obs)
if self.rubric is not None:
obs.reward_components = self.rubric.component_scores()
if done and rt.submitted:
self._maybe_persist_learned_entry()
return obs
def render(self) -> dict[str, Any]:
"""Render the current public state and log the render interaction."""
rt = self._require_runtime()
state = self.state
payload = state.model_dump(mode="json")
log_interaction(
event_type="render",
agent_id=rt.episode_id,
observation_returned=payload,
done=rt.submitted or rt.budget_steps_remaining <= 0,
)
return payload
@property
def state(self) -> SqlDriftState:
"""Sanitised public state snapshot (explicit whitelist)."""
rt = self._require_runtime()
return SqlDriftState(
episode_id=rt.episode_id,
step_count=rt.step_count,
scenario_id=rt.scenario_id,
phase=rt.phase,
budget_steps_remaining=rt.budget_steps_remaining,
drift_fired=rt.drift_fired,
consultations_used=rt.consultations_used,
submitted=rt.submitted,
)
def effective_speedup(self) -> float | None:
"""Return the current episode's effective speedup, if any."""
rt = self._runtime
if rt is None:
return None
return effective_speedup(rt)
def close(self) -> None:
self._close_existing_runtime()
if self._cleanup_on_close and self._skill_store is not None:
import shutil
store_dir = self._skill_store.dir
shutil.rmtree(store_dir, ignore_errors=True)
# ------------------------------------------------------------------
# Skill-library wiring
# ------------------------------------------------------------------
def _render_learned_hints(self, spec: ScenarioSpec, *, include_drift_cards: bool = True) -> str:
playbook, drift_cards = load_all(self._skill_store)
drift_kind = None
if include_drift_cards and spec.drift_config is not None:
drift_kind = spec.drift_config.kind
result = retrieve(
query_tags=spec.tags,
drift_kind=drift_kind,
playbook=playbook,
drift_cards=drift_cards,
)
return result.render(max_chars=800)
def _maybe_persist_learned_entry(self) -> None:
"""Append a PlaybookEntry on terminal success with a meaningful speedup.
Failures to persist are logged but never re-raised: a training
rollout should not crash because the on-disk playbook is under
contention. The skill store itself is crash-safe (atomic writes
+ file-lock) so at-most-once semantics are sufficient here.
"""
if self._skill_store is None:
return
rt = self._require_runtime()
if not rt.submitted:
return
if self.rubric is None:
return
scores = self.rubric.component_scores()
if scores.get("r_correct", 0.0) < 1.0:
return
spec = get_spec(rt.scenario_id)
raw_speedup = effective_speedup(rt)
# effective_speedup cannot return None here — rt.submitted is True
# so submitted_runtime_ms is populated — but we guard defensively.
# ``+∞`` (drift invalidated the baseline) is capped so the on-disk
# playbook doesn't serialize ``Infinity``, which would round-trip
# as a JSON parse error on load.
if raw_speedup is None or math.isinf(raw_speedup):
speedup_val = float(SPEEDUP_CAP_FOR_INFTY)
else:
speedup_val = float(raw_speedup)
entry = PlaybookEntry(
tag_set=spec.tags,
before_snippet=rt.instance.baseline_sql[:200],
after_snippet=(rt.submitted_sql or "")[:200],
avg_speedup=speedup_val,
scenario_family=spec.family,
source="learned",
)
try:
self._skill_store.append_playbook(entry)
except Exception as exc:
_LOG.warning("skill-library append_playbook failed: %s", exc)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _grant_step_rebate_once(self, *, attr: str, rebate: float) -> None:
rt = self._require_runtime()
if getattr(rt, attr):
return
setattr(rt, attr, True)
rt.last_step_productive_rebate += rebate
def _grant_step_rebate_for_table(
self, *, rewarded_tables_attr: str, table: str, rebate: float
) -> None:
rt = self._require_runtime()
rewarded = getattr(rt, rewarded_tables_attr)
if table in rewarded:
return
rewarded.add(table)
rt.last_step_productive_rebate += rebate
@staticmethod
def _pick_scenario_for_seed(seed: int) -> str:
"""Deterministic round-robin over the sorted scenario registry."""
ids = sorted(REGISTRY)
if not ids:
raise RuntimeError("no scenarios registered")
return ids[seed % len(ids)]
def _require_runtime(self) -> RuntimeEpisodeState:
if self._runtime is None:
raise RuntimeError("SqlDriftEnvironment.reset() must be called before step()/state.")
return self._runtime
def _close_existing_runtime(self) -> None:
if self._runtime is not None:
if self._runtime.connection_poisoned:
_LOG.error(
"skipping close for poisoned DuckDB connection in episode %s",
self._runtime.episode_id,
)
else:
with contextlib.suppress(duckdb.Error):
self._runtime.conn.close()
self._runtime = None
def _maybe_fire_drift(self) -> None:
"""Apply drift when the step index crosses the schedule/cooldown threshold."""
rt = self._require_runtime()
if rt.drift_fired:
return
if rt.drift_scheduled_step is None:
return
if rt.first_run_query_step is None:
return
cfg = rt.instance.drift_config
assert cfg is not None
minimum = max(rt.drift_scheduled_step, rt.first_run_query_step + cfg.cooldown_steps)
if rt.step_count < minimum:
return
self._fire_drift()
def _fire_drift(self) -> None:
"""Apply drift, author a changelog, and resolve the post-drift GT hash.
Failure to recompute the post-drift GT hash is an authoring bug
(the scenario's ``gt_sql_postdrift`` must execute against the
just-mutated DB) and we re-raise loudly so it cannot silently
make every post-drift submission score ``r_correct=0``.
"""
rt = self._require_runtime()
cfg = rt.instance.drift_config
assert cfg is not None
apply_drift(rt.conn, cfg.kind, cfg.payload)
rt.drift_fired_step = rt.step_count
rt.phase = EpisodePhase.DRIFT_RECOVERY
rt.changelog_entries.append(author_changelog(cfg))
try:
rt.conn.execute(rt.instance.baseline_sql).fetchall()
rt.baseline_postdrift_raises = False
except duckdb.Error:
rt.baseline_postdrift_raises = True
if rt.instance.gt_sql_postdrift is not None:
try:
rows = rt.conn.execute(rt.instance.gt_sql_postdrift).fetchall()
except duckdb.Error as exc:
raise RuntimeError(
f"scenario {rt.scenario_id!r}: authored gt_sql_postdrift failed "
f"after drift: {exc}"
) from exc
rt.gt_result_hash_postdrift = canonical_row_hash(rows)
# ------------------------------------------------------------------
# Tool dispatch
# ------------------------------------------------------------------
def _dispatch(self, action: SqlDriftAction, *, timeout_s: float) -> ToolResult:
payload = action.payload
try:
if isinstance(payload, ListTablesPayload):
return self._handle_list_tables()
if isinstance(payload, DescribeTablePayload):
return self._handle_describe_table(payload)
if isinstance(payload, SampleRowsPayload):
return self._handle_sample_rows(payload)
if isinstance(payload, RunQueryPayload):
return self._handle_run_query(payload, timeout_s=timeout_s)
if isinstance(payload, ExplainQueryPayload):
return self._handle_explain_query(payload, timeout_s=timeout_s)
if isinstance(payload, ReadChangelogPayload):
return self._handle_read_changelog()
if isinstance(payload, SubmitRewritePayload):
return self._handle_submit_rewrite(payload, timeout_s=timeout_s)
if isinstance(payload, ConsultDBAPayload):
return self._handle_consult_dba(payload)
except duckdb.Error as exc:
return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
except TimeoutError as exc:
return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
# Unreachable — the discriminated-union validator rejects unknown payloads.
return ToolError(
code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
message=f"unknown payload type: {type(payload).__name__}",
)
def _handle_list_tables(self) -> ListTablesResult:
rt = self._require_runtime()
rows = rt.conn.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = 'main' ORDER BY table_name"
).fetchall()
self._grant_step_rebate_once(attr="listed_tables_rewarded", rebate=STEP_REBATE_LIST_TABLES)
self._mark_diagnostic()
return ListTablesResult(tables=[r[0] for r in rows])
def _handle_describe_table(
self, payload: DescribeTablePayload
) -> DescribeTableResult | ToolError:
rt = self._require_runtime()
rows = rt.conn.execute(
"SELECT column_name, data_type FROM information_schema.columns "
"WHERE table_name = ? ORDER BY ordinal_position",
[payload.table],
).fetchall()
if not rows:
return ToolError(
code=ToolErrorCode.UNKNOWN_TABLE,
message=f"unknown table: {payload.table}",
)
self._grant_step_rebate_for_table(
rewarded_tables_attr="described_tables_rewarded",
table=payload.table,
rebate=STEP_REBATE_DESCRIBE_TABLE,
)
self._mark_diagnostic()
return DescribeTableResult(
table=payload.table,
columns=[{"name": r[0], "type": r[1]} for r in rows],
)
def _handle_sample_rows(self, payload: SampleRowsPayload) -> SampleRowsResult | ToolError:
rt = self._require_runtime()
exists = rt.conn.execute(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?",
[payload.table],
).fetchone()
if not exists or exists[0] == 0:
return ToolError(
code=ToolErrorCode.UNKNOWN_TABLE,
message=f"unknown table: {payload.table}",
)
cur = rt.conn.execute(f'SELECT * FROM "{payload.table}" LIMIT {payload.limit}')
columns = [d[0] for d in cur.description] if cur.description else []
rows = [list(r) for r in cur.fetchall()]
self._grant_step_rebate_for_table(
rewarded_tables_attr="sampled_tables_rewarded",
table=payload.table,
rebate=STEP_REBATE_SAMPLE_ROWS,
)
self._mark_diagnostic()
return SampleRowsResult(table=payload.table, columns=columns, rows=rows)
def _handle_run_query(
self, payload: RunQueryPayload, *, timeout_s: float
) -> RunQueryResult | ToolError:
rt = self._require_runtime()
sql = payload.sql
try:
_validate_read_only_sql(sql)
except ValueError as exc:
return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
# Drift timing: after a valid
# ``run_query`` attempt, the pre-drift probe invariant is
# satisfied regardless of whether the execution ultimately
# returned rows, raised, or was capped for size. Assigning
# *before* execution means truncation, DB errors, and timeouts
# can no longer suppress drift firing in later steps.
if rt.first_run_query_step is None:
rt.first_run_query_step = rt.step_count
try:
result = execute_once_with_columns(
rt.conn, sql, timeout_s=timeout_s, max_rows=MAX_RESULT_ROWS
)
except TimeoutError as exc:
return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
except duckdb.Error as exc:
# Canonicalize *before* hashing so whitespace-/case-only
# variants of the same broken query count as the same repeat
# offence. canonicalize_sql falls back to a whitespace fold
# for SQL that sqlglot can't parse — still normalises the
# vast majority of "retried the same typo" cases.
failure_hash = canonical_row_hash([(canonicalize_sql(sql),)])
count = rt.failed_query_counts.get(failure_hash, 0) + 1
rt.failed_query_counts[failure_hash] = count
rt.failed_query_hashes.add(failure_hash)
rt.last_step_repeat_failing_query_count = count
rt.last_step_was_repeat_failing_query = count > 1
return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
if result.truncated:
return ToolError(
code=ToolErrorCode.RESULT_TOO_LARGE,
message=(
f"result exceeded {MAX_RESULT_ROWS}-row cap — narrow the "
"projection, add a LIMIT, or aggregate"
),
)
self._grant_step_rebate_once(attr="run_query_rewarded", rebate=STEP_REBATE_RUN_QUERY)
self._mark_diagnostic()
return RunQueryResult(
columns=result.columns,
rows=[list(r) for r in result.rows],
runtime_ms=result.elapsed_ms,
row_count=len(result.rows),
)
def _handle_explain_query(
self, payload: ExplainQueryPayload, *, timeout_s: float
) -> ExplainQueryResult | ToolError:
rt = self._require_runtime()
try:
_validate_read_only_sql(payload.sql)
except ValueError as exc:
return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
# EXPLAIN is plan-only (no data materialisation) but we still
# route it through the watchdog so a pathological query cannot
# burn the step budget past the caller's wall-clock deadline.
explain_rows, _ = execute_once_timed(rt.conn, f"EXPLAIN {payload.sql}", timeout_s=timeout_s)
plan = "\n".join(str(r[-1]) if r else "" for r in explain_rows)
self._grant_step_rebate_once(
attr="explain_query_rewarded", rebate=STEP_REBATE_EXPLAIN_QUERY
)
self._mark_diagnostic()
return ExplainQueryResult(plan=plan[:10_000])
def _handle_read_changelog(self) -> ReadChangelogResult:
rt = self._require_runtime()
if rt.changelog_entries:
rt.drift_acknowledged = True
self._grant_step_rebate_once(
attr="changelog_rewarded_after_drift",
rebate=STEP_REBATE_READ_CHANGELOG,
)
self._mark_diagnostic()
return ReadChangelogResult(entries=list(rt.changelog_entries))
def _handle_submit_rewrite(
self, payload: SubmitRewritePayload, *, timeout_s: float
) -> SubmitRewriteResult | ToolError:
rt = self._require_runtime()
if not rt.diagnostic_actions_taken:
return ToolError(
code=ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE,
message=(
"submit_rewrite rejected: the agent must take at least one "
"diagnostic action (list_tables, describe_table, sample_rows, "
"run_query, explain_query, or read_changelog) before submitting."
),
)
sql = payload.sql
try:
_validate_read_only_sql(sql)
except ValueError as exc:
return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000])
try:
agent_hash, elapsed_ms = execute_hash_timed(rt.conn, sql, timeout_s=timeout_s)
except TimeoutError as exc:
return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000])
except duckdb.Error as exc:
return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000])
gt_hash = (
rt.gt_result_hash_postdrift
if rt.drift_fired and rt.gt_result_hash_postdrift is not None
else rt.gt_result_hash_predrift
)
matches = agent_hash == gt_hash
rt.submitted = True
rt.submitted_sql = sql
rt.submitted_sql_canonical = canonicalize_sql(sql)
rt.submitted_result_hash = agent_hash
rt.submitted_runtime_ms = elapsed_ms
rt.phase = EpisodePhase.FINALIZE
return SubmitRewriteResult(
accepted=True,
runtime_ms=elapsed_ms,
matches_ground_truth=matches,
)
def _handle_consult_dba(self, payload: ConsultDBAPayload) -> ConsultDBAResult | ToolError:
rt = self._require_runtime()
if not rt.dba_oracle_enabled:
return ToolError(
code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
message="consult_dba disabled — set enable_dba_oracle=True at reset()",
)
if not dba_oracle.has_hints(rt.scenario_id):
return ToolError(
code=ToolErrorCode.INVALID_TOOL_ARGUMENT,
message=f"no DBA hints registered for scenario={rt.scenario_id!r}",
)
rt.consultations_used += 1
tier = min(rt.consultations_used, 3)
hint = dba_oracle.get_hint(rt.scenario_id, tier)
del payload # question is free-text context only; hints are scenario-keyed.
return ConsultDBAResult(tier=tier, hint=hint)
def _mark_diagnostic(self) -> None:
"""Record a successful diagnostic tool call and advance the phase machine."""
rt = self._require_runtime()
rt.diagnostic_actions_taken += 1
if rt.phase == EpisodePhase.DIAGNOSE:
rt.phase = EpisodePhase.REWRITE
__all__ = [
"DEFAULT_STEP_BUDGET",
"MAX_RESULT_ROWS",
"QUERY_TIMEOUT_S",
"SqlDriftEnvironment",
]