Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # train_gsm8k_qwen_grpo.py | |
| # | |
| # End-to-end SFT + GRPO pipeline on GSM8K using a Qwen instruct checkpoint, | |
| # compatible with transformers>=4.57 and trl>=0.25.x. | |
| import argparse | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig | |
| # ================== Model paths & defaults ================== | |
| LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" | |
| def _resolve_default_model_id() -> str: | |
| env_override = os.environ.get("QWEN_INSTRUCT_MODEL") | |
| if env_override: | |
| return env_override | |
| if os.path.isdir(LOCAL_INSTRUCT_PATH): | |
| return LOCAL_INSTRUCT_PATH | |
| return "Qwen/Qwen3-4B-Instruct" | |
| DEFAULT_MODEL_ID = _resolve_default_model_id() | |
| DEFAULT_OUTPUT_DIR = "./qwen_gsm8k_grpo" | |
| DEFAULT_TARGET_MODULES = [ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ] | |
| # ================== Config dataclass ================== | |
| class PipelineConfig: | |
| model_name_or_path: str = DEFAULT_MODEL_ID | |
| output_dir: str = DEFAULT_OUTPUT_DIR | |
| max_seq_length: int = 512 | |
| sft_epochs: int = 1 | |
| grpo_epochs: int = 1 | |
| train_samples: Optional[int] = None | |
| eval_samples: Optional[int] = None | |
| bf16: bool = True | |
| per_device_batch_size: int = 1 | |
| grad_accum_steps: int = 8 | |
| sft_learning_rate: float = 1e-5 | |
| grpo_learning_rate: float = 5e-6 | |
| max_completion_length: int = 64 | |
| num_generations: int = 4 | |
| steps_per_generation: int = 1 | |
| target_modules: Optional[List[str]] = None | |
| skip_sft: bool = False | |
| skip_grpo: bool = False | |
| def parse_args() -> PipelineConfig: | |
| parser = argparse.ArgumentParser( | |
| description="Run GSM8K fine-tuning (SFT + GRPO) with a Qwen instruct checkpoint." | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default=DEFAULT_MODEL_ID, | |
| help=( | |
| "Model id or local path for the instruct-tuned Qwen checkpoint. " | |
| "Defaults to models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507 when present." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default=DEFAULT_OUTPUT_DIR, | |
| help="Directory for checkpoints and logs.", | |
| ) | |
| parser.add_argument("--max-seq-length", type=int, default=512) | |
| parser.add_argument("--sft-epochs", type=int, default=1) | |
| parser.add_argument("--grpo-epochs", type=int, default=1) | |
| parser.add_argument( | |
| "--train-samples", | |
| type=int, | |
| default=None, | |
| help="Optional number of GSM8K training samples (None => full set).", | |
| ) | |
| parser.add_argument( | |
| "--eval-samples", | |
| type=int, | |
| default=None, | |
| help="Optional number of GSM8K eval samples.", | |
| ) | |
| parser.add_argument("--per-device-batch-size", type=int, default=1) | |
| parser.add_argument("--grad-accum-steps", type=int, default=8) | |
| parser.add_argument("--sft-learning-rate", type=float, default=1e-5) | |
| parser.add_argument("--grpo-learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--max-completion-length", type=int, default=64) | |
| parser.add_argument("--num-generations", type=int, default=4) | |
| parser.add_argument("--steps-per-generation", type=int, default=1) | |
| parser.add_argument( | |
| "--target-modules", | |
| default=None, | |
| help="Comma-separated list of module names for LoRA (defaults to Qwen attn/FFN blocks).", | |
| ) | |
| parser.add_argument( | |
| "--disable-bf16", | |
| action="store_true", | |
| help="Force fp16/fp32 training if bf16 is not desired or unsupported.", | |
| ) | |
| parser.add_argument("--skip-sft", action="store_true", help="Skip the SFT phase.") | |
| parser.add_argument("--skip-grpo", action="store_true", help="Skip the GRPO phase.") | |
| args = parser.parse_args() | |
| target_modules = ( | |
| [m.strip() for m in args.target_modules.split(",") if m.strip()] | |
| if args.target_modules | |
| else None | |
| ) | |
| return PipelineConfig( | |
| model_name_or_path=args.model, | |
| output_dir=args.output_dir, | |
| max_seq_length=args.max_seq_length, | |
| sft_epochs=args.sft_epochs, | |
| grpo_epochs=args.grpo_epochs, | |
| train_samples=args.train_samples, | |
| eval_samples=args.eval_samples, | |
| bf16=not args.disable_bf16, | |
| per_device_batch_size=args.per_device_batch_size, | |
| grad_accum_steps=args.grad_accum_steps, | |
| sft_learning_rate=args.sft_learning_rate, | |
| grpo_learning_rate=args.grpo_learning_rate, | |
| max_completion_length=args.max_completion_length, | |
| num_generations=args.num_generations, | |
| steps_per_generation=args.steps_per_generation, | |
| target_modules=target_modules, | |
| skip_sft=args.skip_sft, | |
| skip_grpo=args.skip_grpo, | |
| ) | |
| # ================== Data: GSM8K formatting ================== | |
| def load_gsm8k(train_limit: Optional[int] = None, eval_limit: Optional[int] = None): | |
| """ | |
| Load GSM8K and return a dataset with: | |
| - prompt (input to the model) | |
| - completion (gold text, used for SFT) | |
| - final_answer (clean integer answer, used for reward) | |
| """ | |
| raw = load_dataset("openai/gsm8k", "main") | |
| train_ds = raw["train"] | |
| test_ds = raw["test"] | |
| def format_example(ex): | |
| question = ex["question"] | |
| full_answer = ex["answer"] | |
| # GSM8K answers look like "... #### 42" | |
| final_ans = full_answer.split("####")[-1].strip() | |
| prompt = ( | |
| "You are a helpful math solver.\n\n" | |
| f"Question:\n{question}\n\n" | |
| "Answer with a single integer.\n" | |
| ) | |
| completion = final_ans # gold short answer | |
| return { | |
| "prompt": prompt, | |
| "completion": completion, | |
| "final_answer": final_ans, | |
| } | |
| train_ds = train_ds.map(format_example, remove_columns=train_ds.column_names).shuffle(seed=42) | |
| test_ds = test_ds.map(format_example, remove_columns=test_ds.column_names) | |
| if train_limit is not None: | |
| train_ds = train_ds.select(range(min(train_limit, len(train_ds)))) | |
| if eval_limit is not None: | |
| test_ds = test_ds.select(range(min(eval_limit, len(test_ds)))) | |
| return train_ds, test_ds | |
| # ================== Reward function for GRPO ================== | |
| INT_REGEX = re.compile(r"-?\d+") | |
| def extract_last_int(text: str): | |
| matches = INT_REGEX.findall(text) | |
| return matches[-1] if matches else None | |
| def correctness_reward(completions: List[str], **kwargs) -> List[float]: | |
| """ | |
| Custom reward function for GRPOTrainer. | |
| TRL 0.25.x will call this with: | |
| - completions: list[str] | |
| - prompts: list[str] (via kwargs["prompts"]) | |
| - plus all dataset columns (except 'prompt') as kwargs | |
| e.g. kwargs["final_answer"] is our ground truth list[str]. | |
| """ | |
| final_answer = kwargs.get("final_answer") | |
| rewards: List[float] = [] | |
| # If we somehow don't get final_answer, just give length-based reward (debug fallback) | |
| if final_answer is None: | |
| return [float(len(c)) for c in completions] | |
| for comp, ref in zip(completions, final_answer): | |
| pred = extract_last_int(comp) | |
| if pred is not None and pred.strip() == ref.strip(): | |
| rewards.append(1.0) | |
| else: | |
| rewards.append(0.0) | |
| return rewards | |
| # ================== SFT phase ================== | |
| def run_sft(train_ds, eval_ds, tokenizer, cfg: PipelineConfig): | |
| """Run a short supervised fine-tuning pass with LoRA adapters (prompt-completion).""" | |
| target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| ) | |
| sft_config = SFTConfig( | |
| output_dir=os.path.join(cfg.output_dir, "sft"), | |
| per_device_train_batch_size=cfg.per_device_batch_size, | |
| per_device_eval_batch_size=cfg.per_device_batch_size, | |
| gradient_accumulation_steps=cfg.grad_accum_steps, | |
| learning_rate=cfg.sft_learning_rate, | |
| num_train_epochs=cfg.sft_epochs, | |
| logging_steps=10, | |
| save_steps=200, | |
| eval_steps=200, | |
| eval_strategy="steps", # transformers>=4.57 uses 'eval_strategy' | |
| save_total_limit=2, | |
| max_length=cfg.max_seq_length, | |
| bf16=cfg.bf16, | |
| fp16=not cfg.bf16, | |
| report_to=["none"], | |
| ) | |
| # SFTTrainer will look for 'prompt' & 'completion' columns in the dataset. | |
| trainer = SFTTrainer( | |
| model=cfg.model_name_or_path, # string path → SFTTrainer will load the model | |
| args=sft_config, | |
| train_dataset=train_ds, | |
| eval_dataset=eval_ds, | |
| processing_class=tokenizer, # new TRL API | |
| peft_config=peft_config, | |
| ) | |
| trainer.train() | |
| save_path = os.path.join(cfg.output_dir, "sft_model") | |
| trainer.save_model(save_path) | |
| return trainer.model # PEFT-wrapped model instance | |
| # ================== GRPO phase ================== | |
| def build_rl_dataset(train_ds): | |
| """ | |
| For GRPO we just need 'prompt'; we also keep 'final_answer' so reward_fn can use it. | |
| """ | |
| return train_ds | |
| def run_grpo(rl_dataset, base_model, tokenizer, cfg: PipelineConfig): | |
| """Run a short GRPO training loop on top of the (optionally) SFT-initialized model.""" | |
| target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES | |
| peft_config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| bias="none", | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| ) | |
| generation_batch_size = cfg.per_device_batch_size * cfg.num_generations | |
| grpo_config = GRPOConfig( | |
| output_dir=os.path.join(cfg.output_dir, "grpo"), | |
| num_train_epochs=cfg.grpo_epochs, | |
| per_device_train_batch_size=cfg.per_device_batch_size, | |
| gradient_accumulation_steps=cfg.grad_accum_steps, | |
| logging_steps=10, | |
| save_steps=200, | |
| save_total_limit=2, | |
| bf16=cfg.bf16, | |
| fp16=not cfg.bf16, | |
| learning_rate=cfg.grpo_learning_rate, | |
| max_prompt_length=cfg.max_seq_length, | |
| max_completion_length=cfg.max_completion_length, | |
| num_generations=cfg.num_generations, | |
| generation_batch_size=generation_batch_size, | |
| report_to=["none"], | |
| ) | |
| trainer = GRPOTrainer( | |
| model=base_model, # can be model instance or model id | |
| args=grpo_config, | |
| processing_class=tokenizer, | |
| reward_funcs=correctness_reward, # single custom reward | |
| train_dataset=rl_dataset, | |
| peft_config=peft_config, | |
| ) | |
| trainer.train() | |
| trainer.save_model(os.path.join(cfg.output_dir, "grpo_model")) | |
| # ================== Main ================== | |
| def main(): | |
| cfg = parse_args() | |
| os.makedirs(cfg.output_dir, exist_ok=True) | |
| print(f"Using model: {cfg.model_name_or_path}") | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| cfg.model_name_or_path, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # GRPO 要求 left padding | |
| tokenizer.padding_side = "left" | |
| print("Loading GSM8K dataset...") | |
| train_ds, eval_ds = load_gsm8k(cfg.train_samples, cfg.eval_samples) | |
| # ----- SFT ----- | |
| if cfg.skip_sft: | |
| print("Skipping SFT phase; loading base model directly.") | |
| dtype = ( | |
| torch.bfloat16 if cfg.bf16 and torch.cuda.is_available() | |
| else (torch.float16 if torch.cuda.is_available() else torch.float32) | |
| ) | |
| model_kwargs = { | |
| "torch_dtype": dtype, | |
| "trust_remote_code": True, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| base_model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, **model_kwargs) | |
| else: | |
| print("Running SFT phase...") | |
| base_model = run_sft(train_ds, eval_ds, tokenizer, cfg) | |
| # ----- GRPO ----- | |
| if cfg.skip_grpo: | |
| print("Skipping GRPO phase; only SFT outputs (if any) were produced.") | |
| else: | |
| print("Preparing RL dataset...") | |
| rl_dataset = build_rl_dataset(train_ds) | |
| print("Running GRPO phase...") | |
| run_grpo(rl_dataset, base_model, tokenizer, cfg) | |
| print("All done. Check outputs under:", cfg.output_dir) | |
| if __name__ == "__main__": | |
| main() | |