Spaces:
Sleeping
Sleeping
| """GRPO training entrypoint wiring for SQLDrift. | |
| This module hosts the building blocks for training a Qwen3-class | |
| model (default: ``Qwen/Qwen3-4B-Instruct-2507``) against the SQLDrift | |
| OpenEnv environment with :class:`trl.GRPOTrainer`: | |
| * :func:`iter_curriculum` — pure, dependency-free scenario sampler | |
| used both by :func:`build_dataset` and by unit tests. | |
| * :func:`build_dataset` — turns the curriculum iterator into a | |
| Hugging Face :class:`datasets.Dataset` whose rows carry the per- | |
| episode seed / scenario id that :class:`.tool_env.SqlDriftToolEnv` | |
| consumes in its ``reset(**kwargs)``. | |
| * :func:`load_model_and_tokenizer` — lazy Unsloth loader that keeps | |
| the CUDA dependency tree out of the module top-level so CPU-only | |
| CI can still import this file. | |
| * :func:`build_env_client` — sanctioned way to obtain the OpenEnv | |
| client bound to a running SQLDrift server (used by the eval | |
| harness and notebooks — the trainer builds its own clients via | |
| ``environment_factory``). | |
| * :func:`reward_from_environments` — a TRL-compatible reward | |
| function that reads the cumulative trajectory return off each | |
| rolled-out :class:`SqlDriftToolEnv` instance. | |
| * :func:`train` — the real entrypoint: builds a curriculum dataset, | |
| loads the model, instantiates :class:`trl.GRPOTrainer` with | |
| an ``environment_factory`` bound to :class:`SqlDriftToolEnv` and the | |
| caller-supplied env URL (the TRL-sanctioned multi-turn OpenEnv | |
| rollout path, see the TRL OpenEnv integration guide) and runs | |
| ``trainer.train()``. All heavy imports are lazy so CPU-only CI can | |
| still import the module. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from collections.abc import Iterator, Sequence | |
| from dataclasses import asdict | |
| from functools import partial | |
| from pathlib import Path | |
| from random import Random | |
| from typing import Any | |
| from training.config import GRPOConfig | |
| from training.prompt import render_system_prompt | |
| from training.seeding import set_seed | |
| from training.tool_env import SqlDriftToolEnv | |
| from utilities.logger import get_module_logger | |
| _LOG = get_module_logger(__name__) | |
| def iter_curriculum(config: GRPOConfig, *, seed: int = 0) -> Iterator[tuple[str, int]]: | |
| """Yield an infinite stream of ``(scenario_id, episode_seed)`` tuples.""" | |
| rng = Random(seed) | |
| curr = config.curriculum | |
| lo, hi = curr.seed_range | |
| i = 0 | |
| while True: | |
| if curr.mode == "uniform": | |
| scenario = rng.choice(curr.scenarios) | |
| elif curr.mode == "weighted": | |
| scenario = rng.choices(curr.scenarios, weights=list(curr.weights or ()), k=1)[0] | |
| else: | |
| scenario = curr.scenarios[i % len(curr.scenarios)] | |
| yield scenario, rng.randint(lo, hi - 1) | |
| i += 1 | |
| def build_dataset(config: GRPOConfig, *, num_rows: int, seed: int = 0) -> Any: | |
| """Build a :class:`datasets.Dataset` of prompt rows for GRPO. | |
| Every row pre-computes the system prompt for a single ``(scenario, | |
| seed)`` pair so the trainer sees a normal chat-format ``prompt`` | |
| column. The extra columns (``scenario_id``, ``seed``, | |
| ``budget_steps``, ``enable_dba_oracle``) ride along verbatim and are forwarded by TRL as | |
| ``**kwargs`` to :meth:`SqlDriftToolEnv.reset`, which is how we | |
| reproducibly pin each episode to its curriculum slot. | |
| This function imports :mod:`datasets` lazily to keep the stdlib- | |
| only import surface for CPU-only CI. | |
| """ | |
| if num_rows < 1: | |
| raise ValueError("build_dataset requires num_rows >= 1") | |
| from datasets import Dataset | |
| it = iter_curriculum(config, seed=seed) | |
| prompts: list[list[dict[str, str]]] = [] | |
| scenario_ids: list[str] = [] | |
| seeds: list[int] = [] | |
| budgets: list[int] = [] | |
| oracle_flags: list[bool] = [] | |
| for _ in range(num_rows): | |
| scenario_id, episode_seed = next(it) | |
| system = render_system_prompt( | |
| scenario_id=scenario_id, dba_enabled=config.dba_oracle_enabled | |
| ) | |
| prompts.append( | |
| [ | |
| {"role": "system", "content": system}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| "Diagnose, adapt to any drift, and submit a correct " | |
| "rewrite of the baseline query using the tools " | |
| "provided. Call submit_rewrite when confident." | |
| ), | |
| }, | |
| ] | |
| ) | |
| scenario_ids.append(scenario_id) | |
| seeds.append(episode_seed) | |
| budgets.append(config.episode_step_budget) | |
| oracle_flags.append(config.dba_oracle_enabled) | |
| return Dataset.from_dict( | |
| { | |
| "prompt": prompts, | |
| "scenario_id": scenario_ids, | |
| "seed": seeds, | |
| "budget_steps": budgets, | |
| "enable_dba_oracle": oracle_flags, | |
| } | |
| ) | |
| def load_model_and_tokenizer(config: GRPOConfig) -> tuple[Any, Any]: | |
| """Load the base model with QLoRA-friendly quantization + tokenizer. | |
| Mirrors the stack used by Hugging Face TRL's own reference notebooks | |
| (``grpo_trl_lora_qlora.ipynb`` and the OpenEnv examples): plain | |
| :class:`transformers.AutoModelForCausalLM` + :class:`BitsAndBytesConfig` | |
| nf4 4-bit quantization. PEFT/LoRA is layered ON the model by the | |
| GRPOTrainer itself when ``peft_config`` is passed to its constructor — | |
| we do not call ``get_peft_model`` here. | |
| Returning ``(model, tokenizer)`` keeps the function signature stable | |
| for callers (notebooks, eval harness, etc.) that prepared LoRA | |
| separately. | |
| """ | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| quantization_config = None | |
| if config.load_in_4bit: | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.float16 if config.fp16 else torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.model_name, | |
| attn_implementation="sdpa", | |
| dtype="float32" if quantization_config is not None else "auto", | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
| return model, tokenizer | |
| def build_peft_config(config: GRPOConfig) -> Any: | |
| """Build the :class:`peft.LoraConfig` matching ``GRPOConfig`` knobs. | |
| Returned separately so callers can pass it to | |
| ``GRPOTrainer(peft_config=...)`` — the TRL-recommended path for | |
| LoRA/QLoRA training. | |
| """ | |
| from peft import LoraConfig | |
| return LoraConfig( | |
| r=config.lora_r, | |
| lora_alpha=config.lora_alpha, | |
| lora_dropout=config.lora_dropout, | |
| target_modules=list(config.lora_target_modules), | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| def build_env_client(config: GRPOConfig) -> Any: | |
| """Instantiate the OpenEnv client bound to the running SQLDrift server. | |
| The GRPO trainer itself does not use this helper — it owns its own | |
| client lifecycle via ``environment_factory=SqlDriftToolEnv``. The | |
| helper exists for the eval harness and notebooks that still want a | |
| one-shot client. | |
| """ | |
| from client import SqlDriftEnv | |
| return SqlDriftEnv(base_url=config.env_base_url).sync() | |
| def reward_from_environments( | |
| environments: Sequence[SqlDriftToolEnv], | |
| **_: Any, | |
| ) -> list[float]: | |
| """TRL-compatible reward function. | |
| SQLDrift's reward shaping is trajectory-based: step tax, tool-error | |
| penalties, repeat-failing-query penalties, and DBA consult penalties | |
| accrue *before* the final submit step. GRPO therefore needs the | |
| running :attr:`SqlDriftToolEnv.episode_return`, not just the last | |
| step's reward, or it would silently discard most of the shaping | |
| signal during training. | |
| """ | |
| return [float(env.episode_return) for env in environments] | |
| def _build_flush_log_history_callback(out_path: Path) -> Any: | |
| """Construct a TrainerCallback that appends each log dict to JSONL. | |
| Why: `trainer.state.log_history` is in-memory only; a crash mid-run | |
| wipes the curves needed by `utilities/plot_curves.py`. This callback | |
| flushes per `on_log` so a step-N crash still leaves N records on | |
| disk for the post-mortem plot. | |
| The callback class is constructed lazily inside `train()` because | |
| `transformers.TrainerCallback` is a [train]-extra import. | |
| """ | |
| from transformers import TrainerCallback | |
| class _FlushLogHistory(TrainerCallback): | |
| def __init__(self, target: Path) -> None: | |
| self._target = target | |
| self._target.parent.mkdir(parents=True, exist_ok=True) | |
| def on_log( | |
| self, | |
| args: Any, | |
| state: Any, | |
| control: Any, | |
| logs: dict[str, Any] | None = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| if not logs: | |
| return | |
| record = {"step": state.global_step, **logs} | |
| with self._target.open("a") as f: | |
| f.write(json.dumps(record, default=str) + "\n") | |
| return _FlushLogHistory(out_path) | |
| def train(config: GRPOConfig) -> Any: | |
| """Run the full GRPO training loop against SQLDrift. | |
| All heavy imports (``trl``, ``datasets``, ``transformers``, | |
| ``peft``, ``bitsandbytes``) are deferred into this function body so | |
| a CPU-only checkout that only imports :mod:`training.grpo_train` | |
| for, say, :func:`iter_curriculum` never triggers them. | |
| Stack: :class:`transformers.AutoModelForCausalLM` (+ optional | |
| ``BitsAndBytesConfig`` 4-bit nf4) → :class:`trl.GRPOTrainer` with | |
| ``environment_factory=SqlDriftToolEnv`` and ``peft_config`` for | |
| LoRA. Mirrors Hugging Face TRL's official reference notebooks | |
| (``grpo_trl_lora_qlora.ipynb``, ``openenv_wordle_grpo.ipynb``). | |
| """ | |
| set_seed(config.seed) | |
| out = Path(config.output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| (out / "grpo_config.json").write_text( | |
| json.dumps(asdict(config), default=str, indent=2), | |
| ) | |
| from trl import GRPOConfig as TRLGRPOConfig | |
| from trl import GRPOTrainer | |
| dataset = build_dataset( | |
| config, | |
| num_rows=max(config.max_steps * config.group_size, config.group_size), | |
| seed=config.seed, | |
| ) | |
| # TRL >=0.25 removed `max_prompt_length` (prompt length is now | |
| # dataset-driven). `max_completion_length` is the TOTAL token budget | |
| # across the entire multi-turn conversation, not a per-turn cap — | |
| # see the OpenEnv integration doc: | |
| # https://huggingface.co/docs/trl/openenv#max_completion_length-in-multi-turn-episodes | |
| # Do NOT re-add `max_prompt_length` here; the kwarg no longer exists | |
| # and its presence raises TypeError on construction. | |
| # Gemma-4 has no <think> wrapper to suppress, so we don't pass | |
| # chat_template_kwargs here. per_device_train_batch_size=1 + | |
| # num_generations=group_size mirrors the Gemma-4 sudoku notebook | |
| # (TRL requires the per-device batch * grad_accum to be divisible | |
| # by num_generations; 1*2 % 2 == 0 satisfies that for group_size=2). | |
| trl_args = TRLGRPOConfig( | |
| output_dir=str(out), | |
| learning_rate=config.learning_rate, | |
| max_steps=config.max_steps, | |
| per_device_train_batch_size=1, | |
| num_generations=config.group_size, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| warmup_steps=config.warmup_steps, | |
| max_completion_length=config.max_completion_length, | |
| temperature=config.temperature, | |
| top_p=config.top_p, | |
| fp16=config.fp16, | |
| bf16=config.bf16, | |
| logging_steps=config.logging_steps, | |
| save_steps=config.save_steps, | |
| log_completions=True, | |
| report_to=["tensorboard"], | |
| logging_dir=str(out / "tb"), | |
| ) | |
| model, tokenizer = load_model_and_tokenizer(config) | |
| peft_config = build_peft_config(config) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| train_dataset=dataset, | |
| reward_funcs=reward_from_environments, | |
| args=trl_args, | |
| environment_factory=partial(SqlDriftToolEnv, env_url=config.env_base_url), | |
| peft_config=peft_config, | |
| ) | |
| trainer.add_callback(_build_flush_log_history_callback(out / "log_history.jsonl")) | |
| _LOG.info( | |
| "Starting GRPO: scenarios=%d, max_steps=%d, group_size=%d, env_url=%s", | |
| len(config.curriculum.scenarios), | |
| config.max_steps, | |
| config.group_size, | |
| config.env_base_url, | |
| ) | |
| trainer.train() | |
| trainer.save_model(str(out)) | |
| return trainer | |
| def _parse_args(argv: list[str] | None = None) -> GRPOConfig: | |
| import argparse | |
| ap = argparse.ArgumentParser(description="GRPO training for SQLDrift.") | |
| ap.add_argument( | |
| "--config", | |
| type=Path, | |
| default=None, | |
| help="Optional JSON manifest overriding GRPOConfig defaults.", | |
| ) | |
| ap.add_argument("--max-steps", type=int, default=None) | |
| ap.add_argument("--output-dir", type=str, default=None) | |
| ap.add_argument("--env-base-url", type=str, default=None) | |
| ns = ap.parse_args(argv) | |
| overrides: dict[str, Any] = {} | |
| if ns.config is not None: | |
| overrides.update(json.loads(ns.config.read_text())) | |
| if ns.max_steps is not None: | |
| overrides["max_steps"] = ns.max_steps | |
| if ns.output_dir is not None: | |
| overrides["output_dir"] = ns.output_dir | |
| if ns.env_base_url is not None: | |
| overrides["env_base_url"] = ns.env_base_url | |
| from training.config import CurriculumConfig | |
| if "curriculum" in overrides and isinstance(overrides["curriculum"], dict): | |
| overrides["curriculum"] = CurriculumConfig(**overrides["curriculum"]) | |
| return GRPOConfig(**overrides) | |
| def main(argv: list[str] | None = None) -> None: | |
| cfg = _parse_args(argv) | |
| train(cfg) | |
| if __name__ == "__main__": | |
| main() | |
| __all__ = [ | |
| "build_dataset", | |
| "build_env_client", | |
| "build_peft_config", | |
| "iter_curriculum", | |
| "load_model_and_tokenizer", | |
| "reward_from_environments", | |
| "train", | |
| ] | |