NanoGPT-WebReaper-ZeroGPU / train_continuous.py
LvcidPsyche's picture
Upload train_continuous.py with huggingface_hub
8439f3c verified
import os
import math
import time
import torch
from datasets import load_dataset
from model import GPT, GPTConfig
import tiktoken
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
BATCH_SIZE = 64 # High batch size since we have an A100
BLOCK_SIZE = 256
MAX_STEPS = 5000
LEARNING_RATE = 6e-4
WARMUP_STEPS = 100
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
CHECKPOINT_DIR = "./checkpoints_continuous"
EVAL_INTERVAL = 250
SAVE_INTERVAL = 500
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# -----------------------------------------------------------------------------
# Optimization Settings for A100/H200
# -----------------------------------------------------------------------------
# Enable Tensor Cores
torch.set_float32_matmul_precision('high')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# -----------------------------------------------------------------------------
# Cosine Learning Rate Scheduler (Karpathy's exact implementation)
# -----------------------------------------------------------------------------
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < WARMUP_STEPS:
return LEARNING_RATE * (it + 1) / WARMUP_STEPS
# 2) if it > max_steps, return min learning rate
if it > MAX_STEPS:
return LEARNING_RATE * 0.1
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - WARMUP_STEPS) / (MAX_STEPS - WARMUP_STEPS)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
return LEARNING_RATE * 0.1 + coeff * (LEARNING_RATE - LEARNING_RATE * 0.1)
# -----------------------------------------------------------------------------
# Main Training Loop
# -----------------------------------------------------------------------------
def main():
print(f"Initializing NanoGPT on {device}...")
# 1. Initialize Model
config = GPTConfig(block_size=BLOCK_SIZE, vocab_size=50304, n_layer=4, n_head=4, n_embd=256)
model = GPT(config)
model.to(device)
# 2. Compile model for massive speedup
if hasattr(torch, 'compile'):
print("Compiling model (this takes a minute)...")
model = torch.compile(model)
# 3. Setup Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, fused=True)
# 4. Load Dataset
print(f"Streaming dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME, name="sample-10BT", split="train", streaming=True)
ds_iter = iter(ds)
enc = tiktoken.get_encoding("gpt2")
# 5. Training Loop
print("Starting continuous training loop...")
t0 = time.time()
for step in range(MAX_STEPS):
# Get learning rate
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Fetch data
try:
row = next(ds_iter)
text = row.get("text", " ")
if not text: text = " "
except StopIteration:
# Loop dataset
ds_iter = iter(ds)
row = next(ds_iter)
text = row.get("text", " ")
tokens = enc.encode(text, allowed_special={"<|endoftext|>"})
if len(tokens) < BLOCK_SIZE + 1:
continue
# Sample sequence
ix = torch.randint(len(tokens) - BLOCK_SIZE, (BATCH_SIZE,))
x = torch.stack([torch.tensor(tokens[i:i+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
y = torch.stack([torch.tensor(tokens[i+1:i+1+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
# Forward pass (bfloat16)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, loss = model(x, y)
# Backward pass
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Global gradient clipping
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step
optimizer.step()
# Wait for the GPU to finish its work
torch.cuda.synchronize()
# Timing
t1 = time.time()
dt = t1 - t0
t0 = t1
tokens_processed = BATCH_SIZE * BLOCK_SIZE
tokens_per_sec = tokens_processed / dt
if step % 10 == 0:
print(f"step {step:4d} | loss: {loss.item():.4f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
if step > 0 and step % SAVE_INTERVAL == 0:
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
ckpt_path = os.path.join(CHECKPOINT_DIR, f"model_{step:05d}.pt")
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': step,
'config': config,
}
print(f"Saving checkpoint to {ckpt_path}")
torch.save(checkpoint, ckpt_path)
if __name__ == "__main__":
main()