sql-drift-env / training /grpo_train.py
visheshrathi's picture
Upload folder using huggingface_hub
bbf206f verified
"""GRPO training entrypoint wiring for SQLDrift.
This module hosts the building blocks for training a Qwen3-class
model (default: ``Qwen/Qwen3-4B-Instruct-2507``) against the SQLDrift
OpenEnv environment with :class:`trl.GRPOTrainer`:
* :func:`iter_curriculum` — pure, dependency-free scenario sampler
used both by :func:`build_dataset` and by unit tests.
* :func:`build_dataset` — turns the curriculum iterator into a
Hugging Face :class:`datasets.Dataset` whose rows carry the per-
episode seed / scenario id that :class:`.tool_env.SqlDriftToolEnv`
consumes in its ``reset(**kwargs)``.
* :func:`load_model_and_tokenizer` — lazy Unsloth loader that keeps
the CUDA dependency tree out of the module top-level so CPU-only
CI can still import this file.
* :func:`build_env_client` — sanctioned way to obtain the OpenEnv
client bound to a running SQLDrift server (used by the eval
harness and notebooks — the trainer builds its own clients via
``environment_factory``).
* :func:`reward_from_environments` — a TRL-compatible reward
function that reads the cumulative trajectory return off each
rolled-out :class:`SqlDriftToolEnv` instance.
* :func:`train` — the real entrypoint: builds a curriculum dataset,
loads the model, instantiates :class:`trl.GRPOTrainer` with
an ``environment_factory`` bound to :class:`SqlDriftToolEnv` and the
caller-supplied env URL (the TRL-sanctioned multi-turn OpenEnv
rollout path, see the TRL OpenEnv integration guide) and runs
``trainer.train()``. All heavy imports are lazy so CPU-only CI can
still import the module.
"""
from __future__ import annotations
import json
from collections.abc import Iterator, Sequence
from dataclasses import asdict
from functools import partial
from pathlib import Path
from random import Random
from typing import Any
from training.config import GRPOConfig
from training.prompt import render_system_prompt
from training.seeding import set_seed
from training.tool_env import SqlDriftToolEnv
from utilities.logger import get_module_logger
_LOG = get_module_logger(__name__)
def iter_curriculum(config: GRPOConfig, *, seed: int = 0) -> Iterator[tuple[str, int]]:
"""Yield an infinite stream of ``(scenario_id, episode_seed)`` tuples."""
rng = Random(seed)
curr = config.curriculum
lo, hi = curr.seed_range
i = 0
while True:
if curr.mode == "uniform":
scenario = rng.choice(curr.scenarios)
elif curr.mode == "weighted":
scenario = rng.choices(curr.scenarios, weights=list(curr.weights or ()), k=1)[0]
else:
scenario = curr.scenarios[i % len(curr.scenarios)]
yield scenario, rng.randint(lo, hi - 1)
i += 1
def build_dataset(config: GRPOConfig, *, num_rows: int, seed: int = 0) -> Any:
"""Build a :class:`datasets.Dataset` of prompt rows for GRPO.
Every row pre-computes the system prompt for a single ``(scenario,
seed)`` pair so the trainer sees a normal chat-format ``prompt``
column. The extra columns (``scenario_id``, ``seed``,
``budget_steps``, ``enable_dba_oracle``) ride along verbatim and are forwarded by TRL as
``**kwargs`` to :meth:`SqlDriftToolEnv.reset`, which is how we
reproducibly pin each episode to its curriculum slot.
This function imports :mod:`datasets` lazily to keep the stdlib-
only import surface for CPU-only CI.
"""
if num_rows < 1:
raise ValueError("build_dataset requires num_rows >= 1")
from datasets import Dataset
it = iter_curriculum(config, seed=seed)
prompts: list[list[dict[str, str]]] = []
scenario_ids: list[str] = []
seeds: list[int] = []
budgets: list[int] = []
oracle_flags: list[bool] = []
for _ in range(num_rows):
scenario_id, episode_seed = next(it)
system = render_system_prompt(
scenario_id=scenario_id, dba_enabled=config.dba_oracle_enabled
)
prompts.append(
[
{"role": "system", "content": system},
{
"role": "user",
"content": (
"Diagnose, adapt to any drift, and submit a correct "
"rewrite of the baseline query using the tools "
"provided. Call submit_rewrite when confident."
),
},
]
)
scenario_ids.append(scenario_id)
seeds.append(episode_seed)
budgets.append(config.episode_step_budget)
oracle_flags.append(config.dba_oracle_enabled)
return Dataset.from_dict(
{
"prompt": prompts,
"scenario_id": scenario_ids,
"seed": seeds,
"budget_steps": budgets,
"enable_dba_oracle": oracle_flags,
}
)
def load_model_and_tokenizer(config: GRPOConfig) -> tuple[Any, Any]:
"""Load the base model with QLoRA-friendly quantization + tokenizer.
Mirrors the stack used by Hugging Face TRL's own reference notebooks
(``grpo_trl_lora_qlora.ipynb`` and the OpenEnv examples): plain
:class:`transformers.AutoModelForCausalLM` + :class:`BitsAndBytesConfig`
nf4 4-bit quantization. PEFT/LoRA is layered ON the model by the
GRPOTrainer itself when ``peft_config`` is passed to its constructor —
we do not call ``get_peft_model`` here.
Returning ``(model, tokenizer)`` keeps the function signature stable
for callers (notebooks, eval harness, etc.) that prepared LoRA
separately.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
quantization_config = None
if config.load_in_4bit:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16 if config.fp16 else torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
attn_implementation="sdpa",
dtype="float32" if quantization_config is not None else "auto",
quantization_config=quantization_config,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
return model, tokenizer
def build_peft_config(config: GRPOConfig) -> Any:
"""Build the :class:`peft.LoraConfig` matching ``GRPOConfig`` knobs.
Returned separately so callers can pass it to
``GRPOTrainer(peft_config=...)`` — the TRL-recommended path for
LoRA/QLoRA training.
"""
from peft import LoraConfig
return LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=list(config.lora_target_modules),
bias="none",
task_type="CAUSAL_LM",
)
def build_env_client(config: GRPOConfig) -> Any:
"""Instantiate the OpenEnv client bound to the running SQLDrift server.
The GRPO trainer itself does not use this helper — it owns its own
client lifecycle via ``environment_factory=SqlDriftToolEnv``. The
helper exists for the eval harness and notebooks that still want a
one-shot client.
"""
from client import SqlDriftEnv
return SqlDriftEnv(base_url=config.env_base_url).sync()
def reward_from_environments(
environments: Sequence[SqlDriftToolEnv],
**_: Any,
) -> list[float]:
"""TRL-compatible reward function.
SQLDrift's reward shaping is trajectory-based: step tax, tool-error
penalties, repeat-failing-query penalties, and DBA consult penalties
accrue *before* the final submit step. GRPO therefore needs the
running :attr:`SqlDriftToolEnv.episode_return`, not just the last
step's reward, or it would silently discard most of the shaping
signal during training.
"""
return [float(env.episode_return) for env in environments]
def _build_flush_log_history_callback(out_path: Path) -> Any:
"""Construct a TrainerCallback that appends each log dict to JSONL.
Why: `trainer.state.log_history` is in-memory only; a crash mid-run
wipes the curves needed by `utilities/plot_curves.py`. This callback
flushes per `on_log` so a step-N crash still leaves N records on
disk for the post-mortem plot.
The callback class is constructed lazily inside `train()` because
`transformers.TrainerCallback` is a [train]-extra import.
"""
from transformers import TrainerCallback
class _FlushLogHistory(TrainerCallback):
def __init__(self, target: Path) -> None:
self._target = target
self._target.parent.mkdir(parents=True, exist_ok=True)
def on_log(
self,
args: Any,
state: Any,
control: Any,
logs: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
if not logs:
return
record = {"step": state.global_step, **logs}
with self._target.open("a") as f:
f.write(json.dumps(record, default=str) + "\n")
return _FlushLogHistory(out_path)
def train(config: GRPOConfig) -> Any:
"""Run the full GRPO training loop against SQLDrift.
All heavy imports (``trl``, ``datasets``, ``transformers``,
``peft``, ``bitsandbytes``) are deferred into this function body so
a CPU-only checkout that only imports :mod:`training.grpo_train`
for, say, :func:`iter_curriculum` never triggers them.
Stack: :class:`transformers.AutoModelForCausalLM` (+ optional
``BitsAndBytesConfig`` 4-bit nf4) → :class:`trl.GRPOTrainer` with
``environment_factory=SqlDriftToolEnv`` and ``peft_config`` for
LoRA. Mirrors Hugging Face TRL's official reference notebooks
(``grpo_trl_lora_qlora.ipynb``, ``openenv_wordle_grpo.ipynb``).
"""
set_seed(config.seed)
out = Path(config.output_dir)
out.mkdir(parents=True, exist_ok=True)
(out / "grpo_config.json").write_text(
json.dumps(asdict(config), default=str, indent=2),
)
from trl import GRPOConfig as TRLGRPOConfig
from trl import GRPOTrainer
dataset = build_dataset(
config,
num_rows=max(config.max_steps * config.group_size, config.group_size),
seed=config.seed,
)
# TRL >=0.25 removed `max_prompt_length` (prompt length is now
# dataset-driven). `max_completion_length` is the TOTAL token budget
# across the entire multi-turn conversation, not a per-turn cap —
# see the OpenEnv integration doc:
# https://huggingface.co/docs/trl/openenv#max_completion_length-in-multi-turn-episodes
# Do NOT re-add `max_prompt_length` here; the kwarg no longer exists
# and its presence raises TypeError on construction.
# Gemma-4 has no <think> wrapper to suppress, so we don't pass
# chat_template_kwargs here. per_device_train_batch_size=1 +
# num_generations=group_size mirrors the Gemma-4 sudoku notebook
# (TRL requires the per-device batch * grad_accum to be divisible
# by num_generations; 1*2 % 2 == 0 satisfies that for group_size=2).
trl_args = TRLGRPOConfig(
output_dir=str(out),
learning_rate=config.learning_rate,
max_steps=config.max_steps,
per_device_train_batch_size=1,
num_generations=config.group_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
warmup_steps=config.warmup_steps,
max_completion_length=config.max_completion_length,
temperature=config.temperature,
top_p=config.top_p,
fp16=config.fp16,
bf16=config.bf16,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
log_completions=True,
report_to=["tensorboard"],
logging_dir=str(out / "tb"),
)
model, tokenizer = load_model_and_tokenizer(config)
peft_config = build_peft_config(config)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
reward_funcs=reward_from_environments,
args=trl_args,
environment_factory=partial(SqlDriftToolEnv, env_url=config.env_base_url),
peft_config=peft_config,
)
trainer.add_callback(_build_flush_log_history_callback(out / "log_history.jsonl"))
_LOG.info(
"Starting GRPO: scenarios=%d, max_steps=%d, group_size=%d, env_url=%s",
len(config.curriculum.scenarios),
config.max_steps,
config.group_size,
config.env_base_url,
)
trainer.train()
trainer.save_model(str(out))
return trainer
def _parse_args(argv: list[str] | None = None) -> GRPOConfig:
import argparse
ap = argparse.ArgumentParser(description="GRPO training for SQLDrift.")
ap.add_argument(
"--config",
type=Path,
default=None,
help="Optional JSON manifest overriding GRPOConfig defaults.",
)
ap.add_argument("--max-steps", type=int, default=None)
ap.add_argument("--output-dir", type=str, default=None)
ap.add_argument("--env-base-url", type=str, default=None)
ns = ap.parse_args(argv)
overrides: dict[str, Any] = {}
if ns.config is not None:
overrides.update(json.loads(ns.config.read_text()))
if ns.max_steps is not None:
overrides["max_steps"] = ns.max_steps
if ns.output_dir is not None:
overrides["output_dir"] = ns.output_dir
if ns.env_base_url is not None:
overrides["env_base_url"] = ns.env_base_url
from training.config import CurriculumConfig
if "curriculum" in overrides and isinstance(overrides["curriculum"], dict):
overrides["curriculum"] = CurriculumConfig(**overrides["curriculum"])
return GRPOConfig(**overrides)
def main(argv: list[str] | None = None) -> None:
cfg = _parse_args(argv)
train(cfg)
if __name__ == "__main__":
main()
__all__ = [
"build_dataset",
"build_env_client",
"build_peft_config",
"iter_curriculum",
"load_model_and_tokenizer",
"reward_from_environments",
"train",
]