dkumar15 commited on
Commit
d42a1f3
·
verified ·
1 Parent(s): 72372ef

Upload training_code/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_code/train.py +257 -0
training_code/train.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distributed training script for 1B parameter Transformer.
3
+
4
+ Launch: torchrun --nproc_per_node=8 train.py
5
+
6
+ Stack: PyTorch DDP + BF16 autocast + 8x H100 80GB
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import math
12
+ import time
13
+ import json
14
+ import datetime
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+
20
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
21
+ from model.config import ModelConfig, TrainConfig
22
+ from model.transformer import Transformer
23
+ from model.data import get_tokenizer, create_dataloader
24
+
25
+
26
+ def get_wsd_lr(step, warmup_steps, total_steps, max_lr, min_lr):
27
+ """Warmup-Stable-Decay: linear warmup -> constant -> cosine decay (last 20%)."""
28
+ stable_end = int(total_steps * 0.8)
29
+ if step < warmup_steps:
30
+ return max_lr * step / max(warmup_steps, 1)
31
+ elif step < stable_end:
32
+ return max_lr
33
+ else:
34
+ progress = (step - stable_end) / max(total_steps - stable_end, 1)
35
+ return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
36
+
37
+
38
+ def find_latest_checkpoint(checkpoint_dir):
39
+ """Find the latest step_*.pt checkpoint in the directory."""
40
+ import glob
41
+ pattern = os.path.join(checkpoint_dir, "step_*.pt")
42
+ files = glob.glob(pattern)
43
+ if not files:
44
+ return None, 0
45
+ latest = max(files, key=lambda f: int(os.path.basename(f).replace("step_", "").replace(".pt", "")))
46
+ step = int(os.path.basename(latest).replace("step_", "").replace(".pt", ""))
47
+ return latest, step
48
+
49
+
50
+ def main():
51
+ dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
52
+ rank = int(os.environ.get("RANK", 0))
53
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
54
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
55
+ torch.cuda.set_device(local_rank)
56
+ device = torch.device(f"cuda:{local_rank}")
57
+
58
+ model_config = ModelConfig()
59
+ train_config = TrainConfig()
60
+
61
+ eff_batch = train_config.batch_size_per_gpu * world_size * train_config.gradient_accumulation_steps
62
+ tokens_per_step = eff_batch * model_config.max_seq_len
63
+ total_steps = train_config.total_tokens // tokens_per_step
64
+
65
+ if rank == 0:
66
+ os.makedirs(train_config.log_dir, exist_ok=True)
67
+ os.makedirs(train_config.checkpoint_dir, exist_ok=True)
68
+ print("=" * 70)
69
+ print(f" TRAINING 1B TRANSFORMER FROM SCRATCH")
70
+ print(f" Arch: {model_config.num_layers}L / {model_config.hidden_dim}D / "
71
+ f"{model_config.num_attention_heads}H / GQA-{model_config.num_kv_heads}KV / "
72
+ f"SwiGLU-{model_config.intermediate_dim}")
73
+ print(f" Seq: {model_config.max_seq_len} | Vocab: {model_config.vocab_size}")
74
+ print(f" GPUs: {world_size}x H100 80GB | Backend: DDP + BF16 autocast")
75
+ print(f" Batch: {eff_batch} seqs = {tokens_per_step:,} tok/step")
76
+ print(f" Steps: {total_steps:,} | Target: {train_config.total_tokens:,} tokens")
77
+ print("=" * 70)
78
+
79
+ # Tokenizer
80
+ tokenizer = get_tokenizer()
81
+
82
+ # Model
83
+ torch.manual_seed(train_config.seed)
84
+ model = Transformer(model_config).to(device)
85
+
86
+ if rank == 0:
87
+ n = sum(p.numel() for p in model.parameters())
88
+ print(f"[Init] Params: {n:,} ({n/1e9:.3f}B)")
89
+
90
+ model = DDP(model, device_ids=[local_rank])
91
+
92
+ # Optimizer
93
+ decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
94
+ nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
95
+ optimizer = torch.optim.AdamW([
96
+ {"params": decay_params, "weight_decay": train_config.weight_decay},
97
+ {"params": nodecay_params, "weight_decay": 0.0},
98
+ ], lr=train_config.learning_rate, betas=(train_config.beta1, train_config.beta2), fused=True)
99
+
100
+ if rank == 0:
101
+ dp = sum(p.numel() for p in decay_params)
102
+ ndp = sum(p.numel() for p in nodecay_params)
103
+ print(f"[Init] Optimizer: {dp:,} decay + {ndp:,} no-decay params")
104
+
105
+ # Resume from checkpoint
106
+ resume_step = 0
107
+ ckpt_path, ckpt_step = find_latest_checkpoint(train_config.checkpoint_dir)
108
+ if ckpt_path is not None:
109
+ if rank == 0:
110
+ print(f"[Resume] Loading checkpoint: {ckpt_path} (step {ckpt_step})")
111
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
112
+ model.module.load_state_dict(ckpt["model"])
113
+ optimizer.load_state_dict(ckpt["optimizer"])
114
+ resume_step = ckpt["step"]
115
+ if rank == 0:
116
+ print(f"[Resume] Restored model + optimizer at step {resume_step}, "
117
+ f"loss was {ckpt.get('loss', 'N/A')}")
118
+ del ckpt
119
+ torch.cuda.empty_cache()
120
+ else:
121
+ if rank == 0:
122
+ print("[Init] No checkpoint found, starting from scratch")
123
+
124
+ # Data — use (seed + resume_step) so resumed runs see different shuffled data
125
+ effective_seed = train_config.seed + resume_step
126
+ dataloader = create_dataloader(tokenizer, train_config, rank=rank, world_size=world_size,
127
+ seed_override=effective_seed)
128
+ data_iter = iter(dataloader)
129
+
130
+ if rank == 0:
131
+ print(f"[Init] Dataloader ready (streaming FineWeb-Edu 10BT)")
132
+ print(f"[Schedule] WSD: warmup {train_config.warmup_steps} -> "
133
+ f"stable {int(total_steps*0.8)} -> decay {total_steps}")
134
+ if resume_step > 0:
135
+ remaining = total_steps - resume_step
136
+ print(f"[Resume] Continuing from step {resume_step}, {remaining:,} steps remaining")
137
+ print("-" * 70)
138
+ sys.stdout.flush()
139
+
140
+ # ===== TRAINING LOOP =====
141
+ model.train()
142
+ global_step = resume_step
143
+ running_loss = 0.0
144
+ best_loss = float("inf")
145
+ tokens_done = resume_step * tokens_per_step
146
+ t0 = time.time()
147
+ step_t0 = time.time()
148
+
149
+ log_file = open(os.path.join(train_config.log_dir, "train_log.jsonl"), "a") if rank == 0 else None
150
+
151
+ while global_step < total_steps:
152
+ optimizer.zero_grad(set_to_none=True)
153
+ micro_loss = 0.0
154
+
155
+ for micro in range(train_config.gradient_accumulation_steps):
156
+ try:
157
+ input_ids, labels = next(data_iter)
158
+ except StopIteration:
159
+ data_iter = iter(dataloader)
160
+ input_ids, labels = next(data_iter)
161
+
162
+ input_ids = input_ids.to(device, non_blocking=True)
163
+ labels = labels.to(device, non_blocking=True)
164
+
165
+ # BF16 autocast — no scaler needed (BF16 has enough dynamic range)
166
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
167
+ _, loss = model(input_ids, labels)
168
+ loss = loss / train_config.gradient_accumulation_steps
169
+
170
+ loss.backward()
171
+ micro_loss += loss.item()
172
+
173
+ # Gradient clipping
174
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip)
175
+
176
+ # LR schedule
177
+ lr = get_wsd_lr(global_step, train_config.warmup_steps, total_steps,
178
+ train_config.learning_rate, train_config.min_lr)
179
+ for pg in optimizer.param_groups:
180
+ pg["lr"] = lr
181
+
182
+ optimizer.step()
183
+ global_step += 1
184
+ running_loss += micro_loss
185
+ tokens_done += tokens_per_step
186
+
187
+ # Log
188
+ if global_step % train_config.log_interval == 0:
189
+ dt = time.time() - step_t0
190
+ tps = (train_config.log_interval * tokens_per_step) / max(dt, 1e-9)
191
+ avg = running_loss / train_config.log_interval
192
+ elapsed = time.time() - t0
193
+ pct = 100.0 * global_step / total_steps
194
+ eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
195
+
196
+ if rank == 0:
197
+ gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
198
+ print(
199
+ f"[Step {global_step:>6d}/{total_steps}] "
200
+ f"loss={avg:.4f} | lr={lr:.2e} | "
201
+ f"tok/s={tps:,.0f} | GPU={gpu_mem:.1f}GB | "
202
+ f"{pct:.1f}% | ETA={eta/3600:.1f}h",
203
+ flush=True,
204
+ )
205
+ if log_file:
206
+ log_file.write(json.dumps({
207
+ "step": global_step, "loss": round(avg, 4), "lr": lr,
208
+ "tps": round(tps), "tokens": tokens_done,
209
+ "gpu_gb": round(gpu_mem, 1), "elapsed_s": round(elapsed, 1),
210
+ }) + "\n")
211
+ log_file.flush()
212
+
213
+ if avg < best_loss:
214
+ best_loss = avg
215
+ running_loss = 0.0
216
+ step_t0 = time.time()
217
+
218
+ # Checkpoint
219
+ if global_step % train_config.save_interval == 0:
220
+ dist.barrier()
221
+ if rank == 0:
222
+ ckpt_path = os.path.join(train_config.checkpoint_dir, f"step_{global_step}.pt")
223
+ torch.save({
224
+ "step": global_step,
225
+ "model": model.module.state_dict(),
226
+ "optimizer": optimizer.state_dict(),
227
+ "loss": avg if global_step % train_config.log_interval == 0 else micro_loss,
228
+ "config": {"model": model_config.__dict__, "train": train_config.__dict__},
229
+ }, ckpt_path)
230
+ print(f" >> Checkpoint: {ckpt_path}", flush=True)
231
+ dist.barrier()
232
+
233
+ # Final
234
+ dist.barrier()
235
+ if rank == 0:
236
+ final_path = os.path.join(train_config.checkpoint_dir, "final.pt")
237
+ torch.save({
238
+ "step": global_step,
239
+ "model": model.module.state_dict(),
240
+ "config": {"model": model_config.__dict__, "train": train_config.__dict__},
241
+ }, final_path)
242
+ total_time = time.time() - t0
243
+ print("=" * 70)
244
+ print(f" TRAINING COMPLETE")
245
+ print(f" Steps: {global_step:,} | Tokens: {tokens_done:,}")
246
+ print(f" Time: {total_time/3600:.2f}h | Throughput: {tokens_done/total_time:,.0f} tok/s")
247
+ print(f" Best loss: {best_loss:.4f}")
248
+ print(f" Final model: {final_path}")
249
+ print("=" * 70)
250
+ if log_file:
251
+ log_file.close()
252
+
253
+ dist.destroy_process_group()
254
+
255
+
256
+ if __name__ == "__main__":
257
+ main()