deepbattler / RL /train_gsm8k_qwen_grpo.py
wyksdsg's picture
Upload folder using huggingface_hub
787c99c verified
#!/usr/bin/env python
# train_gsm8k_qwen_grpo.py
#
# End-to-end SFT + GRPO pipeline on GSM8K using a Qwen instruct checkpoint,
# compatible with transformers>=4.57 and trl>=0.25.x.
import argparse
import os
import re
from dataclasses import dataclass
from typing import List, Optional
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig
# ================== Model paths & defaults ==================
LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507"
def _resolve_default_model_id() -> str:
env_override = os.environ.get("QWEN_INSTRUCT_MODEL")
if env_override:
return env_override
if os.path.isdir(LOCAL_INSTRUCT_PATH):
return LOCAL_INSTRUCT_PATH
return "Qwen/Qwen3-4B-Instruct"
DEFAULT_MODEL_ID = _resolve_default_model_id()
DEFAULT_OUTPUT_DIR = "./qwen_gsm8k_grpo"
DEFAULT_TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
# ================== Config dataclass ==================
@dataclass
class PipelineConfig:
model_name_or_path: str = DEFAULT_MODEL_ID
output_dir: str = DEFAULT_OUTPUT_DIR
max_seq_length: int = 512
sft_epochs: int = 1
grpo_epochs: int = 1
train_samples: Optional[int] = None
eval_samples: Optional[int] = None
bf16: bool = True
per_device_batch_size: int = 1
grad_accum_steps: int = 8
sft_learning_rate: float = 1e-5
grpo_learning_rate: float = 5e-6
max_completion_length: int = 64
num_generations: int = 4
steps_per_generation: int = 1
target_modules: Optional[List[str]] = None
skip_sft: bool = False
skip_grpo: bool = False
def parse_args() -> PipelineConfig:
parser = argparse.ArgumentParser(
description="Run GSM8K fine-tuning (SFT + GRPO) with a Qwen instruct checkpoint."
)
parser.add_argument(
"--model",
default=DEFAULT_MODEL_ID,
help=(
"Model id or local path for the instruct-tuned Qwen checkpoint. "
"Defaults to models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507 when present."
),
)
parser.add_argument(
"--output-dir",
default=DEFAULT_OUTPUT_DIR,
help="Directory for checkpoints and logs.",
)
parser.add_argument("--max-seq-length", type=int, default=512)
parser.add_argument("--sft-epochs", type=int, default=1)
parser.add_argument("--grpo-epochs", type=int, default=1)
parser.add_argument(
"--train-samples",
type=int,
default=None,
help="Optional number of GSM8K training samples (None => full set).",
)
parser.add_argument(
"--eval-samples",
type=int,
default=None,
help="Optional number of GSM8K eval samples.",
)
parser.add_argument("--per-device-batch-size", type=int, default=1)
parser.add_argument("--grad-accum-steps", type=int, default=8)
parser.add_argument("--sft-learning-rate", type=float, default=1e-5)
parser.add_argument("--grpo-learning-rate", type=float, default=5e-6)
parser.add_argument("--max-completion-length", type=int, default=64)
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--steps-per-generation", type=int, default=1)
parser.add_argument(
"--target-modules",
default=None,
help="Comma-separated list of module names for LoRA (defaults to Qwen attn/FFN blocks).",
)
parser.add_argument(
"--disable-bf16",
action="store_true",
help="Force fp16/fp32 training if bf16 is not desired or unsupported.",
)
parser.add_argument("--skip-sft", action="store_true", help="Skip the SFT phase.")
parser.add_argument("--skip-grpo", action="store_true", help="Skip the GRPO phase.")
args = parser.parse_args()
target_modules = (
[m.strip() for m in args.target_modules.split(",") if m.strip()]
if args.target_modules
else None
)
return PipelineConfig(
model_name_or_path=args.model,
output_dir=args.output_dir,
max_seq_length=args.max_seq_length,
sft_epochs=args.sft_epochs,
grpo_epochs=args.grpo_epochs,
train_samples=args.train_samples,
eval_samples=args.eval_samples,
bf16=not args.disable_bf16,
per_device_batch_size=args.per_device_batch_size,
grad_accum_steps=args.grad_accum_steps,
sft_learning_rate=args.sft_learning_rate,
grpo_learning_rate=args.grpo_learning_rate,
max_completion_length=args.max_completion_length,
num_generations=args.num_generations,
steps_per_generation=args.steps_per_generation,
target_modules=target_modules,
skip_sft=args.skip_sft,
skip_grpo=args.skip_grpo,
)
# ================== Data: GSM8K formatting ==================
def load_gsm8k(train_limit: Optional[int] = None, eval_limit: Optional[int] = None):
"""
Load GSM8K and return a dataset with:
- prompt (input to the model)
- completion (gold text, used for SFT)
- final_answer (clean integer answer, used for reward)
"""
raw = load_dataset("openai/gsm8k", "main")
train_ds = raw["train"]
test_ds = raw["test"]
def format_example(ex):
question = ex["question"]
full_answer = ex["answer"]
# GSM8K answers look like "... #### 42"
final_ans = full_answer.split("####")[-1].strip()
prompt = (
"You are a helpful math solver.\n\n"
f"Question:\n{question}\n\n"
"Answer with a single integer.\n"
)
completion = final_ans # gold short answer
return {
"prompt": prompt,
"completion": completion,
"final_answer": final_ans,
}
train_ds = train_ds.map(format_example, remove_columns=train_ds.column_names).shuffle(seed=42)
test_ds = test_ds.map(format_example, remove_columns=test_ds.column_names)
if train_limit is not None:
train_ds = train_ds.select(range(min(train_limit, len(train_ds))))
if eval_limit is not None:
test_ds = test_ds.select(range(min(eval_limit, len(test_ds))))
return train_ds, test_ds
# ================== Reward function for GRPO ==================
INT_REGEX = re.compile(r"-?\d+")
def extract_last_int(text: str):
matches = INT_REGEX.findall(text)
return matches[-1] if matches else None
def correctness_reward(completions: List[str], **kwargs) -> List[float]:
"""
Custom reward function for GRPOTrainer.
TRL 0.25.x will call this with:
- completions: list[str]
- prompts: list[str] (via kwargs["prompts"])
- plus all dataset columns (except 'prompt') as kwargs
e.g. kwargs["final_answer"] is our ground truth list[str].
"""
final_answer = kwargs.get("final_answer")
rewards: List[float] = []
# If we somehow don't get final_answer, just give length-based reward (debug fallback)
if final_answer is None:
return [float(len(c)) for c in completions]
for comp, ref in zip(completions, final_answer):
pred = extract_last_int(comp)
if pred is not None and pred.strip() == ref.strip():
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards
# ================== SFT phase ==================
def run_sft(train_ds, eval_ds, tokenizer, cfg: PipelineConfig):
"""Run a short supervised fine-tuning pass with LoRA adapters (prompt-completion)."""
target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=target_modules,
task_type="CAUSAL_LM",
)
sft_config = SFTConfig(
output_dir=os.path.join(cfg.output_dir, "sft"),
per_device_train_batch_size=cfg.per_device_batch_size,
per_device_eval_batch_size=cfg.per_device_batch_size,
gradient_accumulation_steps=cfg.grad_accum_steps,
learning_rate=cfg.sft_learning_rate,
num_train_epochs=cfg.sft_epochs,
logging_steps=10,
save_steps=200,
eval_steps=200,
eval_strategy="steps", # transformers>=4.57 uses 'eval_strategy'
save_total_limit=2,
max_length=cfg.max_seq_length,
bf16=cfg.bf16,
fp16=not cfg.bf16,
report_to=["none"],
)
# SFTTrainer will look for 'prompt' & 'completion' columns in the dataset.
trainer = SFTTrainer(
model=cfg.model_name_or_path, # string path → SFTTrainer will load the model
args=sft_config,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer, # new TRL API
peft_config=peft_config,
)
trainer.train()
save_path = os.path.join(cfg.output_dir, "sft_model")
trainer.save_model(save_path)
return trainer.model # PEFT-wrapped model instance
# ================== GRPO phase ==================
def build_rl_dataset(train_ds):
"""
For GRPO we just need 'prompt'; we also keep 'final_answer' so reward_fn can use it.
"""
return train_ds
def run_grpo(rl_dataset, base_model, tokenizer, cfg: PipelineConfig):
"""Run a short GRPO training loop on top of the (optionally) SFT-initialized model."""
target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
target_modules=target_modules,
task_type="CAUSAL_LM",
)
generation_batch_size = cfg.per_device_batch_size * cfg.num_generations
grpo_config = GRPOConfig(
output_dir=os.path.join(cfg.output_dir, "grpo"),
num_train_epochs=cfg.grpo_epochs,
per_device_train_batch_size=cfg.per_device_batch_size,
gradient_accumulation_steps=cfg.grad_accum_steps,
logging_steps=10,
save_steps=200,
save_total_limit=2,
bf16=cfg.bf16,
fp16=not cfg.bf16,
learning_rate=cfg.grpo_learning_rate,
max_prompt_length=cfg.max_seq_length,
max_completion_length=cfg.max_completion_length,
num_generations=cfg.num_generations,
generation_batch_size=generation_batch_size,
report_to=["none"],
)
trainer = GRPOTrainer(
model=base_model, # can be model instance or model id
args=grpo_config,
processing_class=tokenizer,
reward_funcs=correctness_reward, # single custom reward
train_dataset=rl_dataset,
peft_config=peft_config,
)
trainer.train()
trainer.save_model(os.path.join(cfg.output_dir, "grpo_model"))
# ================== Main ==================
def main():
cfg = parse_args()
os.makedirs(cfg.output_dir, exist_ok=True)
print(f"Using model: {cfg.model_name_or_path}")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_name_or_path,
use_fast=True,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# GRPO 要求 left padding
tokenizer.padding_side = "left"
print("Loading GSM8K dataset...")
train_ds, eval_ds = load_gsm8k(cfg.train_samples, cfg.eval_samples)
# ----- SFT -----
if cfg.skip_sft:
print("Skipping SFT phase; loading base model directly.")
dtype = (
torch.bfloat16 if cfg.bf16 and torch.cuda.is_available()
else (torch.float16 if torch.cuda.is_available() else torch.float32)
)
model_kwargs = {
"torch_dtype": dtype,
"trust_remote_code": True,
}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
base_model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, **model_kwargs)
else:
print("Running SFT phase...")
base_model = run_sft(train_ds, eval_ds, tokenizer, cfg)
# ----- GRPO -----
if cfg.skip_grpo:
print("Skipping GRPO phase; only SFT outputs (if any) were produced.")
else:
print("Preparing RL dataset...")
rl_dataset = build_rl_dataset(train_ds)
print("Running GRPO phase...")
run_grpo(rl_dataset, base_model, tokenizer, cfg)
print("All done. Check outputs under:", cfg.output_dir)
if __name__ == "__main__":
main()