""" DPO (Direct Preference Optimization) training for the 1B Transformer. Takes the SFT model and aligns it with human preferences using UltraFeedback preference pairs. DPO Loss: L = -log sigma(beta * (log pi(yw|x)/pi_ref(yw|x) - log pi(yl|x)/pi_ref(yl|x))) Launch: torchrun --nproc_per_node=8 train_dpo.py """ import os import sys import math import time import json import datetime import torch import torch.nn.functional as F import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.config import ModelConfig from model.transformer import Transformer from model.data import get_tokenizer from model.dpo_data import DPODataset, dpo_collate_fn # === Config === SFT_CHECKPOINT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt" DPO_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_dpo" LOG_DIR = "/home/jovyan/training/logs" DATA_CACHE = "/jfs/deepak-kumar/data" NUM_EPOCHS = 1 BATCH_SIZE_PER_GPU = 2 GRADIENT_ACCUMULATION = 4 # effective batch = 2 * 8 * 4 = 64 MAX_SEQ_LEN = 1024 LEARNING_RATE = 5e-7 # very low LR for DPO MIN_LR = 1e-7 WARMUP_STEPS = 100 WEIGHT_DECAY = 0.01 GRAD_CLIP = 1.0 BETA = 0.1 # DPO temperature LOG_INTERVAL = 10 SAVE_INTERVAL = 200 def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr): if step < warmup_steps: return max_lr * step / max(warmup_steps, 1) progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) def get_per_token_logps(model, input_ids, prompt_lens): """ Compute sum of log probabilities for response tokens only. input_ids: [B, S] full sequence (prompt + response) prompt_lens: [B] where response starts Returns: [B] sum of log probs over response tokens """ # Clone input to avoid inplace issues with shared RoPE buffers inp = input_ids[:, :-1].contiguous() with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits, _ = model(inp) labels = input_ids[:, 1:].contiguous() log_probs = F.log_softmax(logits.float(), dim=-1) token_logps = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2) B, S = token_logps.shape mask = torch.zeros_like(token_logps) for b in range(B): pl = prompt_lens[b].item() response_start = max(0, pl - 1) seq_len = (labels[b] != 0).sum().item() mask[b, response_start:seq_len] = 1.0 return (token_logps * mask).sum(dim=1) def dpo_loss(policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, beta=0.1): """Compute DPO loss and metrics.""" chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps) rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps) logits = chosen_rewards - rejected_rewards loss = -F.logsigmoid(logits).mean() with torch.no_grad(): chosen_better = (chosen_rewards > rejected_rewards).float().mean() reward_margin = (chosen_rewards - rejected_rewards).mean() return loss, chosen_better.item(), reward_margin.item() def main(): dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30)) rank = int(os.environ.get("RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") if rank == 0: os.makedirs(DPO_CHECKPOINT_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True) print("=" * 70) print(" DPO: PREFERENCE ALIGNMENT FOR 1B TRANSFORMER") print("=" * 70) tokenizer = get_tokenizer() special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"] vocab = tokenizer.get_vocab() new_tokens = [t for t in special_tokens if t not in vocab] if new_tokens: tokenizer.add_tokens(new_tokens, special_tokens=True) model_config = ModelConfig() model_config.vocab_size = len(tokenizer) if rank == 0: print(f"[Init] Loading SFT model from {SFT_CHECKPOINT}") # Policy model (trainable) policy = Transformer(model_config) ckpt = torch.load(SFT_CHECKPOINT, map_location="cpu", weights_only=False) policy.load_state_dict(ckpt["model"]) sft_step = ckpt.get("step", 0) if rank == 0: print(f"[Init] SFT model loaded (step {sft_step})") # Reference model (frozen copy) ref_model = Transformer(model_config) ref_model.load_state_dict(ckpt["model"]) del ckpt policy = policy.to(device) ref_model = ref_model.to(device).bfloat16() ref_model.eval() for p in ref_model.parameters(): p.requires_grad = False policy = DDP(policy, device_ids=[local_rank]) if rank == 0: n = sum(p.numel() for p in policy.parameters()) print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100") print(f"[Init] Beta: {BETA} | LR: {LEARNING_RATE}") # Dataset dataset = DPODataset( tokenizer=tokenizer, max_seq_len=MAX_SEQ_LEN, split="train", cache_dir=DATA_CACHE, ) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True) dataloader = torch.utils.data.DataLoader( dataset, batch_size=BATCH_SIZE_PER_GPU, sampler=sampler, num_workers=4, pin_memory=True, collate_fn=lambda b: dpo_collate_fn(b, pad_id=tokenizer.pad_token_id), ) steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION total_steps = steps_per_epoch * NUM_EPOCHS if rank == 0: eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION print(f"[Init] Dataset: {len(dataset):,} preference pairs") print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}") print(f"[Init] Total steps: {total_steps}") print("-" * 70) decay_params = [p for n, p in policy.named_parameters() if p.dim() >= 2 and p.requires_grad] nodecay_params = [p for n, p in policy.named_parameters() if p.dim() < 2 and p.requires_grad] optimizer = torch.optim.AdamW([ {"params": decay_params, "weight_decay": WEIGHT_DECAY}, {"params": nodecay_params, "weight_decay": 0.0}, ], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True) policy.train() global_step = 0 running_loss = 0.0 running_acc = 0.0 running_margin = 0.0 t0 = time.time() log_file = open(os.path.join(LOG_DIR, "dpo_log.jsonl"), "w") if rank == 0 else None for epoch in range(NUM_EPOCHS): sampler.set_epoch(epoch) data_iter = iter(dataloader) if rank == 0: print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]") while True: optimizer.zero_grad(set_to_none=True) batch_loss = 0.0 batch_acc = 0.0 batch_margin = 0.0 valid_micros = 0 for _ in range(GRADIENT_ACCUMULATION): try: batch = next(data_iter) except StopIteration: break chosen_ids = batch["chosen_ids"].to(device, non_blocking=True) rejected_ids = batch["rejected_ids"].to(device, non_blocking=True) prompt_lens = batch["prompt_lens"].to(device, non_blocking=True) policy_chosen_logps = get_per_token_logps(policy, chosen_ids, prompt_lens) policy_rejected_logps = get_per_token_logps(policy, rejected_ids, prompt_lens) with torch.no_grad(): ref_chosen_logps = get_per_token_logps(ref_model, chosen_ids, prompt_lens) ref_rejected_logps = get_per_token_logps(ref_model, rejected_ids, prompt_lens) loss, acc, margin = dpo_loss( policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, beta=BETA, ) loss = loss / GRADIENT_ACCUMULATION loss.backward() batch_loss += loss.item() batch_acc += acc batch_margin += margin valid_micros += 1 if valid_micros == 0: break torch.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_CLIP) lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR) for pg in optimizer.param_groups: pg["lr"] = lr optimizer.step() global_step += 1 running_loss += batch_loss running_acc += batch_acc / valid_micros running_margin += batch_margin / valid_micros if global_step % LOG_INTERVAL == 0: avg_loss = running_loss / LOG_INTERVAL avg_acc = running_acc / LOG_INTERVAL avg_margin = running_margin / LOG_INTERVAL elapsed = time.time() - t0 pct = 100.0 * global_step / total_steps eta = (elapsed / max(global_step, 1)) * (total_steps - global_step) if rank == 0: gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 print( f" [Step {global_step:>5d}/{total_steps}] " f"loss={avg_loss:.4f} | acc={avg_acc:.1%} | " f"margin={avg_margin:.3f} | lr={lr:.2e} | " f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m", flush=True, ) if log_file: log_file.write(json.dumps({ "step": global_step, "loss": round(avg_loss, 4), "accuracy": round(avg_acc, 4), "reward_margin": round(avg_margin, 4), "lr": lr, "elapsed_s": round(elapsed, 1), }) + "\n") log_file.flush() running_loss = 0.0 running_acc = 0.0 running_margin = 0.0 if global_step % SAVE_INTERVAL == 0: dist.barrier() if rank == 0: path = os.path.join(DPO_CHECKPOINT_DIR, f"dpo_step_{global_step}.pt") torch.save({ "step": global_step, "model": policy.module.state_dict(), "config": model_config.__dict__, "vocab_size": model_config.vocab_size, }, path) print(f" >> Checkpoint: {path}", flush=True) dist.barrier() # Final save dist.barrier() if rank == 0: final_path = os.path.join(DPO_CHECKPOINT_DIR, "dpo_final.pt") torch.save({ "step": global_step, "model": policy.module.state_dict(), "config": model_config.__dict__, "vocab_size": model_config.vocab_size, }, final_path) total_time = time.time() - t0 print("=" * 70) print(f" DPO COMPLETE") print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}") print(f" Time: {total_time/60:.1f} minutes") print(f" Final model: {final_path}") print("=" * 70) if log_file: log_file.close() dist.destroy_process_group() if __name__ == "__main__": main()