Add train_grpo.py
Browse files- train_grpo.py +171 -0
train_grpo.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""C27: GRPO (Group Relative Policy Optimization) for math reasoning
|
| 2 |
+
Based on DeepSeekMath GRPO + Gemma-2-2B success recipe from literature.
|
| 3 |
+
"""
|
| 4 |
+
import json, re, random, torch, numpy as np, os
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 7 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 8 |
+
|
| 9 |
+
SEED = 42
|
| 10 |
+
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
|
| 11 |
+
if torch.cuda.is_available():
|
| 12 |
+
torch.cuda.manual_seed_all(SEED)
|
| 13 |
+
if torch.cuda.get_device_capability()[0] >= 8:
|
| 14 |
+
torch.set_float32_matmul_precision('high')
|
| 15 |
+
|
| 16 |
+
BASE_MODEL = "outputs/models/c20-2-5x-replay"
|
| 17 |
+
|
| 18 |
+
SP = "์ฃผ์ด์ง ์ํ ๋ฌธ์ ๋ฅผ ๋จ๊ณ๋ณ๋ก ํ๊ณ ๋ต๋ณ์ ์์ฑํ์ธ์.\n๋ฐ๋์ ์ต์ข
๋ต๋ณ์ \\boxed{์ ์} ํ์์ผ๋ก ๋ง์ง๋ง ์ค์ ์ถ๋ ฅํ์ธ์.\n์์: \\boxed{42}"
|
| 19 |
+
|
| 20 |
+
# === Load questions + ground truth ===
|
| 21 |
+
with open("data/GSM8K_full_qwen3_30b.json") as f:
|
| 22 |
+
teacher_data = json.load(f)
|
| 23 |
+
|
| 24 |
+
def extract_boxed(text):
|
| 25 |
+
m = re.findall(r'\\boxed\{([^}]+)\}', text)
|
| 26 |
+
return m[-1].strip() if m else None
|
| 27 |
+
|
| 28 |
+
def normalize(a):
|
| 29 |
+
if a is None: return None
|
| 30 |
+
s = str(a).replace(",","").replace(" ","").strip()
|
| 31 |
+
try:
|
| 32 |
+
n = float(s)
|
| 33 |
+
return str(int(n)) if n == int(n) else str(n)
|
| 34 |
+
except: return s
|
| 35 |
+
|
| 36 |
+
# Build ground truth
|
| 37 |
+
gt_by_q = {}
|
| 38 |
+
for t in teacher_data:
|
| 39 |
+
ans = extract_boxed(t["answer"])
|
| 40 |
+
if ans is None: continue
|
| 41 |
+
q = t["question"]
|
| 42 |
+
na = normalize(ans)
|
| 43 |
+
if q not in gt_by_q: gt_by_q[q] = {}
|
| 44 |
+
gt_by_q[q][na] = gt_by_q[q].get(na, 0) + 1
|
| 45 |
+
|
| 46 |
+
ground_truth = {q: max(counts, key=counts.get) for q, counts in gt_by_q.items()}
|
| 47 |
+
questions = list(ground_truth.keys())
|
| 48 |
+
random.shuffle(questions)
|
| 49 |
+
print(f"Total questions: {len(questions)}")
|
| 50 |
+
|
| 51 |
+
# Build dataset with prompt (conversational format) + answer column
|
| 52 |
+
dataset_items = []
|
| 53 |
+
for q in questions:
|
| 54 |
+
dataset_items.append({
|
| 55 |
+
"prompt": [{"role": "user", "content": SP + "\n\n" + q}],
|
| 56 |
+
"answer": ground_truth[q],
|
| 57 |
+
})
|
| 58 |
+
|
| 59 |
+
train_ds = Dataset.from_list(dataset_items)
|
| 60 |
+
print(f"Train dataset: {len(train_ds)}")
|
| 61 |
+
|
| 62 |
+
# === Reward function ===
|
| 63 |
+
def math_reward(prompts, completions, answer, **kwargs):
|
| 64 |
+
"""Reward: 1.0 if boxed answer matches ground truth, 0.0 otherwise.
|
| 65 |
+
Also +0.1 bonus for having boxed format (encourages format compliance)."""
|
| 66 |
+
rewards = []
|
| 67 |
+
for completion, gt in zip(completions, answer):
|
| 68 |
+
# Handle conversational format (list of dicts) or plain string
|
| 69 |
+
if isinstance(completion, list):
|
| 70 |
+
text = completion[-1]["content"] if completion else ""
|
| 71 |
+
elif isinstance(completion, dict):
|
| 72 |
+
text = completion.get("content", "")
|
| 73 |
+
else:
|
| 74 |
+
text = str(completion)
|
| 75 |
+
|
| 76 |
+
pred = extract_boxed(text)
|
| 77 |
+
if pred is None:
|
| 78 |
+
rewards.append(0.0) # No boxed = 0
|
| 79 |
+
elif normalize(pred) == gt:
|
| 80 |
+
rewards.append(1.0) # Correct = 1.0
|
| 81 |
+
else:
|
| 82 |
+
rewards.append(0.1) # Wrong but has boxed format = 0.1
|
| 83 |
+
return rewards
|
| 84 |
+
|
| 85 |
+
# === GRPO Config ===
|
| 86 |
+
# Literature-informed settings for 1B model
|
| 87 |
+
NUM_GEN = 8
|
| 88 |
+
BATCH_SIZE = 8 # per device, must be divisible by num_generations
|
| 89 |
+
GRAD_ACCUM = 4 # effective batch = 8 * 4 = 32, / 8 gen = 4 prompts per step
|
| 90 |
+
|
| 91 |
+
config = GRPOConfig(
|
| 92 |
+
output_dir="outputs/c27_grpo_ckpt",
|
| 93 |
+
report_to="none",
|
| 94 |
+
seed=SEED,
|
| 95 |
+
|
| 96 |
+
# Generation
|
| 97 |
+
num_generations=NUM_GEN,
|
| 98 |
+
max_completion_length=1024,
|
| 99 |
+
temperature=0.7,
|
| 100 |
+
|
| 101 |
+
# GRPO algorithm
|
| 102 |
+
beta=0.04, # Higher KL penalty to preserve format/quality
|
| 103 |
+
loss_type="grpo", # Standard GRPO
|
| 104 |
+
epsilon=0.2, # PPO-style clipping
|
| 105 |
+
scale_rewards="group", # Normalize within group
|
| 106 |
+
|
| 107 |
+
# Training
|
| 108 |
+
num_train_epochs=1,
|
| 109 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 110 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 111 |
+
learning_rate=5e-6,
|
| 112 |
+
lr_scheduler_type="cosine",
|
| 113 |
+
warmup_ratio=0.05,
|
| 114 |
+
max_grad_norm=0.1, # Very strict gradient clipping (from literature)
|
| 115 |
+
optim="paged_adamw_8bit",
|
| 116 |
+
bf16=True,
|
| 117 |
+
gradient_checkpointing=True,
|
| 118 |
+
|
| 119 |
+
# Logging & saving
|
| 120 |
+
logging_steps=10,
|
| 121 |
+
save_strategy="no",
|
| 122 |
+
max_steps=500,
|
| 123 |
+
|
| 124 |
+
# vLLM for fast generation
|
| 125 |
+
use_vllm=True,
|
| 126 |
+
vllm_mode="colocate",
|
| 127 |
+
vllm_gpu_memory_utilization=0.3,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
print(f"Config: num_gen={NUM_GEN}, batch={BATCH_SIZE}, ga={GRAD_ACCUM}")
|
| 131 |
+
print(f" effective_batch={BATCH_SIZE * GRAD_ACCUM}, prompts_per_step={BATCH_SIZE * GRAD_ACCUM // NUM_GEN}")
|
| 132 |
+
print(f" lr={config.learning_rate}, beta={config.beta}, max_steps={config.max_steps}")
|
| 133 |
+
print(f" vllm={config.use_vllm}, mode={config.vllm_mode}")
|
| 134 |
+
|
| 135 |
+
# === Load model ===
|
| 136 |
+
print("\nLoading model...")
|
| 137 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 138 |
+
BASE_MODEL, torch_dtype=torch.bfloat16,
|
| 139 |
+
attn_implementation="flash_attention_2",
|
| 140 |
+
)
|
| 141 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 142 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 143 |
+
|
| 144 |
+
# === Train ===
|
| 145 |
+
trainer = GRPOTrainer(
|
| 146 |
+
model=model,
|
| 147 |
+
reward_funcs=math_reward,
|
| 148 |
+
args=config,
|
| 149 |
+
train_dataset=train_ds,
|
| 150 |
+
processing_class=tokenizer,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
print("\n=== Starting GRPO Training ===")
|
| 154 |
+
result = trainer.train()
|
| 155 |
+
print(f"\nTraining loss: {result.training_loss:.4f}")
|
| 156 |
+
|
| 157 |
+
# Save
|
| 158 |
+
SAVE_DIR = "outputs/models/c27-grpo"
|
| 159 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 160 |
+
trainer.save_model(SAVE_DIR)
|
| 161 |
+
tokenizer.save_pretrained(SAVE_DIR)
|
| 162 |
+
print(f"Saved: {SAVE_DIR}")
|
| 163 |
+
|
| 164 |
+
# Print training metrics summary
|
| 165 |
+
logs = trainer.state.log_history
|
| 166 |
+
reward_logs = [l for l in logs if "reward" in str(l)]
|
| 167 |
+
for l in reward_logs[-5:]:
|
| 168 |
+
print(f" step={l.get('step','?')}: reward={l.get('reward', l.get('rewards/mean','?'))}, "
|
| 169 |
+
f"completion_length={l.get('completion_length', '?')}")
|
| 170 |
+
|
| 171 |
+
print("\n=== GRPO Training Complete ===")
|