sql-drift-env / utilities /verbose_api_rollout.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""Run one SQLDrift episode with an OpenAI-compatible chat endpoint.
Usage:
uv run python utilities/verbose_api_rollout.py --scenario 07_drift_column_rename
Configuration is read from the repo-local ``.env`` file:
* ``SQL_DRIFT_API_KEY``
* ``SQL_DRIFT_API_BASE_URL``
* ``SQL_DRIFT_API_MODEL``
The environment itself runs in-process, so you do NOT need to start the
OpenEnv server just to watch one episode.
"""
from __future__ import annotations
import argparse
import sys
import textwrap
from pathlib import Path
from typing import cast
# Make the repo root importable when the script is run as
# ``python utilities/verbose_api_rollout.py`` from any cwd.
_REPO_ROOT = Path(__file__).resolve().parent.parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from models import SqlDriftAction, SqlDriftObservation
from server import SqlDriftEnvironment
from training.llm_agent import (
_TOOL_CONTRACT,
_canonicalise_action,
_parse_completion_as_action,
_summarise_tool_result,
)
from training.prompt import render_system_prompt
from utilities.env_loader import env_str
from utilities.logger import log_interaction
class ApiEpisodeAgent:
"""Minimal chat agent backed by an OpenAI-compatible endpoint."""
def __init__(
self,
*,
model: str,
api_key: str,
base_url: str,
temperature: float,
max_tokens: int,
history_turns: int,
) -> None:
self._client = OpenAI(api_key=api_key, base_url=base_url)
self._model = model
self._temperature = temperature
self._max_tokens = max_tokens
self._history_turns = max(history_turns, 1)
self._history: list[dict[str, str]] = []
self._scenario_id = "unknown"
self._system_prompt = ""
self.last_completion = ""
self.last_parsed_ok = False
def reset(self, *, scenario_id: str) -> None:
self._history = []
self._scenario_id = scenario_id
self._system_prompt = ""
self.last_completion = ""
self.last_parsed_ok = False
def act(self, obs: SqlDriftObservation) -> SqlDriftAction:
if not self._system_prompt:
self._system_prompt = self._initial_system_prompt(obs)
user_message = self._render_user_message(obs)
completion = self._generate(user_message)
action, parsed_ok = _parse_completion_as_action(completion)
self.last_completion = completion
self.last_parsed_ok = parsed_ok
self._history.append({"role": "user", "content": user_message})
self._history.append(
{
"role": "assistant",
"content": completion if parsed_ok else _canonicalise_action(action),
}
)
self._trim_history()
return action
def _initial_system_prompt(self, obs: SqlDriftObservation) -> str:
base = render_system_prompt(
scenario_id=self._scenario_id,
learned_hints=obs.learned_hints,
phase=obs.phase,
budget_steps_remaining=obs.budget_steps_remaining,
drift_fired=obs.drift_fired,
)
task_block = ""
if obs.schema_synopsis:
task_block += f"\n\nSchema synopsis:\n{obs.schema_synopsis}"
if obs.baseline_sql:
task_block += f"\n\nBaseline query:\n{obs.baseline_sql}"
return f"{base}{task_block}\n\n{_TOOL_CONTRACT}"
def _render_user_message(self, obs: SqlDriftObservation) -> str:
parts: list[str] = []
if obs.drift_fired and not self._history_mentions_drift():
parts.append("Drift has fired.")
parts.append(f"Remaining steps: {obs.budget_steps_remaining}")
tool_summary = _summarise_tool_result(obs)
if tool_summary:
parts.append(tool_summary)
if obs.learned_hints:
parts.append("Learned hints:\n" + obs.learned_hints)
else:
parts.append("Pick the next tool call.")
return "\n".join(parts)
def _history_mentions_drift(self) -> bool:
for msg in reversed(self._history):
if msg["role"] == "user":
return "Drift has fired." in msg["content"]
return False
def _generate(self, user_message: str) -> str:
messages = cast(
list[ChatCompletionMessageParam],
[{"role": "system", "content": self._system_prompt}]
+ self._history
+ [{"role": "user", "content": user_message}],
)
try:
response = self._client.chat.completions.create(
model=self._model,
messages=messages,
temperature=self._temperature,
max_tokens=self._max_tokens,
)
except Exception as exc:
log_interaction(
event_type="llm_call",
agent_id=self._agent_id(),
llm_prompt=messages,
error=repr(exc),
)
raise
content = response.choices[0].message.content
if content is None:
text = ""
elif isinstance(content, str):
text = content.strip()
else:
text = str(content).strip()
log_interaction(
event_type="llm_call",
agent_id=self._agent_id(),
llm_prompt=messages,
llm_response=text,
)
return text
def _trim_history(self) -> None:
max_messages = self._history_turns * 2
if len(self._history) > max_messages:
self._history = self._history[-max_messages:]
def _agent_id(self) -> str:
return f"api_agent:{self._scenario_id}:{self._model}"
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Verbose one-episode API rollout for SQLDrift.")
parser.add_argument("--scenario", default="07_drift_column_rename")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--difficulty", choices=("easy", "normal", "hard"), default="normal")
parser.add_argument("--budget-steps", type=int, default=25)
parser.add_argument("--max-steps", type=int, default=25)
parser.add_argument("--temperature", type=float, default=0.1)
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument("--history-turns", type=int, default=6)
parser.add_argument("--enable-dba-oracle", action="store_true")
return parser.parse_args()
def _required_env(name: str) -> str:
value = env_str(name, "")
if not value:
raise RuntimeError(
f"Missing {name}. Set it in `.env` (see `.env.example`) or export it before running."
)
return value
def _print_header(args: argparse.Namespace, model: str, base_url: str) -> None:
print("=" * 88)
print("SQLDrift verbose API rollout")
print("=" * 88)
print(
f"scenario={args.scenario} seed={args.seed} difficulty={args.difficulty} "
f"budget_steps={args.budget_steps}"
)
print(f"model={model}")
print(f"base_url={base_url}")
def _print_reset(obs: SqlDriftObservation) -> None:
print("\n[reset]")
print(f"phase={obs.phase} budget={obs.budget_steps_remaining} drift_fired={obs.drift_fired}")
if obs.schema_synopsis:
print("\nSchema synopsis:")
print(obs.schema_synopsis)
if obs.baseline_sql:
print("\nBaseline SQL:")
print(obs.baseline_sql)
if obs.learned_hints:
print("\nLearned hints:")
print(obs.learned_hints)
def _print_step(
step: int, agent: ApiEpisodeAgent, action: SqlDriftAction, obs: SqlDriftObservation
) -> None:
print("\n" + "-" * 88)
print(f"[step {step:02d}] model completion")
print(textwrap.indent(agent.last_completion or "(empty)", " "))
print(f"[step {step:02d}] parsed_ok={agent.last_parsed_ok}")
print(f"[step {step:02d}] action={action.model_dump(mode='json')}")
print(
f"[step {step:02d}] phase={obs.phase} drift_fired={obs.drift_fired} "
f"budget={obs.budget_steps_remaining} reward={obs.reward}"
)
summary = _summarise_tool_result(obs) or "(no tool result)"
print(f"[step {step:02d}] result:")
print(textwrap.indent(summary, " "))
if obs.learned_hints:
print(f"[step {step:02d}] learned_hints:")
print(textwrap.indent(obs.learned_hints, " "))
if obs.reward_components:
print(f"[step {step:02d}] reward_components={obs.reward_components}")
def _print_final(env: SqlDriftEnvironment, obs: SqlDriftObservation) -> None:
print("\n" + "=" * 88)
print("Episode finished")
print("=" * 88)
print(f"done={obs.done} reward={obs.reward}")
print(f"reward_components={obs.reward_components}")
print(f"effective_speedup={env.effective_speedup()}")
print(f"final_state={env.state.model_dump()}")
def main() -> None:
args = _parse_args()
model = _required_env("SQL_DRIFT_API_MODEL")
base_url = _required_env("SQL_DRIFT_API_BASE_URL")
api_key = _required_env("SQL_DRIFT_API_KEY")
agent = ApiEpisodeAgent(
model=model,
api_key=api_key,
base_url=base_url,
temperature=args.temperature,
max_tokens=args.max_tokens,
history_turns=args.history_turns,
)
env = SqlDriftEnvironment()
try:
obs = env.reset(
seed=args.seed,
scenario_id=args.scenario,
difficulty=args.difficulty,
budget_steps=args.budget_steps,
enable_dba_oracle=args.enable_dba_oracle,
)
agent.reset(scenario_id=args.scenario)
_print_header(args, model, base_url)
_print_reset(obs)
for step in range(1, args.max_steps + 1):
if obs.done:
break
action = agent.act(obs)
obs = env.step(action)
_print_step(step, agent, action, obs)
_print_final(env, obs)
finally:
env.close()
if __name__ == "__main__":
main()