sql-drift-env / training /config.py
visheshrathi's picture
Upload folder using huggingface_hub
bbf206f verified
"""Training configuration dataclasses.
Holds every knob the :mod:`training.grpo_train` script or the eval CLI
needs, as plain, frozen dataclasses so they serialize cleanly to JSON
for experiment manifests.
Deliberately lightweight: do not import ``trl`` / ``unsloth`` /
``transformers`` at module import time. Those libraries are CUDA-heavy
and optional. ``grpo_train.py`` resolves them lazily.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal
from utilities.env_loader import env_str
def _load_all_scenarios() -> tuple[str, ...]:
from scenarios import iter_specs
return tuple(spec.scenario_id for spec in iter_specs())
# Derived from the live registry so training defaults stay in sync with
# the scenarios actually shipped under ``scenarios/``.
ALL_SCENARIOS: tuple[str, ...] = _load_all_scenarios()
@dataclass(frozen=True)
class CurriculumConfig:
"""Scenario sampling policy for GRPO rollouts.
``mode="uniform"`` samples each id in :attr:`scenarios` with equal
probability. ``mode="weighted"`` uses :attr:`weights` (must be the
same length as :attr:`scenarios`) β€” useful for over-sampling drift
scenarios early in training. ``mode="static_order"`` iterates the
list round-robin (handy for reproducing eval-style runs).
"""
scenarios: tuple[str, ...] = ALL_SCENARIOS
mode: Literal["uniform", "weighted", "static_order"] = "uniform"
weights: tuple[float, ...] | None = None
seed_range: tuple[int, int] = (0, 2**31 - 1)
def __post_init__(self) -> None:
if not self.scenarios:
raise ValueError("CurriculumConfig.scenarios must be non-empty")
if self.mode == "weighted":
if self.weights is None or len(self.weights) != len(self.scenarios):
raise ValueError("mode='weighted' requires weights of the same length as scenarios")
if any(w < 0 for w in self.weights):
raise ValueError("weights must all be >= 0")
if sum(self.weights) <= 0:
raise ValueError("at least one weight must be > 0")
lo, hi = self.seed_range
if lo < 0 or hi <= lo:
raise ValueError("seed_range must be (lo >= 0, hi > lo)")
@dataclass(frozen=True)
class GRPOConfig:
"""Top-level training config for the GRPO skeleton.
Defaults pinned to ``Qwen/Qwen3-4B-Instruct-2507`` (Apache-2.0, July 2025
release, BFCL-v3 = 61.9 native tool-calling) loaded via
``transformers.AutoModelForCausalLM`` + ``BitsAndBytesConfig`` 4-bit nf4
QLoRA + ``peft.LoraConfig`` β€” the stack used by Hugging Face TRL's own
reference notebooks (``grpo_trl_lora_qlora.ipynb``, the OpenEnv
Wordle/Echo examples). Every knob is override-able from the CLI or a
JSON manifest.
"""
model_name: str = "Qwen/Qwen3-4B-Instruct-2507"
max_seq_length: int = 4096
# 4-bit nf4 quantization (QLoRA) β€” fits a 4B model on a free Colab T4
# in ~6 GB peak. Flip to False on a >=24 GB GPU for plain LoRA.
load_in_4bit: bool = True
# LoRA r=16 / alpha=32 mirrors the TRL grpo_trl_lora_qlora reference.
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.0
lora_target_modules: tuple[str, ...] = (
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
)
# GRPO knobs β€” sized for a free Colab T4 (16 GB) running multi-turn
# tool-using rollouts: per_device_train_batch_size=1 in TRL,
# num_generations=2, gradient_accumulation_steps=2.
group_size: int = 2
learning_rate: float = 5e-6
max_steps: int = 500
gradient_accumulation_steps: int = 2
warmup_steps: int = 10
# NOTE: TRL >=0.25 removed `max_prompt_length` from GRPOConfig β€” prompt
# length is dataset-driven, not trainer-configurable. And
# `max_completion_length` is now the TOTAL token budget across the
# entire multi-turn conversation (all assistant generations + tool
# results combined), not a per-turn cap. Size it for the worst-case
# episode (25-step budget * ~80 tokens per turn β‰ˆ 2k).
max_completion_length: int = 2048
# Qwen3-Instruct-2507 recommends temperature=0.7 / top_p=0.8 for
# non-thinking instruct mode. The 2507 line is non-thinking by default,
# so no `enable_thinking` toggle is needed.
temperature: float = 0.7
top_p: float = 0.8
seed: int = 0
# TRL defaults bf16 to True when fp16 is not explicitly set, but Colab
# T4 supports fp16 only (no bf16 hardware). Keep fp16=True for T4.
fp16: bool = True
bf16: bool = False
# Env wiring
env_base_url: str = env_str("SQL_DRIFT_ENV_URL", "http://localhost:8000")
episode_step_budget: int = 25
dba_oracle_enabled: bool = False
# IO
output_dir: str = "outputs/grpo_run"
logging_steps: int = 1
save_steps: int = 100
curriculum: CurriculumConfig = field(default_factory=CurriculumConfig)
def __post_init__(self) -> None:
if self.group_size < 2:
raise ValueError("GRPO group_size must be >= 2 (GRPO requires groups).")
if self.max_steps < 1:
raise ValueError("max_steps must be >= 1")
if self.seed < 0:
raise ValueError("seed must be >= 0")
if self.lora_r < 1:
raise ValueError("lora_r must be >= 1")
if self.fp16 and self.bf16:
raise ValueError("fp16 and bf16 are mutually exclusive")
if not 0.0 < self.temperature <= 2.0:
raise ValueError("temperature must be in (0, 2]")
__all__ = ["ALL_SCENARIOS", "CurriculumConfig", "GRPOConfig"]