sql-drift-env / training /random_agent.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""Deterministic random agent used in the CPU smoke-test loop.
The agent mirrors the tool surface exposed by
:class:`models.SqlDriftAction` but makes no LLM call; all choices
come from a seeded :class:`random.Random`. It is intentionally noisy
so ``tests/integration/test_env_random_smoke.py`` sees reward variance
(> 0.1 std across 5 rollouts × 10 scenarios) and exercises every tool.
Design notes
------------
The agent is stateful per episode: it remembers every table name it
has observed via :class:`ListTablesResult` / :class:`DescribeTableResult`
so later ``run_query`` / ``submit_rewrite`` / ``sample_rows`` actions
can target real identifiers rather than fabricated ones (which would
only ever yield ``ToolError`` with code ``unknown_table``).
The agent is not meant to score well: it is a ground-truth
"does the env crash under arbitrary-but-syntactically-valid input?"
harness.
"""
from __future__ import annotations
import random
from dataclasses import dataclass, field
from models import (
ConsultDBAPayload,
DescribeTablePayload,
DescribeTableResult,
ExplainQueryPayload,
ListTablesPayload,
ListTablesResult,
ReadChangelogPayload,
RunQueryPayload,
SampleRowsPayload,
SqlDriftAction,
SqlDriftObservation,
SubmitRewritePayload,
ToolName,
)
# Tools the agent may sample. submit_rewrite is gated to fire at most
# once per episode (otherwise the first draw ends the episode before
# any real exploration happens) — see :meth:`RandomAgent.act`.
_EXPLORATORY_TOOLS: tuple[ToolName, ...] = (
ToolName.LIST_TABLES,
ToolName.DESCRIBE_TABLE,
ToolName.SAMPLE_ROWS,
ToolName.RUN_QUERY,
ToolName.EXPLAIN_QUERY,
ToolName.READ_CHANGELOG,
ToolName.CONSULT_DBA,
)
@dataclass
class RandomAgent:
"""Seeded agent that emits valid ``SqlDriftAction`` envelopes."""
seed: int = 0
submit_probability: float = 0.08
"""Probability of drawing ``SUBMIT_REWRITE`` once at least one
table name is known. Small by design — we want diverse rollouts."""
_rng: random.Random = field(init=False)
_known_tables: list[str] = field(init=False)
_submitted: bool = field(init=False)
def __post_init__(self) -> None:
self._rng = random.Random(self.seed)
self._known_tables = []
self._submitted = False
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def reset(self, seed: int | None = None, scenario_id: str | None = None) -> None:
"""Reset per-episode state. Optionally reseed the RNG."""
if seed is not None:
self.seed = seed
self._rng = random.Random(self.seed)
self._known_tables = []
self._submitted = False
# ------------------------------------------------------------------
# Observation ingestion — harvest table names so later calls target
# real identifiers.
# ------------------------------------------------------------------
def observe(self, obs: SqlDriftObservation) -> None:
result = obs.tool_result
if isinstance(result, ListTablesResult):
for t in result.tables:
if t not in self._known_tables:
self._known_tables.append(t)
elif isinstance(result, DescribeTableResult) and (
result.table and result.table not in self._known_tables
):
self._known_tables.append(result.table)
# ------------------------------------------------------------------
# Policy
# ------------------------------------------------------------------
def act(self, obs: SqlDriftObservation) -> SqlDriftAction:
"""Return the next ``SqlDriftAction`` given ``obs``.
Exploration heuristic:
1. If we have never seen any table, draw ``LIST_TABLES`` first
so downstream tools have something to aim at.
2. With :attr:`submit_probability` (once we have table names),
emit ``SUBMIT_REWRITE`` and terminate.
3. Otherwise uniformly sample a tool from
:data:`_EXPLORATORY_TOOLS`.
"""
self.observe(obs)
if not self._known_tables:
return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload())
if not self._submitted and self._rng.random() < self.submit_probability:
self._submitted = True
return self._build_submit()
tool = self._rng.choice(_EXPLORATORY_TOOLS)
return self._build(tool)
# ------------------------------------------------------------------
# Per-tool payload builders
# ------------------------------------------------------------------
def _pick_table(self) -> str:
return self._rng.choice(self._known_tables)
def _build(self, tool: ToolName) -> SqlDriftAction:
if tool is ToolName.LIST_TABLES:
return SqlDriftAction(tool=tool, payload=ListTablesPayload())
if tool is ToolName.DESCRIBE_TABLE:
return SqlDriftAction(tool=tool, payload=DescribeTablePayload(table=self._pick_table()))
if tool is ToolName.SAMPLE_ROWS:
limit = self._rng.randint(1, 5)
return SqlDriftAction(
tool=tool,
payload=SampleRowsPayload(table=self._pick_table(), limit=limit),
)
if tool is ToolName.RUN_QUERY:
return SqlDriftAction(
tool=tool,
payload=RunQueryPayload(sql=f"SELECT * FROM {self._pick_table()} LIMIT 1"),
)
if tool is ToolName.EXPLAIN_QUERY:
return SqlDriftAction(
tool=tool,
payload=ExplainQueryPayload(sql=f"SELECT * FROM {self._pick_table()}"),
)
if tool is ToolName.READ_CHANGELOG:
return SqlDriftAction(tool=tool, payload=ReadChangelogPayload())
if tool is ToolName.CONSULT_DBA:
return SqlDriftAction(
tool=tool,
payload=ConsultDBAPayload(
question=self._rng.choice(
(
"What is the biggest anti-pattern here?",
"How do I make this faster?",
"Did anything change?",
)
)
),
)
raise RuntimeError(f"unhandled tool {tool!r}")
def _build_submit(self) -> SqlDriftAction:
# The submit is intentionally trivial: a random-table SELECT *.
# It will rarely match ground truth, but that's the point — we
# want the agent to terminate episodes so we observe the full
# reward pipeline (including the baseline-verbatim gate).
table = self._pick_table()
return SqlDriftAction(
tool=ToolName.SUBMIT_REWRITE,
payload=SubmitRewritePayload(sql=f"SELECT * FROM {table}"),
)
__all__ = ["RandomAgent"]