dkumar15 commited on
Commit
f099982
·
verified ·
1 Parent(s): 3f47090

Upload training_code/train_dpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_code/train_dpo.py +327 -0
training_code/train_dpo.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPO (Direct Preference Optimization) training for the 1B Transformer.
3
+
4
+ Takes the SFT model and aligns it with human preferences using
5
+ UltraFeedback preference pairs.
6
+
7
+ DPO Loss:
8
+ L = -log sigma(beta * (log pi(yw|x)/pi_ref(yw|x) - log pi(yl|x)/pi_ref(yl|x)))
9
+
10
+ Launch: torchrun --nproc_per_node=8 train_dpo.py
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import math
16
+ import time
17
+ import json
18
+ import datetime
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torch.distributed as dist
23
+ from torch.nn.parallel import DistributedDataParallel as DDP
24
+ from torch.utils.data.distributed import DistributedSampler
25
+
26
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
27
+ from model.config import ModelConfig
28
+ from model.transformer import Transformer
29
+ from model.data import get_tokenizer
30
+ from model.dpo_data import DPODataset, dpo_collate_fn
31
+
32
+
33
+ # === Config ===
34
+ SFT_CHECKPOINT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
35
+ DPO_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_dpo"
36
+ LOG_DIR = "/home/jovyan/training/logs"
37
+ DATA_CACHE = "/jfs/deepak-kumar/data"
38
+
39
+ NUM_EPOCHS = 1
40
+ BATCH_SIZE_PER_GPU = 2
41
+ GRADIENT_ACCUMULATION = 4 # effective batch = 2 * 8 * 4 = 64
42
+ MAX_SEQ_LEN = 1024
43
+ LEARNING_RATE = 5e-7 # very low LR for DPO
44
+ MIN_LR = 1e-7
45
+ WARMUP_STEPS = 100
46
+ WEIGHT_DECAY = 0.01
47
+ GRAD_CLIP = 1.0
48
+ BETA = 0.1 # DPO temperature
49
+ LOG_INTERVAL = 10
50
+ SAVE_INTERVAL = 200
51
+
52
+
53
+ def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
54
+ if step < warmup_steps:
55
+ return max_lr * step / max(warmup_steps, 1)
56
+ progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
57
+ return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
58
+
59
+
60
+ def get_per_token_logps(model, input_ids, prompt_lens):
61
+ """
62
+ Compute sum of log probabilities for response tokens only.
63
+ input_ids: [B, S] full sequence (prompt + response)
64
+ prompt_lens: [B] where response starts
65
+ Returns: [B] sum of log probs over response tokens
66
+ """
67
+ # Clone input to avoid inplace issues with shared RoPE buffers
68
+ inp = input_ids[:, :-1].contiguous()
69
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
70
+ logits, _ = model(inp)
71
+
72
+ labels = input_ids[:, 1:].contiguous()
73
+ log_probs = F.log_softmax(logits.float(), dim=-1)
74
+ token_logps = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2)
75
+
76
+ B, S = token_logps.shape
77
+ mask = torch.zeros_like(token_logps)
78
+ for b in range(B):
79
+ pl = prompt_lens[b].item()
80
+ response_start = max(0, pl - 1)
81
+ seq_len = (labels[b] != 0).sum().item()
82
+ mask[b, response_start:seq_len] = 1.0
83
+
84
+ return (token_logps * mask).sum(dim=1)
85
+
86
+
87
+ def dpo_loss(policy_chosen_logps, policy_rejected_logps,
88
+ ref_chosen_logps, ref_rejected_logps, beta=0.1):
89
+ """Compute DPO loss and metrics."""
90
+ chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
91
+ rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
92
+
93
+ logits = chosen_rewards - rejected_rewards
94
+ loss = -F.logsigmoid(logits).mean()
95
+
96
+ with torch.no_grad():
97
+ chosen_better = (chosen_rewards > rejected_rewards).float().mean()
98
+ reward_margin = (chosen_rewards - rejected_rewards).mean()
99
+
100
+ return loss, chosen_better.item(), reward_margin.item()
101
+
102
+
103
+ def main():
104
+ dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
105
+ rank = int(os.environ.get("RANK", 0))
106
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
107
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
108
+ torch.cuda.set_device(local_rank)
109
+ device = torch.device(f"cuda:{local_rank}")
110
+
111
+ if rank == 0:
112
+ os.makedirs(DPO_CHECKPOINT_DIR, exist_ok=True)
113
+ os.makedirs(LOG_DIR, exist_ok=True)
114
+ print("=" * 70)
115
+ print(" DPO: PREFERENCE ALIGNMENT FOR 1B TRANSFORMER")
116
+ print("=" * 70)
117
+
118
+ tokenizer = get_tokenizer()
119
+ special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
120
+ vocab = tokenizer.get_vocab()
121
+ new_tokens = [t for t in special_tokens if t not in vocab]
122
+ if new_tokens:
123
+ tokenizer.add_tokens(new_tokens, special_tokens=True)
124
+
125
+ model_config = ModelConfig()
126
+ model_config.vocab_size = len(tokenizer)
127
+
128
+ if rank == 0:
129
+ print(f"[Init] Loading SFT model from {SFT_CHECKPOINT}")
130
+
131
+ # Policy model (trainable)
132
+ policy = Transformer(model_config)
133
+ ckpt = torch.load(SFT_CHECKPOINT, map_location="cpu", weights_only=False)
134
+ policy.load_state_dict(ckpt["model"])
135
+ sft_step = ckpt.get("step", 0)
136
+ if rank == 0:
137
+ print(f"[Init] SFT model loaded (step {sft_step})")
138
+
139
+ # Reference model (frozen copy)
140
+ ref_model = Transformer(model_config)
141
+ ref_model.load_state_dict(ckpt["model"])
142
+ del ckpt
143
+
144
+ policy = policy.to(device)
145
+ ref_model = ref_model.to(device).bfloat16()
146
+ ref_model.eval()
147
+ for p in ref_model.parameters():
148
+ p.requires_grad = False
149
+
150
+ policy = DDP(policy, device_ids=[local_rank])
151
+
152
+ if rank == 0:
153
+ n = sum(p.numel() for p in policy.parameters())
154
+ print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
155
+ print(f"[Init] Beta: {BETA} | LR: {LEARNING_RATE}")
156
+
157
+ # Dataset
158
+ dataset = DPODataset(
159
+ tokenizer=tokenizer,
160
+ max_seq_len=MAX_SEQ_LEN,
161
+ split="train",
162
+ cache_dir=DATA_CACHE,
163
+ )
164
+
165
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
166
+ dataloader = torch.utils.data.DataLoader(
167
+ dataset,
168
+ batch_size=BATCH_SIZE_PER_GPU,
169
+ sampler=sampler,
170
+ num_workers=4,
171
+ pin_memory=True,
172
+ collate_fn=lambda b: dpo_collate_fn(b, pad_id=tokenizer.pad_token_id),
173
+ )
174
+
175
+ steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
176
+ total_steps = steps_per_epoch * NUM_EPOCHS
177
+
178
+ if rank == 0:
179
+ eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
180
+ print(f"[Init] Dataset: {len(dataset):,} preference pairs")
181
+ print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
182
+ print(f"[Init] Total steps: {total_steps}")
183
+ print("-" * 70)
184
+
185
+ decay_params = [p for n, p in policy.named_parameters() if p.dim() >= 2 and p.requires_grad]
186
+ nodecay_params = [p for n, p in policy.named_parameters() if p.dim() < 2 and p.requires_grad]
187
+ optimizer = torch.optim.AdamW([
188
+ {"params": decay_params, "weight_decay": WEIGHT_DECAY},
189
+ {"params": nodecay_params, "weight_decay": 0.0},
190
+ ], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
191
+
192
+ policy.train()
193
+ global_step = 0
194
+ running_loss = 0.0
195
+ running_acc = 0.0
196
+ running_margin = 0.0
197
+ t0 = time.time()
198
+
199
+ log_file = open(os.path.join(LOG_DIR, "dpo_log.jsonl"), "w") if rank == 0 else None
200
+
201
+ for epoch in range(NUM_EPOCHS):
202
+ sampler.set_epoch(epoch)
203
+ data_iter = iter(dataloader)
204
+
205
+ if rank == 0:
206
+ print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
207
+
208
+ while True:
209
+ optimizer.zero_grad(set_to_none=True)
210
+ batch_loss = 0.0
211
+ batch_acc = 0.0
212
+ batch_margin = 0.0
213
+ valid_micros = 0
214
+
215
+ for _ in range(GRADIENT_ACCUMULATION):
216
+ try:
217
+ batch = next(data_iter)
218
+ except StopIteration:
219
+ break
220
+
221
+ chosen_ids = batch["chosen_ids"].to(device, non_blocking=True)
222
+ rejected_ids = batch["rejected_ids"].to(device, non_blocking=True)
223
+ prompt_lens = batch["prompt_lens"].to(device, non_blocking=True)
224
+
225
+ policy_chosen_logps = get_per_token_logps(policy, chosen_ids, prompt_lens)
226
+ policy_rejected_logps = get_per_token_logps(policy, rejected_ids, prompt_lens)
227
+
228
+ with torch.no_grad():
229
+ ref_chosen_logps = get_per_token_logps(ref_model, chosen_ids, prompt_lens)
230
+ ref_rejected_logps = get_per_token_logps(ref_model, rejected_ids, prompt_lens)
231
+
232
+ loss, acc, margin = dpo_loss(
233
+ policy_chosen_logps, policy_rejected_logps,
234
+ ref_chosen_logps, ref_rejected_logps,
235
+ beta=BETA,
236
+ )
237
+ loss = loss / GRADIENT_ACCUMULATION
238
+ loss.backward()
239
+
240
+ batch_loss += loss.item()
241
+ batch_acc += acc
242
+ batch_margin += margin
243
+ valid_micros += 1
244
+
245
+ if valid_micros == 0:
246
+ break
247
+
248
+ torch.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_CLIP)
249
+
250
+ lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
251
+ for pg in optimizer.param_groups:
252
+ pg["lr"] = lr
253
+
254
+ optimizer.step()
255
+ global_step += 1
256
+ running_loss += batch_loss
257
+ running_acc += batch_acc / valid_micros
258
+ running_margin += batch_margin / valid_micros
259
+
260
+ if global_step % LOG_INTERVAL == 0:
261
+ avg_loss = running_loss / LOG_INTERVAL
262
+ avg_acc = running_acc / LOG_INTERVAL
263
+ avg_margin = running_margin / LOG_INTERVAL
264
+ elapsed = time.time() - t0
265
+ pct = 100.0 * global_step / total_steps
266
+ eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
267
+
268
+ if rank == 0:
269
+ gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
270
+ print(
271
+ f" [Step {global_step:>5d}/{total_steps}] "
272
+ f"loss={avg_loss:.4f} | acc={avg_acc:.1%} | "
273
+ f"margin={avg_margin:.3f} | lr={lr:.2e} | "
274
+ f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
275
+ flush=True,
276
+ )
277
+ if log_file:
278
+ log_file.write(json.dumps({
279
+ "step": global_step, "loss": round(avg_loss, 4),
280
+ "accuracy": round(avg_acc, 4),
281
+ "reward_margin": round(avg_margin, 4),
282
+ "lr": lr, "elapsed_s": round(elapsed, 1),
283
+ }) + "\n")
284
+ log_file.flush()
285
+
286
+ running_loss = 0.0
287
+ running_acc = 0.0
288
+ running_margin = 0.0
289
+
290
+ if global_step % SAVE_INTERVAL == 0:
291
+ dist.barrier()
292
+ if rank == 0:
293
+ path = os.path.join(DPO_CHECKPOINT_DIR, f"dpo_step_{global_step}.pt")
294
+ torch.save({
295
+ "step": global_step,
296
+ "model": policy.module.state_dict(),
297
+ "config": model_config.__dict__,
298
+ "vocab_size": model_config.vocab_size,
299
+ }, path)
300
+ print(f" >> Checkpoint: {path}", flush=True)
301
+ dist.barrier()
302
+
303
+ # Final save
304
+ dist.barrier()
305
+ if rank == 0:
306
+ final_path = os.path.join(DPO_CHECKPOINT_DIR, "dpo_final.pt")
307
+ torch.save({
308
+ "step": global_step,
309
+ "model": policy.module.state_dict(),
310
+ "config": model_config.__dict__,
311
+ "vocab_size": model_config.vocab_size,
312
+ }, final_path)
313
+ total_time = time.time() - t0
314
+ print("=" * 70)
315
+ print(f" DPO COMPLETE")
316
+ print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
317
+ print(f" Time: {total_time/60:.1f} minutes")
318
+ print(f" Final model: {final_path}")
319
+ print("=" * 70)
320
+ if log_file:
321
+ log_file.close()
322
+
323
+ dist.destroy_process_group()
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()