sql-drift-env / training /prompt.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""Single-source-of-truth system prompt for SQLDrift agents.
Every component that constructs an agent context (GRPO trainer,
``random_agent``, eval harness, demo notebook) goes through
:func:`render_system_prompt` so the tool surface described to the
model stays in lockstep with :mod:`models.ToolName` and the tool
payload schemas.
The rendered string is tokenizer-agnostic: no chat-template
markers, no special tokens, no role wrappers. Callers that need a
chat format wrap the returned string with their tokenizer's
``apply_chat_template``.
"""
from __future__ import annotations
from models import EpisodePhase, SqlDriftObservation, ToolName
# -----------------------------------------------------------------------------
# Tool catalog — keep in sync with ``models.ToolName`` + payload schemas.
# The env enforces argument shapes server-side, but the agent needs a
# human-readable cheat sheet to plan its turn.
# -----------------------------------------------------------------------------
TOOL_DOCS: dict[ToolName, dict[str, str]] = {
ToolName.LIST_TABLES: {
"signature": "list_tables()",
"purpose": "Enumerate tables visible to the session (cheap, always safe).",
},
ToolName.DESCRIBE_TABLE: {
"signature": "describe_table(table: str)",
"purpose": "Return column names + types for one table.",
},
ToolName.SAMPLE_ROWS: {
"signature": "sample_rows(table: str, limit: int ∈ [1, 5] = 5)",
"purpose": "Peek at up to 5 rows for fast schema intuition.",
},
ToolName.RUN_QUERY: {
"signature": "run_query(sql: str)",
"purpose": (
"Execute a read-only SELECT against the live database. "
"Timing counts toward the step budget; repeat-failing queries are "
"penalised."
),
},
ToolName.EXPLAIN_QUERY: {
"signature": "explain_query(sql: str)",
"purpose": "Return the DuckDB plan for a SELECT (no execution).",
},
ToolName.READ_CHANGELOG: {
"signature": "read_changelog()",
"purpose": (
"Read all drift-related deploy notes published so far. Always "
"consult this after drift is announced in an observation."
),
},
ToolName.SUBMIT_REWRITE: {
"signature": "submit_rewrite(sql: str)",
"purpose": (
"Commit your final SELECT. Terminates the episode. Reward requires "
"the result to match ground truth AND the rewrite to be ≥1.2x "
"faster than the baseline query."
),
},
ToolName.CONSULT_DBA: {
"signature": "consult_dba(question: str)",
"purpose": (
"Ask the on-call DBA for a hint. Each consultation escalates the "
"hint tier and incurs a compounding penalty; use sparingly and "
"only after diagnostics."
),
},
}
PHASE_NUDGES: dict[EpisodePhase, str] = {
EpisodePhase.DIAGNOSE: (
"You are in DIAGNOSE. Explore the schema and sample data. Do NOT submit a rewrite yet."
),
EpisodePhase.REWRITE: (
"You are in REWRITE. Draft candidate queries with run_query and, "
"once confident, call submit_rewrite."
),
EpisodePhase.DRIFT_RECOVERY: (
"Drift has fired. Read the changelog, re-describe affected tables, "
"and adapt your rewrite before submitting."
),
EpisodePhase.FINALIZE: "The episode is finalizing; no further tools will help.",
}
SYSTEM_PROMPT_HEADER = (
"You are a senior SQL engineer operating an analytical database that is "
"under live schema and business-rule drift. Your job is to repair and "
"optimize a slow baseline SELECT under tight step and runtime budgets. "
"Prefer read-only tools; never emit DDL or DML (INSERT/UPDATE/DELETE). "
"When a changelog is published, treat it as authoritative."
)
def _render_tool_catalog(dba_enabled: bool = False) -> str:
lines = ["Tools available (exact JSON shapes enforced by the env):"]
for tool in ToolName:
if tool == ToolName.CONSULT_DBA and not dba_enabled:
continue
doc = TOOL_DOCS[tool]
lines.append(f"- {doc['signature']}: {doc['purpose']}")
return "\n".join(lines)
def render_system_prompt(
*,
scenario_id: str,
learned_hints: str = "",
phase: EpisodePhase = EpisodePhase.DIAGNOSE,
budget_steps_remaining: int | None = None,
drift_fired: bool = False,
dba_enabled: bool = False,
) -> str:
"""Render the per-episode system prompt.
Args:
scenario_id: Current scenario id (so the model sees context).
learned_hints: Pre-rendered bullet list from the skill library
(already capped at 800 chars by the env).
phase: Current episode phase — drives the phase nudge line.
budget_steps_remaining: If provided, surfaces the hard budget.
drift_fired: If True, the drift-recovery nudge is reinforced.
"""
parts: list[str] = [SYSTEM_PROMPT_HEADER, _render_tool_catalog(dba_enabled=dba_enabled)]
parts.append(f"Current scenario: {scenario_id}")
if budget_steps_remaining is not None:
parts.append(
f"Remaining step budget: {budget_steps_remaining}. Each tool "
"call costs one step; plan accordingly."
)
parts.append(PHASE_NUDGES.get(phase, ""))
if drift_fired:
parts.append(
"Drift has already fired in this episode — if you have not yet "
"called read_changelog since, do that FIRST."
)
if learned_hints:
parts.append("Learned hints (from past episodes):\n" + learned_hints)
return "\n\n".join(p for p in parts if p).strip()
def render_prompt_from_observation(
*,
scenario_id: str,
observation: SqlDriftObservation,
) -> str:
"""Convenience wrapper: pull phase / hints / budget from an observation."""
return render_system_prompt(
scenario_id=scenario_id,
learned_hints=observation.learned_hints,
phase=observation.phase,
budget_steps_remaining=observation.budget_steps_remaining,
drift_fired=observation.drift_fired,
)
__all__ = [
"PHASE_NUDGES",
"SYSTEM_PROMPT_HEADER",
"TOOL_DOCS",
"render_prompt_from_observation",
"render_system_prompt",
]