Spaces:
Sleeping
Sleeping
| """OpenEnv ``Environment`` implementation for SQLDrift. | |
| Responsibilities: | |
| * Own the private :class:`engine.runtime.RuntimeEpisodeState` and the | |
| composite :class:`engine.reward.SqlDriftRubric` for the current episode. | |
| * Dispatch each of the eight tool-call payloads to a dedicated | |
| ``_handle_<tool>`` method that returns a typed | |
| :class:`models.ToolResult` (or :class:`models.ToolError`). | |
| * Fire drift on a schedule blended with a cooldown: ``max(scheduled, | |
| first_run_query_step + cooldown)`` before the agent acts on the step | |
| where drift applies, then recompute the post-drift ground truth hash. | |
| * Publish public observations (:class:`models.SqlDriftObservation`) and a | |
| strictly sanitised public state snapshot (:class:`models.SqlDriftState`). | |
| Privacy: ``self._runtime`` holds the DuckDB handle, ground-truth hashes, | |
| baseline runtime, and seed. They stay inside this class; the rubric reads | |
| them via a closure, and ``env.state`` exposes only a fixed whitelist of fields. | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import math | |
| import re | |
| import secrets | |
| from random import Random | |
| from typing import TYPE_CHECKING, Any, Literal | |
| import duckdb | |
| import sqlglot | |
| from openenv.core.env_server.interfaces import Environment | |
| from pydantic import BaseModel, ConfigDict, Field | |
| from actors import dba_oracle | |
| from actors.engineering_manager import author_changelog | |
| from engine.drift import apply_drift | |
| from engine.profiler import ( | |
| QueryWatchdogEscalationError, | |
| execute_hash_timed, | |
| execute_once_timed, | |
| execute_once_with_columns, | |
| ) | |
| from engine.reward import ( | |
| SPEEDUP_CAP_FOR_INFTY, | |
| STEP_REBATE_DESCRIBE_TABLE, | |
| STEP_REBATE_EXPLAIN_QUERY, | |
| STEP_REBATE_LIST_TABLES, | |
| STEP_REBATE_READ_CHANGELOG, | |
| STEP_REBATE_RUN_QUERY, | |
| STEP_REBATE_SAMPLE_ROWS, | |
| SqlDriftRubric, | |
| canonicalize_sql, | |
| effective_speedup, | |
| ) | |
| from engine.runtime import RuntimeEpisodeState | |
| from engine.verifier import canonical_row_hash | |
| from models import ( | |
| REWARD_COMPONENT_KEYS, | |
| ConsultDBAPayload, | |
| ConsultDBAResult, | |
| DescribeTablePayload, | |
| DescribeTableResult, | |
| EpisodePhase, | |
| ExplainQueryPayload, | |
| ExplainQueryResult, | |
| ListTablesPayload, | |
| ListTablesResult, | |
| ReadChangelogPayload, | |
| ReadChangelogResult, | |
| RunQueryPayload, | |
| RunQueryResult, | |
| SampleRowsPayload, | |
| SampleRowsResult, | |
| SqlDriftAction, | |
| SqlDriftObservation, | |
| SqlDriftState, | |
| SubmitRewritePayload, | |
| SubmitRewriteResult, | |
| ToolError, | |
| ToolErrorCode, | |
| ToolResult, | |
| ) | |
| from scenarios import REGISTRY, get_spec | |
| from skill_library import PlaybookEntry, Store, load_all, retrieve | |
| from utilities.logger import get_module_logger, log_env_reset, log_env_step, log_interaction | |
| from . import settings | |
| if TYPE_CHECKING: | |
| from scenarios.base import ScenarioSpec | |
| _LOG = get_module_logger(__name__) | |
| DEFAULT_STEP_BUDGET: int = settings.DEFAULT_STEP_BUDGET | |
| MAX_RESULT_ROWS: int = settings.MAX_RESULT_ROWS | |
| QUERY_TIMEOUT_S: float = settings.QUERY_TIMEOUT_S | |
| class _ResetOptions(BaseModel): | |
| model_config = ConfigDict(extra="ignore") | |
| scenario_id: str | None = None | |
| enable_dba_oracle: bool | None = None | |
| difficulty: Literal["easy", "normal", "hard"] = "normal" | |
| budget_steps: int = Field(default=DEFAULT_STEP_BUDGET, ge=1) | |
| _READ_ONLY_EXPRESSION_KEYS: frozenset[str] = frozenset({"select", "with"}) | |
| # DuckDB exposes a family of table-valued functions and scalar helpers | |
| # that read from the host filesystem or leak introspection state — | |
| # ``read_csv``, ``read_parquet``, ``read_json``, ``read_text``, | |
| # ``parquet_metadata``, ``duckdb_secrets``, ``glob``, etc. They are | |
| # *technically* SELECT-shaped calls so the statement-key check alone | |
| # admits them. We reject any function whose lowercased name starts with | |
| # one of these prefixes or exactly matches one of the known-dangerous | |
| # standalone names. Agent-facing SQL has no legitimate need for any of | |
| # them — the DuckDB connection is pre-populated by the scenario builder. | |
| _DENYLIST_PREFIXES: tuple[str, ...] = ( | |
| "read_", | |
| "write_", | |
| "copy_", | |
| "duckdb_", | |
| "pragma_", | |
| "sniff_", | |
| "parquet_", | |
| "arrow_", | |
| "json_table", | |
| "json_each", | |
| "sqlite_", | |
| "load_", | |
| "install_", | |
| ) | |
| _DENYLIST_EXACT: frozenset[str] = frozenset( | |
| { | |
| "glob", | |
| "attach", | |
| "detach", | |
| "checkpoint", | |
| "force_checkpoint", | |
| "set_secret", | |
| "create_secret", | |
| "drop_secret", | |
| "enable_profiling", | |
| "disable_profiling", | |
| "enable_object_cache", | |
| } | |
| ) | |
| def _is_denylisted_function_name(name: str) -> bool: | |
| """Return True iff ``name`` (case-insensitively) matches a sandbox-escape.""" | |
| lowered = name.lower() | |
| if lowered in _DENYLIST_EXACT: | |
| return True | |
| return any(lowered.startswith(p) for p in _DENYLIST_PREFIXES) | |
| def _function_names(node: sqlglot.exp.Func) -> list[str]: | |
| """All plausible names to check against the denylist for one AST node. | |
| sqlglot lowers a few DuckDB calls into dedicated expression classes | |
| (``ReadCSV``, ``ReadParquet``, …) whose ``.name`` is actually the | |
| first positional arg — the file path — not the function name. We | |
| recover the function name from the class name in that case and fall | |
| back to ``.name`` for the ``Anonymous`` form that covers everything | |
| else. Including both lets one denylist lookup cover both lowerings. | |
| """ | |
| cls = type(node).__name__ | |
| out: list[str] = [] | |
| # Derive a snake-case function name from the class name. We insert | |
| # an underscore at two kinds of CamelCase boundaries: | |
| # | |
| # * ``aB`` — normal lower-to-upper (``ReadParquet`` → ``read_parquet``) | |
| # * ``ABc`` — end of an acronym run (``ReadCSVAuto`` → ``read_csv_auto``) | |
| # | |
| # Purely-lowercase class names (``Anonymous``) produce no prefix | |
| # match; we fall through to ``.name`` below for those. | |
| if cls and cls[0].isupper(): | |
| snake = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", cls).lower() | |
| out.append(snake) | |
| name_attr = getattr(node, "name", None) | |
| if isinstance(name_attr, str) and name_attr: | |
| out.append(name_attr) | |
| return out | |
| _VALID_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") | |
| def _resolve_timeout_s(timeout_s: float | None) -> float: | |
| """Caller-supplied per-step timeout or the module default. | |
| ``timeout_s`` is accepted on every OpenEnv ``step()`` (the abstract | |
| base mandates the keyword). When the caller provides a positive | |
| value we honour it as the wall-clock budget for any DuckDB query | |
| this step runs; ``None`` and non-positive values fall back to the | |
| module-level :data:`QUERY_TIMEOUT_S` so a mis-configured client | |
| cannot silently disable the watchdog. | |
| """ | |
| if timeout_s is None or timeout_s <= 0: | |
| return QUERY_TIMEOUT_S | |
| return float(timeout_s) | |
| def _initial_schema_synopsis(spec: ScenarioSpec, synopsis: str) -> str: | |
| """Reset-time synopsis with future drift details removed. | |
| Drift scenarios should not reveal the exact schema/business-rule | |
| change before the changelog is published at runtime. We therefore | |
| trim the authored synopsis at the first ``" Under drift"`` clause on | |
| reset and only surface the pre-drift schema shape. | |
| """ | |
| if spec.drift_config is None: | |
| return synopsis | |
| predrift, marker, _ = synopsis.partition(" Under drift") | |
| return predrift if marker else synopsis | |
| def _validate_read_only_sql(sql: str) -> None: | |
| """Reject anything that isn't a single-statement read-only SELECT/CTE. | |
| Raises ``ValueError`` so the caller can translate to a typed | |
| :class:`models.ToolError` with :attr:`ToolErrorCode.INVALID_TOOL_ARGUMENT`. | |
| This is the only place that mediates what the policy may execute; | |
| scenario builders and drift DDL call DuckDB directly with privileged | |
| SQL and deliberately bypass this check. | |
| Beyond the statement-level gate, this walker also rejects two | |
| sandbox-escape vectors that would otherwise ride along inside a | |
| perfectly-shaped SELECT: | |
| 1. Table-valued functions that read from the host filesystem | |
| (``read_csv``, ``read_parquet``, ``read_json_auto``, ``glob``, | |
| ``read_text``, …) or leak engine introspection (``duckdb_secrets`` | |
| carries credentials; ``duckdb_settings`` /``duckdb_functions`` | |
| can enumerate available exploits). See :data:`_DENYLIST_PREFIXES` | |
| / :data:`_DENYLIST_EXACT`. | |
| 2. ``SELECT * FROM 'path/to/x.csv'`` — DuckDB treats a bare string | |
| literal in a FROM clause as a filesystem path and auto-detects | |
| the format. There is no function node to inspect in this form, | |
| so we separately reject any :class:`sqlglot.exp.Table` whose | |
| backing expression is a string literal. | |
| """ | |
| try: | |
| statements = sqlglot.parse(sql, dialect="duckdb") | |
| except sqlglot.errors.ParseError as exc: | |
| raise ValueError(f"SQL failed to parse: {exc}") from exc | |
| non_empty = [s for s in statements if s is not None] | |
| if len(non_empty) != 1: | |
| raise ValueError("multi-statement SQL is not allowed; submit one SELECT") | |
| expr = non_empty[0] | |
| if expr.key not in _READ_ONLY_EXPRESSION_KEYS: | |
| raise ValueError( | |
| f"only read-only SELECT/CTE queries are allowed (got {expr.key.upper()} statement)" | |
| ) | |
| for node in expr.walk(): | |
| # (1) Function-valued sandbox escapes. Inspect both the class | |
| # name (catches ``ReadCSV`` / ``ReadParquet`` lowerings where | |
| # ``.name`` holds the file path, not the function name) and | |
| # ``.name`` (catches the generic ``Anonymous`` form). | |
| if isinstance(node, sqlglot.exp.Func): | |
| for fn_name in _function_names(node): | |
| if _is_denylisted_function_name(fn_name): | |
| raise ValueError( | |
| f"function {fn_name!r} is not allowed — agent-facing SQL may " | |
| "only touch the scenario's in-memory tables" | |
| ) | |
| # (2) Bare-path FROM form: ``SELECT * FROM 'x.csv'`` or | |
| # ``SELECT * FROM '/etc/passwd'``. sqlglot normalises both | |
| # single- and double-quoted identifiers to | |
| # ``Identifier(quoted=True)``, so we can't rely on the quote | |
| # flavour to distinguish a file path from a legitimately-quoted | |
| # table name. Instead we require every agent-facing table name | |
| # to be a valid unquoted SQL identifier — the scenarios never | |
| # emit anything else, and paths always contain ``/``, ``.`` or | |
| # ``~`` which fail the identifier regex. | |
| if isinstance(node, sqlglot.exp.Table): | |
| inner = node.this | |
| if isinstance(inner, sqlglot.exp.Identifier): | |
| ident_name = inner.name | |
| if ident_name and not _VALID_IDENTIFIER_RE.match(ident_name): | |
| raise ValueError( | |
| f"table identifier {ident_name!r} is not a valid unquoted SQL " | |
| "name — reading from file paths or other engine-specific " | |
| "resources is not allowed" | |
| ) | |
| class SqlDriftEnvironment(Environment[SqlDriftAction, SqlDriftObservation, SqlDriftState]): | |
| """OpenEnv environment for SQL repair + optimization under schema drift.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__( | |
| self, | |
| skill_store: Store | None = None, | |
| cleanup_on_close: bool = False, | |
| ) -> None: | |
| self._runtime: RuntimeEpisodeState | None = None | |
| self._skill_store: Store | None = skill_store | |
| # When True, the skill-store directory is deleted when close() is called. | |
| # Set this for server-managed per-session stores so disk usage doesn't grow | |
| # monotonically; see design/codereview.md (session store issue). | |
| self._cleanup_on_close: bool = cleanup_on_close | |
| super().__init__( | |
| rubric=SqlDriftRubric(ctx_provider=lambda: self._require_runtime()), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # OpenEnv contract | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs: Any, | |
| ) -> SqlDriftObservation: | |
| options = _ResetOptions.model_validate(kwargs) | |
| scenario_id = options.scenario_id | |
| enable_dba_oracle = dba_oracle.is_enabled(options.enable_dba_oracle) | |
| difficulty = options.difficulty | |
| budget_steps = options.budget_steps | |
| if seed is None: | |
| seed = secrets.randbits(31) | |
| if episode_id is None: | |
| episode_id = f"ep-{seed:08x}" | |
| if scenario_id is None: | |
| scenario_id = self._pick_scenario_for_seed(seed) | |
| spec = get_spec(scenario_id) | |
| instance = spec.materialize(seed, difficulty=difficulty) | |
| drift_scheduled_step: int | None = None | |
| if instance.drift_config is not None: | |
| drift_scheduled_step = Random(seed).randint( | |
| instance.drift_config.min_step, | |
| instance.drift_config.max_step, | |
| ) | |
| self._close_existing_runtime() | |
| self._runtime = RuntimeEpisodeState( | |
| episode_id=episode_id, | |
| seed=seed, | |
| scenario_id=scenario_id, | |
| instance=instance, | |
| conn=instance.conn, | |
| gt_result_hash_predrift=instance.gt_result_hash_predrift, | |
| gt_result_hash_postdrift=None, | |
| baseline_runtime_ms=instance.baseline_runtime_ms, | |
| baseline_tokens=instance.baseline_tokens, | |
| baseline_sql_canonical=canonicalize_sql(instance.baseline_sql), | |
| baseline_postdrift_raises=False, | |
| drift_scheduled_step=drift_scheduled_step, | |
| budget_steps=budget_steps, | |
| dba_oracle_enabled=enable_dba_oracle, | |
| ) | |
| self._reset_rubric() | |
| learned_hints = kwargs.get("learned_hints") | |
| if learned_hints is None: | |
| learned_hints = self._render_learned_hints(spec, include_drift_cards=False) | |
| if len(learned_hints) > 800: | |
| learned_hints = learned_hints[:800] | |
| rt = self._require_runtime() | |
| return SqlDriftObservation( | |
| step=0, | |
| phase=EpisodePhase.DIAGNOSE, | |
| last_tool=None, | |
| tool_result=None, | |
| drift_fired=False, | |
| drift_acknowledged=False, | |
| learned_hints=learned_hints, | |
| baseline_sql=instance.baseline_sql, | |
| schema_synopsis=_initial_schema_synopsis(spec, instance.schema_synopsis), | |
| budget_steps_remaining=rt.budget_steps_remaining, | |
| reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS}, | |
| done=False, | |
| reward=None, | |
| ) | |
| def step( | |
| self, | |
| action: SqlDriftAction, | |
| timeout_s: float | None = None, | |
| **kwargs: Any, | |
| ) -> SqlDriftObservation: | |
| rt = self._require_runtime() | |
| if rt.submitted or rt.budget_steps_remaining <= 0: | |
| raise ValueError("Episode is already finished; call reset() to start a new episode.") | |
| rt.step_count += 1 | |
| rt.last_step_was_tool_error = False | |
| rt.last_step_was_repeat_failing_query = False | |
| rt.last_step_repeat_failing_query_count = 0 | |
| rt.last_step_productive_rebate = 0.0 | |
| self._maybe_fire_drift() | |
| effective_timeout_s = _resolve_timeout_s(timeout_s) | |
| try: | |
| tool_result = self._dispatch(action, timeout_s=effective_timeout_s) | |
| except QueryWatchdogEscalationError: | |
| rt.connection_poisoned = True | |
| rt.phase = EpisodePhase.FINALIZE | |
| rt.step_count = max(rt.step_count, rt.budget_steps) | |
| _LOG.error("episode %s aborted after watchdog escalation", rt.episode_id) | |
| raise | |
| rt.last_step_was_tool_error = isinstance(tool_result, ToolError) | |
| if rt.last_step_was_tool_error: | |
| rt.consecutive_tool_errors += 1 | |
| else: | |
| rt.consecutive_tool_errors = 0 | |
| done = rt.submitted or rt.budget_steps_remaining <= 0 | |
| obs = SqlDriftObservation( | |
| step=rt.step_count, | |
| phase=rt.phase, | |
| last_tool=action.tool, | |
| tool_result=tool_result, | |
| drift_fired=rt.drift_fired, | |
| drift_acknowledged=rt.drift_acknowledged, | |
| learned_hints="", | |
| baseline_sql="", | |
| schema_synopsis="", | |
| budget_steps_remaining=rt.budget_steps_remaining, | |
| reward_components={key: 0.0 for key in REWARD_COMPONENT_KEYS}, | |
| done=done, | |
| reward=None, | |
| ) | |
| if rt.drift_acknowledged: | |
| spec = get_spec(rt.scenario_id) | |
| obs.learned_hints = self._render_learned_hints(spec, include_drift_cards=True) | |
| obs.reward = self._apply_rubric(action, obs) | |
| if self.rubric is not None: | |
| obs.reward_components = self.rubric.component_scores() | |
| if done and rt.submitted: | |
| self._maybe_persist_learned_entry() | |
| return obs | |
| def render(self) -> dict[str, Any]: | |
| """Render the current public state and log the render interaction.""" | |
| rt = self._require_runtime() | |
| state = self.state | |
| payload = state.model_dump(mode="json") | |
| log_interaction( | |
| event_type="render", | |
| agent_id=rt.episode_id, | |
| observation_returned=payload, | |
| done=rt.submitted or rt.budget_steps_remaining <= 0, | |
| ) | |
| return payload | |
| def state(self) -> SqlDriftState: | |
| """Sanitised public state snapshot (explicit whitelist).""" | |
| rt = self._require_runtime() | |
| return SqlDriftState( | |
| episode_id=rt.episode_id, | |
| step_count=rt.step_count, | |
| scenario_id=rt.scenario_id, | |
| phase=rt.phase, | |
| budget_steps_remaining=rt.budget_steps_remaining, | |
| drift_fired=rt.drift_fired, | |
| consultations_used=rt.consultations_used, | |
| submitted=rt.submitted, | |
| ) | |
| def effective_speedup(self) -> float | None: | |
| """Return the current episode's effective speedup, if any.""" | |
| rt = self._runtime | |
| if rt is None: | |
| return None | |
| return effective_speedup(rt) | |
| def close(self) -> None: | |
| self._close_existing_runtime() | |
| if self._cleanup_on_close and self._skill_store is not None: | |
| import shutil | |
| store_dir = self._skill_store.dir | |
| shutil.rmtree(store_dir, ignore_errors=True) | |
| # ------------------------------------------------------------------ | |
| # Skill-library wiring | |
| # ------------------------------------------------------------------ | |
| def _render_learned_hints(self, spec: ScenarioSpec, *, include_drift_cards: bool = True) -> str: | |
| playbook, drift_cards = load_all(self._skill_store) | |
| drift_kind = None | |
| if include_drift_cards and spec.drift_config is not None: | |
| drift_kind = spec.drift_config.kind | |
| result = retrieve( | |
| query_tags=spec.tags, | |
| drift_kind=drift_kind, | |
| playbook=playbook, | |
| drift_cards=drift_cards, | |
| ) | |
| return result.render(max_chars=800) | |
| def _maybe_persist_learned_entry(self) -> None: | |
| """Append a PlaybookEntry on terminal success with a meaningful speedup. | |
| Failures to persist are logged but never re-raised: a training | |
| rollout should not crash because the on-disk playbook is under | |
| contention. The skill store itself is crash-safe (atomic writes | |
| + file-lock) so at-most-once semantics are sufficient here. | |
| """ | |
| if self._skill_store is None: | |
| return | |
| rt = self._require_runtime() | |
| if not rt.submitted: | |
| return | |
| if self.rubric is None: | |
| return | |
| scores = self.rubric.component_scores() | |
| if scores.get("r_correct", 0.0) < 1.0: | |
| return | |
| spec = get_spec(rt.scenario_id) | |
| raw_speedup = effective_speedup(rt) | |
| # effective_speedup cannot return None here — rt.submitted is True | |
| # so submitted_runtime_ms is populated — but we guard defensively. | |
| # ``+∞`` (drift invalidated the baseline) is capped so the on-disk | |
| # playbook doesn't serialize ``Infinity``, which would round-trip | |
| # as a JSON parse error on load. | |
| if raw_speedup is None or math.isinf(raw_speedup): | |
| speedup_val = float(SPEEDUP_CAP_FOR_INFTY) | |
| else: | |
| speedup_val = float(raw_speedup) | |
| entry = PlaybookEntry( | |
| tag_set=spec.tags, | |
| before_snippet=rt.instance.baseline_sql[:200], | |
| after_snippet=(rt.submitted_sql or "")[:200], | |
| avg_speedup=speedup_val, | |
| scenario_family=spec.family, | |
| source="learned", | |
| ) | |
| try: | |
| self._skill_store.append_playbook(entry) | |
| except Exception as exc: | |
| _LOG.warning("skill-library append_playbook failed: %s", exc) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _grant_step_rebate_once(self, *, attr: str, rebate: float) -> None: | |
| rt = self._require_runtime() | |
| if getattr(rt, attr): | |
| return | |
| setattr(rt, attr, True) | |
| rt.last_step_productive_rebate += rebate | |
| def _grant_step_rebate_for_table( | |
| self, *, rewarded_tables_attr: str, table: str, rebate: float | |
| ) -> None: | |
| rt = self._require_runtime() | |
| rewarded = getattr(rt, rewarded_tables_attr) | |
| if table in rewarded: | |
| return | |
| rewarded.add(table) | |
| rt.last_step_productive_rebate += rebate | |
| def _pick_scenario_for_seed(seed: int) -> str: | |
| """Deterministic round-robin over the sorted scenario registry.""" | |
| ids = sorted(REGISTRY) | |
| if not ids: | |
| raise RuntimeError("no scenarios registered") | |
| return ids[seed % len(ids)] | |
| def _require_runtime(self) -> RuntimeEpisodeState: | |
| if self._runtime is None: | |
| raise RuntimeError("SqlDriftEnvironment.reset() must be called before step()/state.") | |
| return self._runtime | |
| def _close_existing_runtime(self) -> None: | |
| if self._runtime is not None: | |
| if self._runtime.connection_poisoned: | |
| _LOG.error( | |
| "skipping close for poisoned DuckDB connection in episode %s", | |
| self._runtime.episode_id, | |
| ) | |
| else: | |
| with contextlib.suppress(duckdb.Error): | |
| self._runtime.conn.close() | |
| self._runtime = None | |
| def _maybe_fire_drift(self) -> None: | |
| """Apply drift when the step index crosses the schedule/cooldown threshold.""" | |
| rt = self._require_runtime() | |
| if rt.drift_fired: | |
| return | |
| if rt.drift_scheduled_step is None: | |
| return | |
| if rt.first_run_query_step is None: | |
| return | |
| cfg = rt.instance.drift_config | |
| assert cfg is not None | |
| minimum = max(rt.drift_scheduled_step, rt.first_run_query_step + cfg.cooldown_steps) | |
| if rt.step_count < minimum: | |
| return | |
| self._fire_drift() | |
| def _fire_drift(self) -> None: | |
| """Apply drift, author a changelog, and resolve the post-drift GT hash. | |
| Failure to recompute the post-drift GT hash is an authoring bug | |
| (the scenario's ``gt_sql_postdrift`` must execute against the | |
| just-mutated DB) and we re-raise loudly so it cannot silently | |
| make every post-drift submission score ``r_correct=0``. | |
| """ | |
| rt = self._require_runtime() | |
| cfg = rt.instance.drift_config | |
| assert cfg is not None | |
| apply_drift(rt.conn, cfg.kind, cfg.payload) | |
| rt.drift_fired_step = rt.step_count | |
| rt.phase = EpisodePhase.DRIFT_RECOVERY | |
| rt.changelog_entries.append(author_changelog(cfg)) | |
| try: | |
| rt.conn.execute(rt.instance.baseline_sql).fetchall() | |
| rt.baseline_postdrift_raises = False | |
| except duckdb.Error: | |
| rt.baseline_postdrift_raises = True | |
| if rt.instance.gt_sql_postdrift is not None: | |
| try: | |
| rows = rt.conn.execute(rt.instance.gt_sql_postdrift).fetchall() | |
| except duckdb.Error as exc: | |
| raise RuntimeError( | |
| f"scenario {rt.scenario_id!r}: authored gt_sql_postdrift failed " | |
| f"after drift: {exc}" | |
| ) from exc | |
| rt.gt_result_hash_postdrift = canonical_row_hash(rows) | |
| # ------------------------------------------------------------------ | |
| # Tool dispatch | |
| # ------------------------------------------------------------------ | |
| def _dispatch(self, action: SqlDriftAction, *, timeout_s: float) -> ToolResult: | |
| payload = action.payload | |
| try: | |
| if isinstance(payload, ListTablesPayload): | |
| return self._handle_list_tables() | |
| if isinstance(payload, DescribeTablePayload): | |
| return self._handle_describe_table(payload) | |
| if isinstance(payload, SampleRowsPayload): | |
| return self._handle_sample_rows(payload) | |
| if isinstance(payload, RunQueryPayload): | |
| return self._handle_run_query(payload, timeout_s=timeout_s) | |
| if isinstance(payload, ExplainQueryPayload): | |
| return self._handle_explain_query(payload, timeout_s=timeout_s) | |
| if isinstance(payload, ReadChangelogPayload): | |
| return self._handle_read_changelog() | |
| if isinstance(payload, SubmitRewritePayload): | |
| return self._handle_submit_rewrite(payload, timeout_s=timeout_s) | |
| if isinstance(payload, ConsultDBAPayload): | |
| return self._handle_consult_dba(payload) | |
| except duckdb.Error as exc: | |
| return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000]) | |
| except TimeoutError as exc: | |
| return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000]) | |
| # Unreachable — the discriminated-union validator rejects unknown payloads. | |
| return ToolError( | |
| code=ToolErrorCode.INVALID_TOOL_ARGUMENT, | |
| message=f"unknown payload type: {type(payload).__name__}", | |
| ) | |
| def _handle_list_tables(self) -> ListTablesResult: | |
| rt = self._require_runtime() | |
| rows = rt.conn.execute( | |
| "SELECT table_name FROM information_schema.tables " | |
| "WHERE table_schema = 'main' ORDER BY table_name" | |
| ).fetchall() | |
| self._grant_step_rebate_once(attr="listed_tables_rewarded", rebate=STEP_REBATE_LIST_TABLES) | |
| self._mark_diagnostic() | |
| return ListTablesResult(tables=[r[0] for r in rows]) | |
| def _handle_describe_table( | |
| self, payload: DescribeTablePayload | |
| ) -> DescribeTableResult | ToolError: | |
| rt = self._require_runtime() | |
| rows = rt.conn.execute( | |
| "SELECT column_name, data_type FROM information_schema.columns " | |
| "WHERE table_name = ? ORDER BY ordinal_position", | |
| [payload.table], | |
| ).fetchall() | |
| if not rows: | |
| return ToolError( | |
| code=ToolErrorCode.UNKNOWN_TABLE, | |
| message=f"unknown table: {payload.table}", | |
| ) | |
| self._grant_step_rebate_for_table( | |
| rewarded_tables_attr="described_tables_rewarded", | |
| table=payload.table, | |
| rebate=STEP_REBATE_DESCRIBE_TABLE, | |
| ) | |
| self._mark_diagnostic() | |
| return DescribeTableResult( | |
| table=payload.table, | |
| columns=[{"name": r[0], "type": r[1]} for r in rows], | |
| ) | |
| def _handle_sample_rows(self, payload: SampleRowsPayload) -> SampleRowsResult | ToolError: | |
| rt = self._require_runtime() | |
| exists = rt.conn.execute( | |
| "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?", | |
| [payload.table], | |
| ).fetchone() | |
| if not exists or exists[0] == 0: | |
| return ToolError( | |
| code=ToolErrorCode.UNKNOWN_TABLE, | |
| message=f"unknown table: {payload.table}", | |
| ) | |
| cur = rt.conn.execute(f'SELECT * FROM "{payload.table}" LIMIT {payload.limit}') | |
| columns = [d[0] for d in cur.description] if cur.description else [] | |
| rows = [list(r) for r in cur.fetchall()] | |
| self._grant_step_rebate_for_table( | |
| rewarded_tables_attr="sampled_tables_rewarded", | |
| table=payload.table, | |
| rebate=STEP_REBATE_SAMPLE_ROWS, | |
| ) | |
| self._mark_diagnostic() | |
| return SampleRowsResult(table=payload.table, columns=columns, rows=rows) | |
| def _handle_run_query( | |
| self, payload: RunQueryPayload, *, timeout_s: float | |
| ) -> RunQueryResult | ToolError: | |
| rt = self._require_runtime() | |
| sql = payload.sql | |
| try: | |
| _validate_read_only_sql(sql) | |
| except ValueError as exc: | |
| return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000]) | |
| # Drift timing: after a valid | |
| # ``run_query`` attempt, the pre-drift probe invariant is | |
| # satisfied regardless of whether the execution ultimately | |
| # returned rows, raised, or was capped for size. Assigning | |
| # *before* execution means truncation, DB errors, and timeouts | |
| # can no longer suppress drift firing in later steps. | |
| if rt.first_run_query_step is None: | |
| rt.first_run_query_step = rt.step_count | |
| try: | |
| result = execute_once_with_columns( | |
| rt.conn, sql, timeout_s=timeout_s, max_rows=MAX_RESULT_ROWS | |
| ) | |
| except TimeoutError as exc: | |
| return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000]) | |
| except duckdb.Error as exc: | |
| # Canonicalize *before* hashing so whitespace-/case-only | |
| # variants of the same broken query count as the same repeat | |
| # offence. canonicalize_sql falls back to a whitespace fold | |
| # for SQL that sqlglot can't parse — still normalises the | |
| # vast majority of "retried the same typo" cases. | |
| failure_hash = canonical_row_hash([(canonicalize_sql(sql),)]) | |
| count = rt.failed_query_counts.get(failure_hash, 0) + 1 | |
| rt.failed_query_counts[failure_hash] = count | |
| rt.failed_query_hashes.add(failure_hash) | |
| rt.last_step_repeat_failing_query_count = count | |
| rt.last_step_was_repeat_failing_query = count > 1 | |
| return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000]) | |
| if result.truncated: | |
| return ToolError( | |
| code=ToolErrorCode.RESULT_TOO_LARGE, | |
| message=( | |
| f"result exceeded {MAX_RESULT_ROWS}-row cap — narrow the " | |
| "projection, add a LIMIT, or aggregate" | |
| ), | |
| ) | |
| self._grant_step_rebate_once(attr="run_query_rewarded", rebate=STEP_REBATE_RUN_QUERY) | |
| self._mark_diagnostic() | |
| return RunQueryResult( | |
| columns=result.columns, | |
| rows=[list(r) for r in result.rows], | |
| runtime_ms=result.elapsed_ms, | |
| row_count=len(result.rows), | |
| ) | |
| def _handle_explain_query( | |
| self, payload: ExplainQueryPayload, *, timeout_s: float | |
| ) -> ExplainQueryResult | ToolError: | |
| rt = self._require_runtime() | |
| try: | |
| _validate_read_only_sql(payload.sql) | |
| except ValueError as exc: | |
| return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000]) | |
| # EXPLAIN is plan-only (no data materialisation) but we still | |
| # route it through the watchdog so a pathological query cannot | |
| # burn the step budget past the caller's wall-clock deadline. | |
| explain_rows, _ = execute_once_timed(rt.conn, f"EXPLAIN {payload.sql}", timeout_s=timeout_s) | |
| plan = "\n".join(str(r[-1]) if r else "" for r in explain_rows) | |
| self._grant_step_rebate_once( | |
| attr="explain_query_rewarded", rebate=STEP_REBATE_EXPLAIN_QUERY | |
| ) | |
| self._mark_diagnostic() | |
| return ExplainQueryResult(plan=plan[:10_000]) | |
| def _handle_read_changelog(self) -> ReadChangelogResult: | |
| rt = self._require_runtime() | |
| if rt.changelog_entries: | |
| rt.drift_acknowledged = True | |
| self._grant_step_rebate_once( | |
| attr="changelog_rewarded_after_drift", | |
| rebate=STEP_REBATE_READ_CHANGELOG, | |
| ) | |
| self._mark_diagnostic() | |
| return ReadChangelogResult(entries=list(rt.changelog_entries)) | |
| def _handle_submit_rewrite( | |
| self, payload: SubmitRewritePayload, *, timeout_s: float | |
| ) -> SubmitRewriteResult | ToolError: | |
| rt = self._require_runtime() | |
| if not rt.diagnostic_actions_taken: | |
| return ToolError( | |
| code=ToolErrorCode.SUBMIT_BEFORE_DIAGNOSE, | |
| message=( | |
| "submit_rewrite rejected: the agent must take at least one " | |
| "diagnostic action (list_tables, describe_table, sample_rows, " | |
| "run_query, explain_query, or read_changelog) before submitting." | |
| ), | |
| ) | |
| sql = payload.sql | |
| try: | |
| _validate_read_only_sql(sql) | |
| except ValueError as exc: | |
| return ToolError(code=ToolErrorCode.INVALID_TOOL_ARGUMENT, message=str(exc)[:2000]) | |
| try: | |
| agent_hash, elapsed_ms = execute_hash_timed(rt.conn, sql, timeout_s=timeout_s) | |
| except TimeoutError as exc: | |
| return ToolError(code=ToolErrorCode.QUERY_TIMEOUT, message=str(exc)[:2000]) | |
| except duckdb.Error as exc: | |
| return ToolError(code=ToolErrorCode.DB_ERROR, message=str(exc)[:2000]) | |
| gt_hash = ( | |
| rt.gt_result_hash_postdrift | |
| if rt.drift_fired and rt.gt_result_hash_postdrift is not None | |
| else rt.gt_result_hash_predrift | |
| ) | |
| matches = agent_hash == gt_hash | |
| rt.submitted = True | |
| rt.submitted_sql = sql | |
| rt.submitted_sql_canonical = canonicalize_sql(sql) | |
| rt.submitted_result_hash = agent_hash | |
| rt.submitted_runtime_ms = elapsed_ms | |
| rt.phase = EpisodePhase.FINALIZE | |
| return SubmitRewriteResult( | |
| accepted=True, | |
| runtime_ms=elapsed_ms, | |
| matches_ground_truth=matches, | |
| ) | |
| def _handle_consult_dba(self, payload: ConsultDBAPayload) -> ConsultDBAResult | ToolError: | |
| rt = self._require_runtime() | |
| if not rt.dba_oracle_enabled: | |
| return ToolError( | |
| code=ToolErrorCode.INVALID_TOOL_ARGUMENT, | |
| message="consult_dba disabled — set enable_dba_oracle=True at reset()", | |
| ) | |
| if not dba_oracle.has_hints(rt.scenario_id): | |
| return ToolError( | |
| code=ToolErrorCode.INVALID_TOOL_ARGUMENT, | |
| message=f"no DBA hints registered for scenario={rt.scenario_id!r}", | |
| ) | |
| rt.consultations_used += 1 | |
| tier = min(rt.consultations_used, 3) | |
| hint = dba_oracle.get_hint(rt.scenario_id, tier) | |
| del payload # question is free-text context only; hints are scenario-keyed. | |
| return ConsultDBAResult(tier=tier, hint=hint) | |
| def _mark_diagnostic(self) -> None: | |
| """Record a successful diagnostic tool call and advance the phase machine.""" | |
| rt = self._require_runtime() | |
| rt.diagnostic_actions_taken += 1 | |
| if rt.phase == EpisodePhase.DIAGNOSE: | |
| rt.phase = EpisodePhase.REWRITE | |
| __all__ = [ | |
| "DEFAULT_STEP_BUDGET", | |
| "MAX_RESULT_ROWS", | |
| "QUERY_TIMEOUT_S", | |
| "SqlDriftEnvironment", | |
| ] | |