Spaces:
Sleeping
Sleeping
File size: 7,027 Bytes
5850885 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """Deterministic random agent used in the CPU smoke-test loop.
The agent mirrors the tool surface exposed by
:class:`models.SqlDriftAction` but makes no LLM call; all choices
come from a seeded :class:`random.Random`. It is intentionally noisy
so ``tests/integration/test_env_random_smoke.py`` sees reward variance
(> 0.1 std across 5 rollouts × 10 scenarios) and exercises every tool.
Design notes
------------
The agent is stateful per episode: it remembers every table name it
has observed via :class:`ListTablesResult` / :class:`DescribeTableResult`
so later ``run_query`` / ``submit_rewrite`` / ``sample_rows`` actions
can target real identifiers rather than fabricated ones (which would
only ever yield ``ToolError`` with code ``unknown_table``).
The agent is not meant to score well: it is a ground-truth
"does the env crash under arbitrary-but-syntactically-valid input?"
harness.
"""
from __future__ import annotations
import random
from dataclasses import dataclass, field
from models import (
ConsultDBAPayload,
DescribeTablePayload,
DescribeTableResult,
ExplainQueryPayload,
ListTablesPayload,
ListTablesResult,
ReadChangelogPayload,
RunQueryPayload,
SampleRowsPayload,
SqlDriftAction,
SqlDriftObservation,
SubmitRewritePayload,
ToolName,
)
# Tools the agent may sample. submit_rewrite is gated to fire at most
# once per episode (otherwise the first draw ends the episode before
# any real exploration happens) — see :meth:`RandomAgent.act`.
_EXPLORATORY_TOOLS: tuple[ToolName, ...] = (
ToolName.LIST_TABLES,
ToolName.DESCRIBE_TABLE,
ToolName.SAMPLE_ROWS,
ToolName.RUN_QUERY,
ToolName.EXPLAIN_QUERY,
ToolName.READ_CHANGELOG,
ToolName.CONSULT_DBA,
)
@dataclass
class RandomAgent:
"""Seeded agent that emits valid ``SqlDriftAction`` envelopes."""
seed: int = 0
submit_probability: float = 0.08
"""Probability of drawing ``SUBMIT_REWRITE`` once at least one
table name is known. Small by design — we want diverse rollouts."""
_rng: random.Random = field(init=False)
_known_tables: list[str] = field(init=False)
_submitted: bool = field(init=False)
def __post_init__(self) -> None:
self._rng = random.Random(self.seed)
self._known_tables = []
self._submitted = False
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def reset(self, seed: int | None = None, scenario_id: str | None = None) -> None:
"""Reset per-episode state. Optionally reseed the RNG."""
if seed is not None:
self.seed = seed
self._rng = random.Random(self.seed)
self._known_tables = []
self._submitted = False
# ------------------------------------------------------------------
# Observation ingestion — harvest table names so later calls target
# real identifiers.
# ------------------------------------------------------------------
def observe(self, obs: SqlDriftObservation) -> None:
result = obs.tool_result
if isinstance(result, ListTablesResult):
for t in result.tables:
if t not in self._known_tables:
self._known_tables.append(t)
elif isinstance(result, DescribeTableResult) and (
result.table and result.table not in self._known_tables
):
self._known_tables.append(result.table)
# ------------------------------------------------------------------
# Policy
# ------------------------------------------------------------------
def act(self, obs: SqlDriftObservation) -> SqlDriftAction:
"""Return the next ``SqlDriftAction`` given ``obs``.
Exploration heuristic:
1. If we have never seen any table, draw ``LIST_TABLES`` first
so downstream tools have something to aim at.
2. With :attr:`submit_probability` (once we have table names),
emit ``SUBMIT_REWRITE`` and terminate.
3. Otherwise uniformly sample a tool from
:data:`_EXPLORATORY_TOOLS`.
"""
self.observe(obs)
if not self._known_tables:
return SqlDriftAction(tool=ToolName.LIST_TABLES, payload=ListTablesPayload())
if not self._submitted and self._rng.random() < self.submit_probability:
self._submitted = True
return self._build_submit()
tool = self._rng.choice(_EXPLORATORY_TOOLS)
return self._build(tool)
# ------------------------------------------------------------------
# Per-tool payload builders
# ------------------------------------------------------------------
def _pick_table(self) -> str:
return self._rng.choice(self._known_tables)
def _build(self, tool: ToolName) -> SqlDriftAction:
if tool is ToolName.LIST_TABLES:
return SqlDriftAction(tool=tool, payload=ListTablesPayload())
if tool is ToolName.DESCRIBE_TABLE:
return SqlDriftAction(tool=tool, payload=DescribeTablePayload(table=self._pick_table()))
if tool is ToolName.SAMPLE_ROWS:
limit = self._rng.randint(1, 5)
return SqlDriftAction(
tool=tool,
payload=SampleRowsPayload(table=self._pick_table(), limit=limit),
)
if tool is ToolName.RUN_QUERY:
return SqlDriftAction(
tool=tool,
payload=RunQueryPayload(sql=f"SELECT * FROM {self._pick_table()} LIMIT 1"),
)
if tool is ToolName.EXPLAIN_QUERY:
return SqlDriftAction(
tool=tool,
payload=ExplainQueryPayload(sql=f"SELECT * FROM {self._pick_table()}"),
)
if tool is ToolName.READ_CHANGELOG:
return SqlDriftAction(tool=tool, payload=ReadChangelogPayload())
if tool is ToolName.CONSULT_DBA:
return SqlDriftAction(
tool=tool,
payload=ConsultDBAPayload(
question=self._rng.choice(
(
"What is the biggest anti-pattern here?",
"How do I make this faster?",
"Did anything change?",
)
)
),
)
raise RuntimeError(f"unhandled tool {tool!r}")
def _build_submit(self) -> SqlDriftAction:
# The submit is intentionally trivial: a random-table SELECT *.
# It will rarely match ground truth, but that's the point — we
# want the agent to terminate episodes so we observe the full
# reward pipeline (including the baseline-verbatim gate).
table = self._pick_table()
return SqlDriftAction(
tool=ToolName.SUBMIT_REWRITE,
payload=SubmitRewritePayload(sql=f"SELECT * FROM {table}"),
)
__all__ = ["RandomAgent"]
|