#!/usr/bin/env python # train_gsm8k_qwen_grpo.py # # End-to-end SFT + GRPO pipeline on GSM8K using a Qwen instruct checkpoint, # compatible with transformers>=4.57 and trl>=0.25.x. import argparse import os import re from dataclasses import dataclass from typing import List, Optional import torch from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM from peft import LoraConfig from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig # ================== Model paths & defaults ================== LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" def _resolve_default_model_id() -> str: env_override = os.environ.get("QWEN_INSTRUCT_MODEL") if env_override: return env_override if os.path.isdir(LOCAL_INSTRUCT_PATH): return LOCAL_INSTRUCT_PATH return "Qwen/Qwen3-4B-Instruct" DEFAULT_MODEL_ID = _resolve_default_model_id() DEFAULT_OUTPUT_DIR = "./qwen_gsm8k_grpo" DEFAULT_TARGET_MODULES = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ] # ================== Config dataclass ================== @dataclass class PipelineConfig: model_name_or_path: str = DEFAULT_MODEL_ID output_dir: str = DEFAULT_OUTPUT_DIR max_seq_length: int = 512 sft_epochs: int = 1 grpo_epochs: int = 1 train_samples: Optional[int] = None eval_samples: Optional[int] = None bf16: bool = True per_device_batch_size: int = 1 grad_accum_steps: int = 8 sft_learning_rate: float = 1e-5 grpo_learning_rate: float = 5e-6 max_completion_length: int = 64 num_generations: int = 4 steps_per_generation: int = 1 target_modules: Optional[List[str]] = None skip_sft: bool = False skip_grpo: bool = False def parse_args() -> PipelineConfig: parser = argparse.ArgumentParser( description="Run GSM8K fine-tuning (SFT + GRPO) with a Qwen instruct checkpoint." ) parser.add_argument( "--model", default=DEFAULT_MODEL_ID, help=( "Model id or local path for the instruct-tuned Qwen checkpoint. " "Defaults to models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507 when present." ), ) parser.add_argument( "--output-dir", default=DEFAULT_OUTPUT_DIR, help="Directory for checkpoints and logs.", ) parser.add_argument("--max-seq-length", type=int, default=512) parser.add_argument("--sft-epochs", type=int, default=1) parser.add_argument("--grpo-epochs", type=int, default=1) parser.add_argument( "--train-samples", type=int, default=None, help="Optional number of GSM8K training samples (None => full set).", ) parser.add_argument( "--eval-samples", type=int, default=None, help="Optional number of GSM8K eval samples.", ) parser.add_argument("--per-device-batch-size", type=int, default=1) parser.add_argument("--grad-accum-steps", type=int, default=8) parser.add_argument("--sft-learning-rate", type=float, default=1e-5) parser.add_argument("--grpo-learning-rate", type=float, default=5e-6) parser.add_argument("--max-completion-length", type=int, default=64) parser.add_argument("--num-generations", type=int, default=4) parser.add_argument("--steps-per-generation", type=int, default=1) parser.add_argument( "--target-modules", default=None, help="Comma-separated list of module names for LoRA (defaults to Qwen attn/FFN blocks).", ) parser.add_argument( "--disable-bf16", action="store_true", help="Force fp16/fp32 training if bf16 is not desired or unsupported.", ) parser.add_argument("--skip-sft", action="store_true", help="Skip the SFT phase.") parser.add_argument("--skip-grpo", action="store_true", help="Skip the GRPO phase.") args = parser.parse_args() target_modules = ( [m.strip() for m in args.target_modules.split(",") if m.strip()] if args.target_modules else None ) return PipelineConfig( model_name_or_path=args.model, output_dir=args.output_dir, max_seq_length=args.max_seq_length, sft_epochs=args.sft_epochs, grpo_epochs=args.grpo_epochs, train_samples=args.train_samples, eval_samples=args.eval_samples, bf16=not args.disable_bf16, per_device_batch_size=args.per_device_batch_size, grad_accum_steps=args.grad_accum_steps, sft_learning_rate=args.sft_learning_rate, grpo_learning_rate=args.grpo_learning_rate, max_completion_length=args.max_completion_length, num_generations=args.num_generations, steps_per_generation=args.steps_per_generation, target_modules=target_modules, skip_sft=args.skip_sft, skip_grpo=args.skip_grpo, ) # ================== Data: GSM8K formatting ================== def load_gsm8k(train_limit: Optional[int] = None, eval_limit: Optional[int] = None): """ Load GSM8K and return a dataset with: - prompt (input to the model) - completion (gold text, used for SFT) - final_answer (clean integer answer, used for reward) """ raw = load_dataset("openai/gsm8k", "main") train_ds = raw["train"] test_ds = raw["test"] def format_example(ex): question = ex["question"] full_answer = ex["answer"] # GSM8K answers look like "... #### 42" final_ans = full_answer.split("####")[-1].strip() prompt = ( "You are a helpful math solver.\n\n" f"Question:\n{question}\n\n" "Answer with a single integer.\n" ) completion = final_ans # gold short answer return { "prompt": prompt, "completion": completion, "final_answer": final_ans, } train_ds = train_ds.map(format_example, remove_columns=train_ds.column_names).shuffle(seed=42) test_ds = test_ds.map(format_example, remove_columns=test_ds.column_names) if train_limit is not None: train_ds = train_ds.select(range(min(train_limit, len(train_ds)))) if eval_limit is not None: test_ds = test_ds.select(range(min(eval_limit, len(test_ds)))) return train_ds, test_ds # ================== Reward function for GRPO ================== INT_REGEX = re.compile(r"-?\d+") def extract_last_int(text: str): matches = INT_REGEX.findall(text) return matches[-1] if matches else None def correctness_reward(completions: List[str], **kwargs) -> List[float]: """ Custom reward function for GRPOTrainer. TRL 0.25.x will call this with: - completions: list[str] - prompts: list[str] (via kwargs["prompts"]) - plus all dataset columns (except 'prompt') as kwargs e.g. kwargs["final_answer"] is our ground truth list[str]. """ final_answer = kwargs.get("final_answer") rewards: List[float] = [] # If we somehow don't get final_answer, just give length-based reward (debug fallback) if final_answer is None: return [float(len(c)) for c in completions] for comp, ref in zip(completions, final_answer): pred = extract_last_int(comp) if pred is not None and pred.strip() == ref.strip(): rewards.append(1.0) else: rewards.append(0.0) return rewards # ================== SFT phase ================== def run_sft(train_ds, eval_ds, tokenizer, cfg: PipelineConfig): """Run a short supervised fine-tuning pass with LoRA adapters (prompt-completion).""" target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=target_modules, task_type="CAUSAL_LM", ) sft_config = SFTConfig( output_dir=os.path.join(cfg.output_dir, "sft"), per_device_train_batch_size=cfg.per_device_batch_size, per_device_eval_batch_size=cfg.per_device_batch_size, gradient_accumulation_steps=cfg.grad_accum_steps, learning_rate=cfg.sft_learning_rate, num_train_epochs=cfg.sft_epochs, logging_steps=10, save_steps=200, eval_steps=200, eval_strategy="steps", # transformers>=4.57 uses 'eval_strategy' save_total_limit=2, max_length=cfg.max_seq_length, bf16=cfg.bf16, fp16=not cfg.bf16, report_to=["none"], ) # SFTTrainer will look for 'prompt' & 'completion' columns in the dataset. trainer = SFTTrainer( model=cfg.model_name_or_path, # string path → SFTTrainer will load the model args=sft_config, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, # new TRL API peft_config=peft_config, ) trainer.train() save_path = os.path.join(cfg.output_dir, "sft_model") trainer.save_model(save_path) return trainer.model # PEFT-wrapped model instance # ================== GRPO phase ================== def build_rl_dataset(train_ds): """ For GRPO we just need 'prompt'; we also keep 'final_answer' so reward_fn can use it. """ return train_ds def run_grpo(rl_dataset, base_model, tokenizer, cfg: PipelineConfig): """Run a short GRPO training loop on top of the (optionally) SFT-initialized model.""" target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES peft_config = LoraConfig( r=8, lora_alpha=16, lora_dropout=0.05, bias="none", target_modules=target_modules, task_type="CAUSAL_LM", ) generation_batch_size = cfg.per_device_batch_size * cfg.num_generations grpo_config = GRPOConfig( output_dir=os.path.join(cfg.output_dir, "grpo"), num_train_epochs=cfg.grpo_epochs, per_device_train_batch_size=cfg.per_device_batch_size, gradient_accumulation_steps=cfg.grad_accum_steps, logging_steps=10, save_steps=200, save_total_limit=2, bf16=cfg.bf16, fp16=not cfg.bf16, learning_rate=cfg.grpo_learning_rate, max_prompt_length=cfg.max_seq_length, max_completion_length=cfg.max_completion_length, num_generations=cfg.num_generations, generation_batch_size=generation_batch_size, report_to=["none"], ) trainer = GRPOTrainer( model=base_model, # can be model instance or model id args=grpo_config, processing_class=tokenizer, reward_funcs=correctness_reward, # single custom reward train_dataset=rl_dataset, peft_config=peft_config, ) trainer.train() trainer.save_model(os.path.join(cfg.output_dir, "grpo_model")) # ================== Main ================== def main(): cfg = parse_args() os.makedirs(cfg.output_dir, exist_ok=True) print(f"Using model: {cfg.model_name_or_path}") print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( cfg.model_name_or_path, use_fast=True, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # GRPO 要求 left padding tokenizer.padding_side = "left" print("Loading GSM8K dataset...") train_ds, eval_ds = load_gsm8k(cfg.train_samples, cfg.eval_samples) # ----- SFT ----- if cfg.skip_sft: print("Skipping SFT phase; loading base model directly.") dtype = ( torch.bfloat16 if cfg.bf16 and torch.cuda.is_available() else (torch.float16 if torch.cuda.is_available() else torch.float32) ) model_kwargs = { "torch_dtype": dtype, "trust_remote_code": True, } if torch.cuda.is_available(): model_kwargs["device_map"] = "auto" base_model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, **model_kwargs) else: print("Running SFT phase...") base_model = run_sft(train_ds, eval_ds, tokenizer, cfg) # ----- GRPO ----- if cfg.skip_grpo: print("Skipping GRPO phase; only SFT outputs (if any) were produced.") else: print("Preparing RL dataset...") rl_dataset = build_rl_dataset(train_ds) print("Running GRPO phase...") run_grpo(rl_dataset, base_model, tokenizer, cfg) print("All done. Check outputs under:", cfg.output_dir) if __name__ == "__main__": main()