sql_env / training /prompts.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
Raw
History Blame Contribute Delete
2.84 kB
"""Prompt helpers for GRPO training rollouts.
Canonical system prompt and observation formatting for SQL exploration.
The system prompt matches the SFT/GRPO tool-calling format used by
``scripts/generate_sft_data.py`` and ``notebooks/train_grpo.ipynb``.
"""
try:
from sql_env.models import SQLObservation
except ImportError:
from models import SQLObservation
_MAX_RESULT_CHARS = 2000
_SYSTEM_PROMPT = (
"You answer questions about a SQL database. "
"Use ONLY the provided tools.\n\n"
"Strategy:\n"
"1. Call describe(table_name=...) to see columns\n"
"2. Call query(sql=...) to run SELECT queries\n"
"3. Call answer(value=...) to submit your final answer\n\n"
"Answer format: submit ONLY the data values from your query result.\n"
"- Single value: 42 or ford\n"
"- Multiple values: alice, bob, charlie\n"
"- Table rows: col1 | col2 (one row per line)\n"
"- No results: []\n\n"
"IMPORTANT: Call only ONE tool at a time, then read the "
"response before deciding what to do next."
)
def get_system_prompt(*, enable_thinking: bool = False) -> str:
"""Return the SQL exploration system prompt.
Parameters
----------
enable_thinking
When False (default), prepends ``/no_think`` to disable
Qwen3 thinking mode. When True, returns prompt as-is.
Returns
-------
str
Deterministic prompt text describing tool-calling strategy.
"""
if enable_thinking:
return _SYSTEM_PROMPT
return "/no_think\n" + _SYSTEM_PROMPT
def format_observation(obs: SQLObservation) -> str:
"""Format an observation into a model-ready user turn.
Parameters
----------
obs
Environment observation to serialize for the language model.
Returns
-------
str
Human-readable observation context including question, schema,
latest result/error, and remaining budget.
"""
result_text = obs.result or "(empty)"
if len(result_text) > _MAX_RESULT_CHARS:
result_text = f"{result_text[:_MAX_RESULT_CHARS]}... [truncated]"
lines = [
f"Question: {obs.question}",
"",
"Schema:",
obs.schema_info or "(none)",
"",
"Last Result:",
result_text,
]
if obs.error:
lines.extend(["", f"Error: {obs.error}"])
if obs.action_history:
lines.extend(["", "Action History:"])
lines.extend(f"- {entry}" for entry in obs.action_history)
lines.extend(
[
"",
f"Step: {obs.step_count}",
f"Budget Remaining: {obs.budget_remaining}",
f"Done: {obs.done}",
]
)
if obs.done:
reward_text = "None" if obs.reward is None else str(obs.reward)
lines.append(f"Final Reward: {reward_text}")
return "\n".join(lines)