#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = [ # "trl>=0.12.0", # "transformers>=4.36.0", # "accelerate>=0.24.0", # "peft>=0.7.0", # "trackio", # "datasets>=2.14.0", # ] # /// """ GRPO training with Qwen2.5-7B-Instruct + LoRA on math reasoning dataset. """ from datasets import load_dataset from peft import LoraConfig from trl import GRPOTrainer, GRPOConfig # Load dataset — GRPO uses prompt-only format, take a demo subset dataset = load_dataset("trl-lib/math_shepherd", split="train[:3000]") print(f"✅ Dataset loaded: {len(dataset)} prompts") # LoRA config — necessary for 7B model to fit in GPU memory lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=0.05, task_type="CAUSAL_LM", ) # Training configuration config = GRPOConfig( # Hub settings — CRITICAL: environment is ephemeral output_dir="qwen2.5-7b-grpo-math", push_to_hub=True, hub_model_id="Conna/qwen2.5-7b-grpo-math", hub_strategy="every_save", # Training parameters num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=8, # effective batch = 16 learning_rate=1e-6, gradient_checkpointing=True, # save VRAM # Checkpointing logging_steps=10, save_strategy="steps", save_steps=100, save_total_limit=2, # LR schedule warmup_ratio=0.1, lr_scheduler_type="cosine", # Trackio monitoring report_to="trackio", project="qwen-grpo-training", run_name="qwen2.5-7b-grpo-math-lora", ) # GRPO requires an instruct-tuned model as base trainer = GRPOTrainer( model="Qwen/Qwen2.5-7B-Instruct", peft_config=lora_config, train_dataset=dataset, args=config, ) print("🚀 Starting GRPO training...") trainer.train() print("💾 Pushing final model to Hub...") trainer.push_to_hub() print("✅ Done! Model: https://huggingface.co/Conna/qwen2.5-7b-grpo-math") print("📊 Metrics: https://huggingface.co/spaces/Conna/trackio")