Spaces:
Sleeping
Sleeping
| """Single-source-of-truth system prompt for SQLDrift agents. | |
| Every component that constructs an agent context (GRPO trainer, | |
| ``random_agent``, eval harness, demo notebook) goes through | |
| :func:`render_system_prompt` so the tool surface described to the | |
| model stays in lockstep with :mod:`models.ToolName` and the tool | |
| payload schemas. | |
| The rendered string is tokenizer-agnostic: no chat-template | |
| markers, no special tokens, no role wrappers. Callers that need a | |
| chat format wrap the returned string with their tokenizer's | |
| ``apply_chat_template``. | |
| """ | |
| from __future__ import annotations | |
| from models import EpisodePhase, SqlDriftObservation, ToolName | |
| # ----------------------------------------------------------------------------- | |
| # Tool catalog — keep in sync with ``models.ToolName`` + payload schemas. | |
| # The env enforces argument shapes server-side, but the agent needs a | |
| # human-readable cheat sheet to plan its turn. | |
| # ----------------------------------------------------------------------------- | |
| TOOL_DOCS: dict[ToolName, dict[str, str]] = { | |
| ToolName.LIST_TABLES: { | |
| "signature": "list_tables()", | |
| "purpose": "Enumerate tables visible to the session (cheap, always safe).", | |
| }, | |
| ToolName.DESCRIBE_TABLE: { | |
| "signature": "describe_table(table: str)", | |
| "purpose": "Return column names + types for one table.", | |
| }, | |
| ToolName.SAMPLE_ROWS: { | |
| "signature": "sample_rows(table: str, limit: int ∈ [1, 5] = 5)", | |
| "purpose": "Peek at up to 5 rows for fast schema intuition.", | |
| }, | |
| ToolName.RUN_QUERY: { | |
| "signature": "run_query(sql: str)", | |
| "purpose": ( | |
| "Execute a read-only SELECT against the live database. " | |
| "Timing counts toward the step budget; repeat-failing queries are " | |
| "penalised." | |
| ), | |
| }, | |
| ToolName.EXPLAIN_QUERY: { | |
| "signature": "explain_query(sql: str)", | |
| "purpose": "Return the DuckDB plan for a SELECT (no execution).", | |
| }, | |
| ToolName.READ_CHANGELOG: { | |
| "signature": "read_changelog()", | |
| "purpose": ( | |
| "Read all drift-related deploy notes published so far. Always " | |
| "consult this after drift is announced in an observation." | |
| ), | |
| }, | |
| ToolName.SUBMIT_REWRITE: { | |
| "signature": "submit_rewrite(sql: str)", | |
| "purpose": ( | |
| "Commit your final SELECT. Terminates the episode. Reward requires " | |
| "the result to match ground truth AND the rewrite to be ≥1.2x " | |
| "faster than the baseline query." | |
| ), | |
| }, | |
| ToolName.CONSULT_DBA: { | |
| "signature": "consult_dba(question: str)", | |
| "purpose": ( | |
| "Ask the on-call DBA for a hint. Each consultation escalates the " | |
| "hint tier and incurs a compounding penalty; use sparingly and " | |
| "only after diagnostics." | |
| ), | |
| }, | |
| } | |
| PHASE_NUDGES: dict[EpisodePhase, str] = { | |
| EpisodePhase.DIAGNOSE: ( | |
| "You are in DIAGNOSE. Explore the schema and sample data. Do NOT submit a rewrite yet." | |
| ), | |
| EpisodePhase.REWRITE: ( | |
| "You are in REWRITE. Draft candidate queries with run_query and, " | |
| "once confident, call submit_rewrite." | |
| ), | |
| EpisodePhase.DRIFT_RECOVERY: ( | |
| "Drift has fired. Read the changelog, re-describe affected tables, " | |
| "and adapt your rewrite before submitting." | |
| ), | |
| EpisodePhase.FINALIZE: "The episode is finalizing; no further tools will help.", | |
| } | |
| SYSTEM_PROMPT_HEADER = ( | |
| "You are a senior SQL engineer operating an analytical database that is " | |
| "under live schema and business-rule drift. Your job is to repair and " | |
| "optimize a slow baseline SELECT under tight step and runtime budgets. " | |
| "Prefer read-only tools; never emit DDL or DML (INSERT/UPDATE/DELETE). " | |
| "When a changelog is published, treat it as authoritative." | |
| ) | |
| def _render_tool_catalog(dba_enabled: bool = False) -> str: | |
| lines = ["Tools available (exact JSON shapes enforced by the env):"] | |
| for tool in ToolName: | |
| if tool == ToolName.CONSULT_DBA and not dba_enabled: | |
| continue | |
| doc = TOOL_DOCS[tool] | |
| lines.append(f"- {doc['signature']}: {doc['purpose']}") | |
| return "\n".join(lines) | |
| def render_system_prompt( | |
| *, | |
| scenario_id: str, | |
| learned_hints: str = "", | |
| phase: EpisodePhase = EpisodePhase.DIAGNOSE, | |
| budget_steps_remaining: int | None = None, | |
| drift_fired: bool = False, | |
| dba_enabled: bool = False, | |
| ) -> str: | |
| """Render the per-episode system prompt. | |
| Args: | |
| scenario_id: Current scenario id (so the model sees context). | |
| learned_hints: Pre-rendered bullet list from the skill library | |
| (already capped at 800 chars by the env). | |
| phase: Current episode phase — drives the phase nudge line. | |
| budget_steps_remaining: If provided, surfaces the hard budget. | |
| drift_fired: If True, the drift-recovery nudge is reinforced. | |
| """ | |
| parts: list[str] = [SYSTEM_PROMPT_HEADER, _render_tool_catalog(dba_enabled=dba_enabled)] | |
| parts.append(f"Current scenario: {scenario_id}") | |
| if budget_steps_remaining is not None: | |
| parts.append( | |
| f"Remaining step budget: {budget_steps_remaining}. Each tool " | |
| "call costs one step; plan accordingly." | |
| ) | |
| parts.append(PHASE_NUDGES.get(phase, "")) | |
| if drift_fired: | |
| parts.append( | |
| "Drift has already fired in this episode — if you have not yet " | |
| "called read_changelog since, do that FIRST." | |
| ) | |
| if learned_hints: | |
| parts.append("Learned hints (from past episodes):\n" + learned_hints) | |
| return "\n\n".join(p for p in parts if p).strip() | |
| def render_prompt_from_observation( | |
| *, | |
| scenario_id: str, | |
| observation: SqlDriftObservation, | |
| ) -> str: | |
| """Convenience wrapper: pull phase / hints / budget from an observation.""" | |
| return render_system_prompt( | |
| scenario_id=scenario_id, | |
| learned_hints=observation.learned_hints, | |
| phase=observation.phase, | |
| budget_steps_remaining=observation.budget_steps_remaining, | |
| drift_fired=observation.drift_fired, | |
| ) | |
| __all__ = [ | |
| "PHASE_NUDGES", | |
| "SYSTEM_PROMPT_HEADER", | |
| "TOOL_DOCS", | |
| "render_prompt_from_observation", | |
| "render_system_prompt", | |
| ] | |