dkumar15 commited on
Commit
3f47090
·
verified ·
1 Parent(s): d42a1f3

Upload training_code/train_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_code/train_sft.py +272 -0
training_code/train_sft.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SFT (Supervised Fine-Tuning) script for the 1B Transformer.
3
+
4
+ Takes the pretrained base model and fine-tunes it on instruction-response
5
+ conversations from UltraChat 200K.
6
+
7
+ Launch: torchrun --nproc_per_node=8 train_sft.py
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import math
13
+ import time
14
+ import json
15
+ import datetime
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from torch.nn.parallel import DistributedDataParallel as DDP
20
+ from torch.utils.data.distributed import DistributedSampler
21
+
22
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
23
+ from model.config import ModelConfig
24
+ from model.transformer import Transformer
25
+ from model.data import get_tokenizer
26
+ from model.sft_data import SFTDataset, sft_collate_fn
27
+
28
+
29
+ # === Config ===
30
+ BASE_CHECKPOINT = "/jfs/deepak-kumar/checkpoints/step_19000.pt"
31
+ SFT_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_sft"
32
+ LOG_DIR = "/home/jovyan/training/logs"
33
+ DATA_CACHE = "/jfs/deepak-kumar/data"
34
+
35
+ NUM_EPOCHS = 2
36
+ BATCH_SIZE_PER_GPU = 4
37
+ GRADIENT_ACCUMULATION = 4 # effective batch = 4 * 8 * 4 = 128
38
+ MAX_SEQ_LEN = 2048
39
+ LEARNING_RATE = 2e-5 # much lower than pretraining — we're fine-tuning
40
+ MIN_LR = 2e-6
41
+ WARMUP_STEPS = 200
42
+ WEIGHT_DECAY = 0.01
43
+ GRAD_CLIP = 1.0
44
+ LOG_INTERVAL = 10
45
+ SAVE_INTERVAL = 500
46
+
47
+
48
+ def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
49
+ if step < warmup_steps:
50
+ return max_lr * step / max(warmup_steps, 1)
51
+ progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
52
+ return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
53
+
54
+
55
+ def main():
56
+ dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
57
+ rank = int(os.environ.get("RANK", 0))
58
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
59
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
60
+ torch.cuda.set_device(local_rank)
61
+ device = torch.device(f"cuda:{local_rank}")
62
+
63
+ if rank == 0:
64
+ os.makedirs(SFT_CHECKPOINT_DIR, exist_ok=True)
65
+ os.makedirs(LOG_DIR, exist_ok=True)
66
+ print("=" * 70)
67
+ print(" SFT: INSTRUCTION FINE-TUNING 1B TRANSFORMER")
68
+ print("=" * 70)
69
+
70
+ # Tokenizer
71
+ tokenizer = get_tokenizer()
72
+
73
+ # Load base model
74
+ model_config = ModelConfig()
75
+ torch.manual_seed(42)
76
+ model = Transformer(model_config)
77
+
78
+ if rank == 0:
79
+ print(f"[Init] Loading base model from {BASE_CHECKPOINT}")
80
+ ckpt = torch.load(BASE_CHECKPOINT, map_location="cpu", weights_only=False)
81
+ model.load_state_dict(ckpt["model"])
82
+ base_step = ckpt.get("step", 0)
83
+ base_loss = ckpt.get("loss", "?")
84
+ if rank == 0:
85
+ print(f"[Init] Base model: step={base_step}, pretrain_loss={base_loss}")
86
+ del ckpt
87
+
88
+ # Add chat tokens to embedding — expand vocab if needed
89
+ special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
90
+ vocab = tokenizer.get_vocab()
91
+ new_tokens = [t for t in special_tokens if t not in vocab]
92
+ if new_tokens:
93
+ tokenizer.add_tokens(new_tokens, special_tokens=True)
94
+
95
+ new_vocab_size = len(tokenizer)
96
+ if new_vocab_size > model_config.vocab_size:
97
+ if rank == 0:
98
+ print(f"[Init] Expanding vocab: {model_config.vocab_size} -> {new_vocab_size}")
99
+
100
+ old_emb_weight = model.tok_embeddings.weight.data
101
+ model.tok_embeddings = torch.nn.Embedding(new_vocab_size, model_config.hidden_dim)
102
+ model.tok_embeddings.weight.data[:model_config.vocab_size] = old_emb_weight
103
+ # Init new token embeddings as mean of existing (better than random)
104
+ mean_emb = old_emb_weight.mean(dim=0)
105
+ for i in range(model_config.vocab_size, new_vocab_size):
106
+ model.tok_embeddings.weight.data[i] = mean_emb
107
+
108
+ old_output_weight = model.output.weight.data
109
+ model.output = torch.nn.Linear(model_config.hidden_dim, new_vocab_size, bias=False)
110
+ model.output.weight.data[:model_config.vocab_size] = old_output_weight
111
+
112
+ model.config.vocab_size = new_vocab_size
113
+
114
+ model = model.to(device)
115
+ model = DDP(model, device_ids=[local_rank])
116
+
117
+ if rank == 0:
118
+ n = sum(p.numel() for p in model.parameters())
119
+ print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
120
+
121
+ # Dataset (only load on each process)
122
+ dataset = SFTDataset(
123
+ tokenizer=tokenizer,
124
+ max_seq_len=MAX_SEQ_LEN,
125
+ split="train_sft",
126
+ cache_dir=DATA_CACHE,
127
+ )
128
+
129
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
130
+ dataloader = torch.utils.data.DataLoader(
131
+ dataset,
132
+ batch_size=BATCH_SIZE_PER_GPU,
133
+ sampler=sampler,
134
+ num_workers=4,
135
+ pin_memory=True,
136
+ collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
137
+ )
138
+
139
+ steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
140
+ total_steps = steps_per_epoch * NUM_EPOCHS
141
+
142
+ if rank == 0:
143
+ eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
144
+ print(f"[Init] Dataset: {len(dataset):,} examples")
145
+ print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
146
+ print(f"[Init] Total steps: {total_steps} | Epochs: {NUM_EPOCHS}")
147
+ print(f"[Init] LR: {LEARNING_RATE} → {MIN_LR} (cosine)")
148
+ print("-" * 70)
149
+
150
+ # Optimizer — lower LR for fine-tuning
151
+ decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
152
+ nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
153
+ optimizer = torch.optim.AdamW([
154
+ {"params": decay_params, "weight_decay": WEIGHT_DECAY},
155
+ {"params": nodecay_params, "weight_decay": 0.0},
156
+ ], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
157
+
158
+ # Training
159
+ model.train()
160
+ global_step = 0
161
+ running_loss = 0.0
162
+ t0 = time.time()
163
+ step_t0 = time.time()
164
+
165
+ log_file = open(os.path.join(LOG_DIR, "sft_log.jsonl"), "w") if rank == 0 else None
166
+
167
+ for epoch in range(NUM_EPOCHS):
168
+ sampler.set_epoch(epoch)
169
+ data_iter = iter(dataloader)
170
+ micro_step = 0
171
+
172
+ if rank == 0:
173
+ print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
174
+
175
+ while True:
176
+ optimizer.zero_grad(set_to_none=True)
177
+ batch_loss = 0.0
178
+
179
+ for _ in range(GRADIENT_ACCUMULATION):
180
+ try:
181
+ input_ids, labels = next(data_iter)
182
+ except StopIteration:
183
+ break
184
+
185
+ input_ids = input_ids.to(device, non_blocking=True)
186
+ labels = labels.to(device, non_blocking=True)
187
+
188
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
189
+ _, loss = model(input_ids, labels)
190
+ loss = loss / GRADIENT_ACCUMULATION
191
+
192
+ loss.backward()
193
+ batch_loss += loss.item()
194
+ micro_step += 1
195
+
196
+ if batch_loss == 0:
197
+ break
198
+
199
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
200
+
201
+ lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
202
+ for pg in optimizer.param_groups:
203
+ pg["lr"] = lr
204
+
205
+ optimizer.step()
206
+ global_step += 1
207
+ running_loss += batch_loss
208
+
209
+ if global_step % LOG_INTERVAL == 0:
210
+ dt = time.time() - step_t0
211
+ avg = running_loss / LOG_INTERVAL
212
+ elapsed = time.time() - t0
213
+ pct = 100.0 * global_step / total_steps
214
+
215
+ if rank == 0:
216
+ gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
217
+ eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
218
+ print(
219
+ f" [Step {global_step:>5d}/{total_steps}] "
220
+ f"loss={avg:.4f} | lr={lr:.2e} | "
221
+ f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
222
+ flush=True,
223
+ )
224
+ if log_file:
225
+ log_file.write(json.dumps({
226
+ "step": global_step, "epoch": epoch + 1,
227
+ "loss": round(avg, 4), "lr": lr,
228
+ "elapsed_s": round(elapsed, 1),
229
+ }) + "\n")
230
+ log_file.flush()
231
+
232
+ running_loss = 0.0
233
+ step_t0 = time.time()
234
+
235
+ if global_step % SAVE_INTERVAL == 0:
236
+ dist.barrier()
237
+ if rank == 0:
238
+ path = os.path.join(SFT_CHECKPOINT_DIR, f"sft_step_{global_step}.pt")
239
+ torch.save({
240
+ "step": global_step,
241
+ "model": model.module.state_dict(),
242
+ "config": model_config.__dict__,
243
+ "vocab_size": new_vocab_size,
244
+ }, path)
245
+ print(f" >> Checkpoint: {path}", flush=True)
246
+ dist.barrier()
247
+
248
+ # Final save
249
+ dist.barrier()
250
+ if rank == 0:
251
+ final_path = os.path.join(SFT_CHECKPOINT_DIR, "sft_final.pt")
252
+ torch.save({
253
+ "step": global_step,
254
+ "model": model.module.state_dict(),
255
+ "config": model_config.__dict__,
256
+ "vocab_size": new_vocab_size,
257
+ }, final_path)
258
+ total_time = time.time() - t0
259
+ print("=" * 70)
260
+ print(f" SFT COMPLETE")
261
+ print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
262
+ print(f" Time: {total_time/60:.1f} minutes")
263
+ print(f" Final model: {final_path}")
264
+ print("=" * 70)
265
+ if log_file:
266
+ log_file.close()
267
+
268
+ dist.destroy_process_group()
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()