File size: 5,306 Bytes
69d01e8 8439f3c 69d01e8 8439f3c 69d01e8 8439f3c 69d01e8 8439f3c 69d01e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import os
import math
import time
import torch
from datasets import load_dataset
from model import GPT, GPTConfig
import tiktoken
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
BATCH_SIZE = 64 # High batch size since we have an A100
BLOCK_SIZE = 256
MAX_STEPS = 5000
LEARNING_RATE = 6e-4
WARMUP_STEPS = 100
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
CHECKPOINT_DIR = "./checkpoints_continuous"
EVAL_INTERVAL = 250
SAVE_INTERVAL = 500
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# Optimization Settings for A100/H200
# -----------------------------------------------------------------------------
# Enable Tensor Cores
torch.set_float32_matmul_precision('high')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# -----------------------------------------------------------------------------
# Cosine Learning Rate Scheduler (Karpathy's exact implementation)
# -----------------------------------------------------------------------------
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < WARMUP_STEPS:
return LEARNING_RATE * (it + 1) / WARMUP_STEPS
# 2) if it > max_steps, return min learning rate
if it > MAX_STEPS:
return LEARNING_RATE * 0.1
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - WARMUP_STEPS) / (MAX_STEPS - WARMUP_STEPS)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
return LEARNING_RATE * 0.1 + coeff * (LEARNING_RATE - LEARNING_RATE * 0.1)
# -----------------------------------------------------------------------------
# Main Training Loop
# -----------------------------------------------------------------------------
def main():
print(f"Initializing NanoGPT on {device}...")
# 1. Initialize Model
config = GPTConfig(block_size=BLOCK_SIZE, vocab_size=50304, n_layer=4, n_head=4, n_embd=256)
model = GPT(config)
model.to(device)
# 2. Compile model for massive speedup
if hasattr(torch, 'compile'):
print("Compiling model (this takes a minute)...")
model = torch.compile(model)
# 3. Setup Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, fused=True)
# 4. Load Dataset
print(f"Streaming dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME, name="sample-10BT", split="train", streaming=True)
ds_iter = iter(ds)
enc = tiktoken.get_encoding("gpt2")
# 5. Training Loop
print("Starting continuous training loop...")
t0 = time.time()
for step in range(MAX_STEPS):
# Get learning rate
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Fetch data
try:
row = next(ds_iter)
text = row.get("text", " ")
if not text: text = " "
except StopIteration:
# Loop dataset
ds_iter = iter(ds)
row = next(ds_iter)
text = row.get("text", " ")
tokens = enc.encode(text, allowed_special={"<|endoftext|>"})
if len(tokens) < BLOCK_SIZE + 1:
continue
# Sample sequence
ix = torch.randint(len(tokens) - BLOCK_SIZE, (BATCH_SIZE,))
x = torch.stack([torch.tensor(tokens[i:i+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
y = torch.stack([torch.tensor(tokens[i+1:i+1+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
# Forward pass (bfloat16)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, loss = model(x, y)
# Backward pass
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Global gradient clipping
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step
optimizer.step()
# Wait for the GPU to finish its work
torch.cuda.synchronize()
# Timing
t1 = time.time()
dt = t1 - t0
t0 = t1
tokens_processed = BATCH_SIZE * BLOCK_SIZE
tokens_per_sec = tokens_processed / dt
if step % 10 == 0:
print(f"step {step:4d} | loss: {loss.item():.4f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
if step > 0 and step % SAVE_INTERVAL == 0:
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
ckpt_path = os.path.join(CHECKPOINT_DIR, f"model_{step:05d}.pt")
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': step,
'config': config,
}
print(f"Saving checkpoint to {ckpt_path}")
torch.save(checkpoint, ckpt_path)
if __name__ == "__main__":
main()
|