sql-drift-env / training /tool_env.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""TRL-compatible tool environment for SQLDrift.
TRL's :class:`trl.GRPOTrainer` drives multi-turn OpenEnv rollouts via
the ``environment_factory=`` kwarg: it instantiates one instance of the
class per generation, calls :meth:`reset` with the sampled dataset row
columns as keyword arguments, and exposes every public method with a
typed ``Args:`` docstring as a function-calling tool. We therefore
unroll the eight SQLDrift tools (``list_tables``, ``describe_table``,
``sample_rows``, ``run_query``, ``explain_query``, ``read_changelog``,
``submit_rewrite``, ``consult_dba``) as individual methods with
minimal, model-facing surface area — exactly the shape TRL's tool
schema extractor expects.
Side-channel state the trainer's reward function reads:
* ``self.reward`` — the *latest* scalar reward from the env (overwritten
on every tool call). TRL reward functions consume this after the
episode finishes, so the final value is what matters for GRPO group
comparisons.
* ``self.episode_return`` — the running sum of per-step rewards.
* ``self.terminal_reward`` — sticky copy of the final reward emitted
when the env set ``done=True``.
* ``self.done`` — True once the episode terminated.
* ``self.components`` — most recent per-rubric component scores.
The default client factory opens a synchronous OpenEnv session against
the caller-supplied ``env_url`` (falling back to the
``SQL_DRIFT_ENV_URL`` environment variable). Tests substitute an
in-process factory via
:func:`set_client_factory` so CPU-only CI can exercise the class
without a running server.
"""
from __future__ import annotations
import os
from collections.abc import Callable
from typing import Any
from client import SqlDriftEnv
from utilities.env_loader import env_str
_DEFAULT_ENV_URL = env_str("SQL_DRIFT_ENV_URL", "http://localhost:8000")
class _SyncClientProtocol:
"""Minimal subset of :class:`openenv.core.env_client.EnvClient` we rely on."""
def reset(self, *, seed: int | None = None, **kwargs: Any) -> Any: ...
def step(self, action: Any) -> Any: ...
def close(self) -> None: ...
ClientFactory = Callable[[], _SyncClientProtocol]
def _default_client_factory(base_url: str) -> _SyncClientProtocol:
"""Build the standard sync WebSocket-backed SQLDrift client."""
client: _SyncClientProtocol = SqlDriftEnv(base_url=base_url).sync()
return client
_client_factory_override: ClientFactory | None = None
def set_client_factory(factory: ClientFactory | None) -> None:
"""Override the client factory (tests and in-process rollouts).
Passing ``None`` restores the default HTTP/WS-backed factory.
"""
global _client_factory_override
_client_factory_override = factory
# ---------------------------------------------------------------------------
# Internal helpers — kept private (underscore-prefixed) so TRL does NOT
# expose them as tools.
# ---------------------------------------------------------------------------
def _format_columns(columns: list[str]) -> str:
return ", ".join(columns) if columns else "(none)"
def _format_rows(columns: list[str], rows: list[list[Any]], *, limit: int = 20) -> str:
"""Render a small result set as a fixed-width table string.
Large result sets are truncated to ``limit`` rows with a trailing
``... (N more)`` marker so the model's context does not balloon
even when a query legitimately returns many rows.
"""
if not rows:
return f"columns: {_format_columns(columns)}\n(0 rows)"
head = rows[:limit]
body = "\n".join(" | ".join(str(cell) for cell in row) for row in head)
more = f"\n... ({len(rows) - len(head)} more)" if len(rows) > len(head) else ""
return f"columns: {_format_columns(columns)}\n{body}{more}"
# ---------------------------------------------------------------------------
# Tool-env class — TRL consumes the public methods as tool schemas.
# Method order here matches :class:`models.ToolName` so the model's tool
# menu stays aligned with the server's dispatch table.
# ---------------------------------------------------------------------------
class SqlDriftToolEnv:
"""TRL-facing wrapper around the SQLDrift OpenEnv client.
One instance per GRPO generation. TRL discovers each public method
(list_tables, describe_table, …) as a callable tool; the class
itself stores the per-episode reward so the reward function can
read it back after the rollout finishes.
"""
def __init__(self, *, env_url: str | None = None) -> None:
base_url = env_url or os.environ.get("SQL_DRIFT_ENV_URL", _DEFAULT_ENV_URL)
if _client_factory_override is None:
self._client = _default_client_factory(base_url)
else:
self._client = _client_factory_override()
self.reward: float = 0.0
self.terminal_reward: float = 0.0
self.episode_return: float = 0.0
self.done: bool = False
self.submitted: bool = False
self.drift_fired: bool = False
self.components: dict[str, float] = {}
self._last_observation_text: str = ""
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def reset(self, **kwargs: Any) -> str | None:
"""Begin a new SQLDrift episode.
TRL invokes ``reset`` with every dataset column as a keyword
argument. We accept ``seed``, ``scenario_id``,
``enable_dba_oracle``, ``budget_steps``, and ``difficulty`` (all
optional — the env supplies deterministic defaults when omitted)
and ignore the rest.
"""
seed = kwargs.get("seed")
scenario_id = kwargs.get("scenario_id")
enable_dba_oracle = kwargs.get("enable_dba_oracle")
budget_steps = kwargs.get("budget_steps")
difficulty = kwargs.get("difficulty")
reset_kwargs: dict[str, Any] = {}
if scenario_id is not None:
reset_kwargs["scenario_id"] = scenario_id
if enable_dba_oracle is not None:
reset_kwargs["enable_dba_oracle"] = enable_dba_oracle
if budget_steps is not None:
reset_kwargs["budget_steps"] = budget_steps
if difficulty is not None:
reset_kwargs["difficulty"] = difficulty
result = self._client.reset(
seed=int(seed) if seed is not None else None,
**reset_kwargs,
)
observation = _extract_observation(result)
self.reward = 0.0
self.terminal_reward = 0.0
self.episode_return = 0.0
self.done = False
self.submitted = False
self.drift_fired = False
self.components = {}
return _format_observation_prelude(observation)
def close(self) -> None:
"""Release the underlying client session (idempotent)."""
with _suppress_client_errors():
self._client.close()
# ------------------------------------------------------------------
# Tools — exposed to the model via TRL's tool-schema extractor.
# Each method has a typed signature + ``Args:`` docstring so the
# schema generator can produce a clean JSON tool spec.
# ------------------------------------------------------------------
def list_tables(self) -> str:
"""Enumerate every table visible in the current SQLDrift database.
Returns:
A comma-separated list of table names.
"""
return self._dispatch(SqlDriftEnv.action_list_tables())
def describe_table(self, table: str) -> str:
"""Return column names and types for a single table.
Args:
table: Name of the table to describe.
Returns:
A formatted list of ``column: type`` pairs, one per line.
"""
return self._dispatch(SqlDriftEnv.action_describe_table(table=table))
def sample_rows(self, table: str, limit: int = 5) -> str:
"""Peek at up to five rows from a table for fast schema intuition.
Args:
table: Name of the table to sample from.
limit: Number of rows to return (1-5, inclusive).
Returns:
A rendered table of the sampled rows.
"""
return self._dispatch(SqlDriftEnv.action_sample_rows(table=table, limit=int(limit)))
def run_query(self, sql: str) -> str:
"""Execute a read-only SELECT and return its result set.
Args:
sql: The SELECT statement to execute. DDL/DML is rejected.
Returns:
The rendered result table (truncated to a handful of rows).
"""
return self._dispatch(SqlDriftEnv.action_run_query(sql=sql))
def explain_query(self, sql: str) -> str:
"""Return the DuckDB query plan for a SELECT without executing it.
Args:
sql: The SELECT statement to plan.
Returns:
The textual query plan.
"""
return self._dispatch(SqlDriftEnv.action_explain_query(sql=sql))
def read_changelog(self) -> str:
"""Read every drift deploy note that has been published so far.
Returns:
The concatenated changelog entries, most recent last.
"""
return self._dispatch(SqlDriftEnv.action_read_changelog())
def submit_rewrite(self, sql: str) -> str:
"""Commit a final SELECT and end the episode.
Args:
sql: The rewritten SELECT to submit as the final answer.
Returns:
A short verdict summarising whether the rewrite matched the
ground truth and its measured runtime.
"""
return self._dispatch(SqlDriftEnv.action_submit_rewrite(sql=sql))
def consult_dba(self, question: str) -> str:
"""Ask the on-call DBA for a hint (enabled only when the DBA
oracle flag is on; each call incurs an escalating penalty).
Args:
question: Free-text question for context — the actual hint
is scenario-keyed, not question-keyed.
Returns:
The DBA's hint plus its tier.
"""
return self._dispatch(SqlDriftEnv.action_consult_dba(question=question))
# ------------------------------------------------------------------
# Internal dispatch — wraps every env.step() call, records reward
# state, and turns tool errors into plain-text returns so the
# model sees a recoverable message instead of a crash.
# ------------------------------------------------------------------
def _dispatch(self, action: Any) -> str:
if self.done:
# Episode already terminated (budget exhausted or a prior
# submit). Raising a ValueError is the TRL-sanctioned way
# to tell the model the episode is over (see the TRL
# OpenEnv integration guide on "Error handling").
raise ValueError("Episode is already finished; the environment is closed.")
result = self._client.step(action)
obs = _extract_observation(result)
self._ingest_observation(obs)
return _format_tool_result(obs)
def _ingest_observation(self, obs: Any) -> None:
reward_val = _safe_float(getattr(obs, "reward", 0.0))
self.reward = reward_val
self.episode_return += reward_val
components = getattr(obs, "reward_components", None)
if isinstance(components, dict):
self.components = {k: float(v) for k, v in components.items()}
self.drift_fired = bool(getattr(obs, "drift_fired", self.drift_fired))
done = bool(getattr(obs, "done", False))
if done:
self.done = True
self.terminal_reward = reward_val
# ``submit_rewrite`` is the only tool that ends an episode
# by setting ``submitted=True``; budget exhaustion ends it
# without submission. Surface both so reward functions can
# decide whether the rollout produced a real submission.
tool_result = getattr(obs, "tool_result", None)
submitted = getattr(tool_result, "accepted", None)
if submitted is not None:
self.submitted = bool(submitted)
# ---------------------------------------------------------------------------
# Observation rendering — kept module-private so TRL sees them as
# helpers rather than tools.
# ---------------------------------------------------------------------------
def _extract_observation(result: Any) -> Any:
"""Return the ``SqlDriftObservation`` out of an ``EnvClient`` return value.
The sync ``EnvClient`` API emits a :class:`StepResult` wrapper; the
actual observation lives on the ``.observation`` attribute. Reset
uses the same shape. We defensively fall back to the value itself
when it is already an observation (which is how our in-process
test shim reports results).
"""
observation = getattr(result, "observation", None)
return result if observation is None else observation
def _safe_float(value: Any) -> float:
"""Coerce any reward-shaped value to a finite float (0.0 on failure)."""
if value is None:
return 0.0
try:
out = float(value)
except (TypeError, ValueError):
return 0.0
return out if out == out else 0.0
def _format_observation_prelude(obs: Any) -> str:
"""Render the initial reset observation for the model."""
parts: list[str] = []
baseline = getattr(obs, "baseline_sql", "")
synopsis = getattr(obs, "schema_synopsis", "")
hints = getattr(obs, "learned_hints", "")
budget = getattr(obs, "budget_steps_remaining", None)
if synopsis:
parts.append(f"Schema synopsis:\n{synopsis}")
if baseline:
parts.append(f"Baseline query:\n{baseline}")
if budget is not None:
parts.append(f"Remaining step budget: {budget}")
if hints:
parts.append(f"Learned hints:\n{hints}")
return "\n\n".join(parts) if parts else ""
def _format_tool_result(obs: Any) -> str:
"""Render ``obs.tool_result`` as a plain-text response for the model."""
tool_result = getattr(obs, "tool_result", None)
if tool_result is None:
text = "(no tool result)"
hints = getattr(obs, "learned_hints", "")
return f"{text}\n\nLearned hints:\n{hints}" if hints else text
# Type-based dispatch keeps the code straightforward and avoids
# a heavy dependency on the Pydantic discriminator tags at
# runtime. Each branch returns a string that the model can parse.
from models import ( # local import avoids a circular hazard at module import
ConsultDBAResult,
DescribeTableResult,
ExplainQueryResult,
ListTablesResult,
ReadChangelogResult,
RunQueryResult,
SampleRowsResult,
SubmitRewriteResult,
ToolError,
)
if isinstance(tool_result, ToolError):
text = f"error[{tool_result.code.value}]: {tool_result.message}"
elif isinstance(tool_result, ListTablesResult):
text = _format_columns(tool_result.tables)
elif isinstance(tool_result, DescribeTableResult):
cols = "\n".join(f"{c['name']}: {c['type']}" for c in tool_result.columns)
text = f"table: {tool_result.table}\n{cols}"
elif isinstance(tool_result, SampleRowsResult | RunQueryResult):
text = _format_rows(tool_result.columns, tool_result.rows)
elif isinstance(tool_result, ExplainQueryResult):
text = tool_result.plan
elif isinstance(tool_result, ReadChangelogResult):
text = "\n---\n".join(tool_result.entries) if tool_result.entries else "(no changelog yet)"
elif isinstance(tool_result, SubmitRewriteResult):
verdict = "matched ground truth" if tool_result.matches_ground_truth else "did NOT match"
text = f"submitted (runtime={tool_result.runtime_ms:.2f}ms) — {verdict}"
elif isinstance(tool_result, ConsultDBAResult):
text = f"[tier {tool_result.tier}] {tool_result.hint}"
else:
text = str(tool_result)
hints = getattr(obs, "learned_hints", "")
return f"{text}\n\nLearned hints:\n{hints}" if hints else text
class _suppress_client_errors:
"""Tiny context manager — swallows teardown-time errors from the client."""
def __enter__(self) -> None:
return None
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return exc is not None # swallow any exception from close()
__all__ = [
"ClientFactory",
"SqlDriftToolEnv",
"set_client_factory",
]