Conna commited on
Commit
9345dd0
·
verified ·
1 Parent(s): 8b77271

Upload train_grpo_qwen7b.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_grpo_qwen7b.py +81 -0
train_grpo_qwen7b.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "trl>=0.12.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "peft>=0.7.0",
9
+ # "trackio",
10
+ # "datasets>=2.14.0",
11
+ # ]
12
+ # ///
13
+
14
+ """
15
+ GRPO training with Qwen2.5-7B-Instruct + LoRA on math reasoning dataset.
16
+ """
17
+
18
+ from datasets import load_dataset
19
+ from peft import LoraConfig
20
+ from trl import GRPOTrainer, GRPOConfig
21
+
22
+ # Load dataset — GRPO uses prompt-only format, take a demo subset
23
+ dataset = load_dataset("trl-lib/math_shepherd", split="train[:3000]")
24
+ print(f"✅ Dataset loaded: {len(dataset)} prompts")
25
+
26
+ # LoRA config — necessary for 7B model to fit in GPU memory
27
+ lora_config = LoraConfig(
28
+ r=16,
29
+ lora_alpha=32,
30
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
31
+ lora_dropout=0.05,
32
+ task_type="CAUSAL_LM",
33
+ )
34
+
35
+ # Training configuration
36
+ config = GRPOConfig(
37
+ # Hub settings — CRITICAL: environment is ephemeral
38
+ output_dir="qwen2.5-7b-grpo-math",
39
+ push_to_hub=True,
40
+ hub_model_id="Conna/qwen2.5-7b-grpo-math",
41
+ hub_strategy="every_save",
42
+
43
+ # Training parameters
44
+ num_train_epochs=1,
45
+ per_device_train_batch_size=2,
46
+ gradient_accumulation_steps=8, # effective batch = 16
47
+ learning_rate=1e-6,
48
+ gradient_checkpointing=True, # save VRAM
49
+
50
+ # Checkpointing
51
+ logging_steps=10,
52
+ save_strategy="steps",
53
+ save_steps=100,
54
+ save_total_limit=2,
55
+
56
+ # LR schedule
57
+ warmup_ratio=0.1,
58
+ lr_scheduler_type="cosine",
59
+
60
+ # Trackio monitoring
61
+ report_to="trackio",
62
+ project="qwen-grpo-training",
63
+ run_name="qwen2.5-7b-grpo-math-lora",
64
+ )
65
+
66
+ # GRPO requires an instruct-tuned model as base
67
+ trainer = GRPOTrainer(
68
+ model="Qwen/Qwen2.5-7B-Instruct",
69
+ peft_config=lora_config,
70
+ train_dataset=dataset,
71
+ args=config,
72
+ )
73
+
74
+ print("🚀 Starting GRPO training...")
75
+ trainer.train()
76
+
77
+ print("💾 Pushing final model to Hub...")
78
+ trainer.push_to_hub()
79
+
80
+ print("✅ Done! Model: https://huggingface.co/Conna/qwen2.5-7b-grpo-math")
81
+ print("📊 Metrics: https://huggingface.co/spaces/Conna/trackio")