NotoriousH2 commited on
Commit
b50d571
ยท
verified ยท
1 Parent(s): 12dd0e7

Add train_grpo.py

Browse files
Files changed (1) hide show
  1. train_grpo.py +171 -0
train_grpo.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """C27: GRPO (Group Relative Policy Optimization) for math reasoning
2
+ Based on DeepSeekMath GRPO + Gemma-2-2B success recipe from literature.
3
+ """
4
+ import json, re, random, torch, numpy as np, os
5
+ from datasets import Dataset
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from trl import GRPOConfig, GRPOTrainer
8
+
9
+ SEED = 42
10
+ random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
11
+ if torch.cuda.is_available():
12
+ torch.cuda.manual_seed_all(SEED)
13
+ if torch.cuda.get_device_capability()[0] >= 8:
14
+ torch.set_float32_matmul_precision('high')
15
+
16
+ BASE_MODEL = "outputs/models/c20-2-5x-replay"
17
+
18
+ SP = "์ฃผ์–ด์ง„ ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ๋‹จ๊ณ„๋ณ„๋กœ ํ’€๊ณ  ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.\n๋ฐ˜๋“œ์‹œ ์ตœ์ข… ๋‹ต๋ณ€์„ \\boxed{์ •์ˆ˜} ํ˜•์‹์œผ๋กœ ๋งˆ์ง€๋ง‰ ์ค„์— ์ถœ๋ ฅํ•˜์„ธ์š”.\n์˜ˆ์‹œ: \\boxed{42}"
19
+
20
+ # === Load questions + ground truth ===
21
+ with open("data/GSM8K_full_qwen3_30b.json") as f:
22
+ teacher_data = json.load(f)
23
+
24
+ def extract_boxed(text):
25
+ m = re.findall(r'\\boxed\{([^}]+)\}', text)
26
+ return m[-1].strip() if m else None
27
+
28
+ def normalize(a):
29
+ if a is None: return None
30
+ s = str(a).replace(",","").replace(" ","").strip()
31
+ try:
32
+ n = float(s)
33
+ return str(int(n)) if n == int(n) else str(n)
34
+ except: return s
35
+
36
+ # Build ground truth
37
+ gt_by_q = {}
38
+ for t in teacher_data:
39
+ ans = extract_boxed(t["answer"])
40
+ if ans is None: continue
41
+ q = t["question"]
42
+ na = normalize(ans)
43
+ if q not in gt_by_q: gt_by_q[q] = {}
44
+ gt_by_q[q][na] = gt_by_q[q].get(na, 0) + 1
45
+
46
+ ground_truth = {q: max(counts, key=counts.get) for q, counts in gt_by_q.items()}
47
+ questions = list(ground_truth.keys())
48
+ random.shuffle(questions)
49
+ print(f"Total questions: {len(questions)}")
50
+
51
+ # Build dataset with prompt (conversational format) + answer column
52
+ dataset_items = []
53
+ for q in questions:
54
+ dataset_items.append({
55
+ "prompt": [{"role": "user", "content": SP + "\n\n" + q}],
56
+ "answer": ground_truth[q],
57
+ })
58
+
59
+ train_ds = Dataset.from_list(dataset_items)
60
+ print(f"Train dataset: {len(train_ds)}")
61
+
62
+ # === Reward function ===
63
+ def math_reward(prompts, completions, answer, **kwargs):
64
+ """Reward: 1.0 if boxed answer matches ground truth, 0.0 otherwise.
65
+ Also +0.1 bonus for having boxed format (encourages format compliance)."""
66
+ rewards = []
67
+ for completion, gt in zip(completions, answer):
68
+ # Handle conversational format (list of dicts) or plain string
69
+ if isinstance(completion, list):
70
+ text = completion[-1]["content"] if completion else ""
71
+ elif isinstance(completion, dict):
72
+ text = completion.get("content", "")
73
+ else:
74
+ text = str(completion)
75
+
76
+ pred = extract_boxed(text)
77
+ if pred is None:
78
+ rewards.append(0.0) # No boxed = 0
79
+ elif normalize(pred) == gt:
80
+ rewards.append(1.0) # Correct = 1.0
81
+ else:
82
+ rewards.append(0.1) # Wrong but has boxed format = 0.1
83
+ return rewards
84
+
85
+ # === GRPO Config ===
86
+ # Literature-informed settings for 1B model
87
+ NUM_GEN = 8
88
+ BATCH_SIZE = 8 # per device, must be divisible by num_generations
89
+ GRAD_ACCUM = 4 # effective batch = 8 * 4 = 32, / 8 gen = 4 prompts per step
90
+
91
+ config = GRPOConfig(
92
+ output_dir="outputs/c27_grpo_ckpt",
93
+ report_to="none",
94
+ seed=SEED,
95
+
96
+ # Generation
97
+ num_generations=NUM_GEN,
98
+ max_completion_length=1024,
99
+ temperature=0.7,
100
+
101
+ # GRPO algorithm
102
+ beta=0.04, # Higher KL penalty to preserve format/quality
103
+ loss_type="grpo", # Standard GRPO
104
+ epsilon=0.2, # PPO-style clipping
105
+ scale_rewards="group", # Normalize within group
106
+
107
+ # Training
108
+ num_train_epochs=1,
109
+ per_device_train_batch_size=BATCH_SIZE,
110
+ gradient_accumulation_steps=GRAD_ACCUM,
111
+ learning_rate=5e-6,
112
+ lr_scheduler_type="cosine",
113
+ warmup_ratio=0.05,
114
+ max_grad_norm=0.1, # Very strict gradient clipping (from literature)
115
+ optim="paged_adamw_8bit",
116
+ bf16=True,
117
+ gradient_checkpointing=True,
118
+
119
+ # Logging & saving
120
+ logging_steps=10,
121
+ save_strategy="no",
122
+ max_steps=500,
123
+
124
+ # vLLM for fast generation
125
+ use_vllm=True,
126
+ vllm_mode="colocate",
127
+ vllm_gpu_memory_utilization=0.3,
128
+ )
129
+
130
+ print(f"Config: num_gen={NUM_GEN}, batch={BATCH_SIZE}, ga={GRAD_ACCUM}")
131
+ print(f" effective_batch={BATCH_SIZE * GRAD_ACCUM}, prompts_per_step={BATCH_SIZE * GRAD_ACCUM // NUM_GEN}")
132
+ print(f" lr={config.learning_rate}, beta={config.beta}, max_steps={config.max_steps}")
133
+ print(f" vllm={config.use_vllm}, mode={config.vllm_mode}")
134
+
135
+ # === Load model ===
136
+ print("\nLoading model...")
137
+ model = AutoModelForCausalLM.from_pretrained(
138
+ BASE_MODEL, torch_dtype=torch.bfloat16,
139
+ attn_implementation="flash_attention_2",
140
+ )
141
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
142
+ tokenizer.pad_token = tokenizer.eos_token
143
+
144
+ # === Train ===
145
+ trainer = GRPOTrainer(
146
+ model=model,
147
+ reward_funcs=math_reward,
148
+ args=config,
149
+ train_dataset=train_ds,
150
+ processing_class=tokenizer,
151
+ )
152
+
153
+ print("\n=== Starting GRPO Training ===")
154
+ result = trainer.train()
155
+ print(f"\nTraining loss: {result.training_loss:.4f}")
156
+
157
+ # Save
158
+ SAVE_DIR = "outputs/models/c27-grpo"
159
+ os.makedirs(SAVE_DIR, exist_ok=True)
160
+ trainer.save_model(SAVE_DIR)
161
+ tokenizer.save_pretrained(SAVE_DIR)
162
+ print(f"Saved: {SAVE_DIR}")
163
+
164
+ # Print training metrics summary
165
+ logs = trainer.state.log_history
166
+ reward_logs = [l for l in logs if "reward" in str(l)]
167
+ for l in reward_logs[-5:]:
168
+ print(f" step={l.get('step','?')}: reward={l.get('reward', l.get('rewards/mean','?'))}, "
169
+ f"completion_length={l.get('completion_length', '?')}")
170
+
171
+ print("\n=== GRPO Training Complete ===")