File size: 5,051 Bytes
5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Configuration objects for GRPO training."""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
_logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Device options
# ---------------------------------------------------------------------------
# "auto" β use GPU/MPS if available, fall back to CPU
# "cpu" β force CPU (use on Mac where MPS OOMs during GRPO)
# "cuda" β force CUDA (use on Colab / cloud GPU)
# "mps" β force MPS (only if model fits; unlikely for GRPO)
DEVICE_AUTO = "auto"
DEVICE_CPU = "cpu"
DEVICE_CUDA = "cuda"
DEVICE_MPS = "mps"
def find_project_root() -> Path:
"""Walk up from cwd until we find pyproject.toml."""
d = Path.cwd()
for parent in [d, *d.parents]:
if (parent / "pyproject.toml").exists():
return parent
raise FileNotFoundError("Could not locate project root (no pyproject.toml found)")
def apply_device_overrides(device: str) -> None:
"""Set environment/backend flags so PyTorch and HuggingFace respect *device*.
Call this before importing transformers or loading models.
Why this exists: GRPO generates multiple completions per prompt, so peak
memory is several times the model size. On Mac (MPS, typically 16 GB
shared), even a 0.6B model OOMs. Forcing CPU avoids the crash at the
cost of speed. On Colab/cloud, "auto" or "cuda" is the right choice.
"""
if device == DEVICE_AUTO:
return
if device == DEVICE_CPU:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
try:
import torch
torch.backends.mps.is_available = lambda: False # type: ignore[assignment]
except ImportError:
pass
_logger.info("Device forced to CPU β MPS and CUDA disabled")
return
if device == DEVICE_CUDA:
try:
import torch
torch.backends.mps.is_available = lambda: False # type: ignore[assignment]
except ImportError:
pass
_logger.info("Device forced to CUDA β MPS disabled")
return
# "mps" β no overrides needed, PyTorch will use MPS if available
@dataclass
class GRPOConfig:
"""Configuration for GRPO training on SQLEnv.
Parameters
----------
questions_path
Path to the training questions JSON file.
db_dir
Directory containing SQLite databases.
output_dir
Directory where checkpoints and outputs are written.
device
Device policy: "auto", "cpu", "cuda", or "mps".
Use "cpu" on Mac (MPS OOMs with GRPO).
Use "auto" or "cuda" on Colab / cloud GPU.
"""
questions_path: str
db_dir: str
output_dir: str
model_name: str = "Qwen/Qwen3-0.6B"
device: str = DEVICE_AUTO
max_new_tokens: int = 256
num_train_epochs: int = 1
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 4
learning_rate: float = 5e-6
num_generations: int = 4
step_budget: int = 10
difficulty_filter: list[str] = field(default_factory=lambda: ["easy", "medium"])
seed: int = 42
logging_steps: int = 10
# KL penalty against reference model (prevents format drift during GRPO)
beta: float = 0.04
# Precision: "auto", "fp16", "bf16", "fp32"
precision: str = "auto"
# Enable gradient checkpointing to reduce VRAM (trades compute for memory)
gradient_checkpointing: bool = False
# Enable Qwen3 thinking mode (<think> blocks before tool calls).
# When False (default), /no_think is prepended to the system prompt
# and TRL's chat_template_kwargs disables thinking. When True,
# the model can reason before acting β requires higher max_new_tokens.
enable_thinking: bool = False
def __post_init__(self) -> None:
valid_devices = {DEVICE_AUTO, DEVICE_CPU, DEVICE_CUDA, DEVICE_MPS}
if self.device not in valid_devices:
msg = f"device must be one of {valid_devices}, got '{self.device}'"
raise ValueError(msg)
if self.max_new_tokens <= 0:
raise ValueError("max_new_tokens must be > 0")
if self.num_train_epochs <= 0:
raise ValueError("num_train_epochs must be > 0")
if self.per_device_train_batch_size <= 0:
raise ValueError("per_device_train_batch_size must be > 0")
if self.gradient_accumulation_steps <= 0:
raise ValueError("gradient_accumulation_steps must be > 0")
if self.learning_rate <= 0:
raise ValueError("learning_rate must be > 0")
if self.num_generations <= 0:
raise ValueError("num_generations must be > 0")
if self.step_budget < 0:
raise ValueError("step_budget must be >= 0")
if self.logging_steps <= 0:
raise ValueError("logging_steps must be > 0")
if not self.difficulty_filter:
raise ValueError("difficulty_filter must not be empty")
|