Spaces:
Sleeping
Sleeping
| """Run one SQLDrift episode with an OpenAI-compatible chat endpoint. | |
| Usage: | |
| uv run python utilities/verbose_api_rollout.py --scenario 07_drift_column_rename | |
| Configuration is read from the repo-local ``.env`` file: | |
| * ``SQL_DRIFT_API_KEY`` | |
| * ``SQL_DRIFT_API_BASE_URL`` | |
| * ``SQL_DRIFT_API_MODEL`` | |
| The environment itself runs in-process, so you do NOT need to start the | |
| OpenEnv server just to watch one episode. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| import textwrap | |
| from pathlib import Path | |
| from typing import cast | |
| # Make the repo root importable when the script is run as | |
| # ``python utilities/verbose_api_rollout.py`` from any cwd. | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletionMessageParam | |
| from models import SqlDriftAction, SqlDriftObservation | |
| from server import SqlDriftEnvironment | |
| from training.llm_agent import ( | |
| _TOOL_CONTRACT, | |
| _canonicalise_action, | |
| _parse_completion_as_action, | |
| _summarise_tool_result, | |
| ) | |
| from training.prompt import render_system_prompt | |
| from utilities.env_loader import env_str | |
| from utilities.logger import log_interaction | |
| class ApiEpisodeAgent: | |
| """Minimal chat agent backed by an OpenAI-compatible endpoint.""" | |
| def __init__( | |
| self, | |
| *, | |
| model: str, | |
| api_key: str, | |
| base_url: str, | |
| temperature: float, | |
| max_tokens: int, | |
| history_turns: int, | |
| ) -> None: | |
| self._client = OpenAI(api_key=api_key, base_url=base_url) | |
| self._model = model | |
| self._temperature = temperature | |
| self._max_tokens = max_tokens | |
| self._history_turns = max(history_turns, 1) | |
| self._history: list[dict[str, str]] = [] | |
| self._scenario_id = "unknown" | |
| self._system_prompt = "" | |
| self.last_completion = "" | |
| self.last_parsed_ok = False | |
| def reset(self, *, scenario_id: str) -> None: | |
| self._history = [] | |
| self._scenario_id = scenario_id | |
| self._system_prompt = "" | |
| self.last_completion = "" | |
| self.last_parsed_ok = False | |
| def act(self, obs: SqlDriftObservation) -> SqlDriftAction: | |
| if not self._system_prompt: | |
| self._system_prompt = self._initial_system_prompt(obs) | |
| user_message = self._render_user_message(obs) | |
| completion = self._generate(user_message) | |
| action, parsed_ok = _parse_completion_as_action(completion) | |
| self.last_completion = completion | |
| self.last_parsed_ok = parsed_ok | |
| self._history.append({"role": "user", "content": user_message}) | |
| self._history.append( | |
| { | |
| "role": "assistant", | |
| "content": completion if parsed_ok else _canonicalise_action(action), | |
| } | |
| ) | |
| self._trim_history() | |
| return action | |
| def _initial_system_prompt(self, obs: SqlDriftObservation) -> str: | |
| base = render_system_prompt( | |
| scenario_id=self._scenario_id, | |
| learned_hints=obs.learned_hints, | |
| phase=obs.phase, | |
| budget_steps_remaining=obs.budget_steps_remaining, | |
| drift_fired=obs.drift_fired, | |
| ) | |
| task_block = "" | |
| if obs.schema_synopsis: | |
| task_block += f"\n\nSchema synopsis:\n{obs.schema_synopsis}" | |
| if obs.baseline_sql: | |
| task_block += f"\n\nBaseline query:\n{obs.baseline_sql}" | |
| return f"{base}{task_block}\n\n{_TOOL_CONTRACT}" | |
| def _render_user_message(self, obs: SqlDriftObservation) -> str: | |
| parts: list[str] = [] | |
| if obs.drift_fired and not self._history_mentions_drift(): | |
| parts.append("Drift has fired.") | |
| parts.append(f"Remaining steps: {obs.budget_steps_remaining}") | |
| tool_summary = _summarise_tool_result(obs) | |
| if tool_summary: | |
| parts.append(tool_summary) | |
| if obs.learned_hints: | |
| parts.append("Learned hints:\n" + obs.learned_hints) | |
| else: | |
| parts.append("Pick the next tool call.") | |
| return "\n".join(parts) | |
| def _history_mentions_drift(self) -> bool: | |
| for msg in reversed(self._history): | |
| if msg["role"] == "user": | |
| return "Drift has fired." in msg["content"] | |
| return False | |
| def _generate(self, user_message: str) -> str: | |
| messages = cast( | |
| list[ChatCompletionMessageParam], | |
| [{"role": "system", "content": self._system_prompt}] | |
| + self._history | |
| + [{"role": "user", "content": user_message}], | |
| ) | |
| try: | |
| response = self._client.chat.completions.create( | |
| model=self._model, | |
| messages=messages, | |
| temperature=self._temperature, | |
| max_tokens=self._max_tokens, | |
| ) | |
| except Exception as exc: | |
| log_interaction( | |
| event_type="llm_call", | |
| agent_id=self._agent_id(), | |
| llm_prompt=messages, | |
| error=repr(exc), | |
| ) | |
| raise | |
| content = response.choices[0].message.content | |
| if content is None: | |
| text = "" | |
| elif isinstance(content, str): | |
| text = content.strip() | |
| else: | |
| text = str(content).strip() | |
| log_interaction( | |
| event_type="llm_call", | |
| agent_id=self._agent_id(), | |
| llm_prompt=messages, | |
| llm_response=text, | |
| ) | |
| return text | |
| def _trim_history(self) -> None: | |
| max_messages = self._history_turns * 2 | |
| if len(self._history) > max_messages: | |
| self._history = self._history[-max_messages:] | |
| def _agent_id(self) -> str: | |
| return f"api_agent:{self._scenario_id}:{self._model}" | |
| def _parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Verbose one-episode API rollout for SQLDrift.") | |
| parser.add_argument("--scenario", default="07_drift_column_rename") | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument("--difficulty", choices=("easy", "normal", "hard"), default="normal") | |
| parser.add_argument("--budget-steps", type=int, default=25) | |
| parser.add_argument("--max-steps", type=int, default=25) | |
| parser.add_argument("--temperature", type=float, default=0.1) | |
| parser.add_argument("--max-tokens", type=int, default=256) | |
| parser.add_argument("--history-turns", type=int, default=6) | |
| parser.add_argument("--enable-dba-oracle", action="store_true") | |
| return parser.parse_args() | |
| def _required_env(name: str) -> str: | |
| value = env_str(name, "") | |
| if not value: | |
| raise RuntimeError( | |
| f"Missing {name}. Set it in `.env` (see `.env.example`) or export it before running." | |
| ) | |
| return value | |
| def _print_header(args: argparse.Namespace, model: str, base_url: str) -> None: | |
| print("=" * 88) | |
| print("SQLDrift verbose API rollout") | |
| print("=" * 88) | |
| print( | |
| f"scenario={args.scenario} seed={args.seed} difficulty={args.difficulty} " | |
| f"budget_steps={args.budget_steps}" | |
| ) | |
| print(f"model={model}") | |
| print(f"base_url={base_url}") | |
| def _print_reset(obs: SqlDriftObservation) -> None: | |
| print("\n[reset]") | |
| print(f"phase={obs.phase} budget={obs.budget_steps_remaining} drift_fired={obs.drift_fired}") | |
| if obs.schema_synopsis: | |
| print("\nSchema synopsis:") | |
| print(obs.schema_synopsis) | |
| if obs.baseline_sql: | |
| print("\nBaseline SQL:") | |
| print(obs.baseline_sql) | |
| if obs.learned_hints: | |
| print("\nLearned hints:") | |
| print(obs.learned_hints) | |
| def _print_step( | |
| step: int, agent: ApiEpisodeAgent, action: SqlDriftAction, obs: SqlDriftObservation | |
| ) -> None: | |
| print("\n" + "-" * 88) | |
| print(f"[step {step:02d}] model completion") | |
| print(textwrap.indent(agent.last_completion or "(empty)", " ")) | |
| print(f"[step {step:02d}] parsed_ok={agent.last_parsed_ok}") | |
| print(f"[step {step:02d}] action={action.model_dump(mode='json')}") | |
| print( | |
| f"[step {step:02d}] phase={obs.phase} drift_fired={obs.drift_fired} " | |
| f"budget={obs.budget_steps_remaining} reward={obs.reward}" | |
| ) | |
| summary = _summarise_tool_result(obs) or "(no tool result)" | |
| print(f"[step {step:02d}] result:") | |
| print(textwrap.indent(summary, " ")) | |
| if obs.learned_hints: | |
| print(f"[step {step:02d}] learned_hints:") | |
| print(textwrap.indent(obs.learned_hints, " ")) | |
| if obs.reward_components: | |
| print(f"[step {step:02d}] reward_components={obs.reward_components}") | |
| def _print_final(env: SqlDriftEnvironment, obs: SqlDriftObservation) -> None: | |
| print("\n" + "=" * 88) | |
| print("Episode finished") | |
| print("=" * 88) | |
| print(f"done={obs.done} reward={obs.reward}") | |
| print(f"reward_components={obs.reward_components}") | |
| print(f"effective_speedup={env.effective_speedup()}") | |
| print(f"final_state={env.state.model_dump()}") | |
| def main() -> None: | |
| args = _parse_args() | |
| model = _required_env("SQL_DRIFT_API_MODEL") | |
| base_url = _required_env("SQL_DRIFT_API_BASE_URL") | |
| api_key = _required_env("SQL_DRIFT_API_KEY") | |
| agent = ApiEpisodeAgent( | |
| model=model, | |
| api_key=api_key, | |
| base_url=base_url, | |
| temperature=args.temperature, | |
| max_tokens=args.max_tokens, | |
| history_turns=args.history_turns, | |
| ) | |
| env = SqlDriftEnvironment() | |
| try: | |
| obs = env.reset( | |
| seed=args.seed, | |
| scenario_id=args.scenario, | |
| difficulty=args.difficulty, | |
| budget_steps=args.budget_steps, | |
| enable_dba_oracle=args.enable_dba_oracle, | |
| ) | |
| agent.reset(scenario_id=args.scenario) | |
| _print_header(args, model, base_url) | |
| _print_reset(obs) | |
| for step in range(1, args.max_steps + 1): | |
| if obs.done: | |
| break | |
| action = agent.act(obs) | |
| obs = env.step(action) | |
| _print_step(step, agent, action, obs) | |
| _print_final(env, obs) | |
| finally: | |
| env.close() | |
| if __name__ == "__main__": | |
| main() | |