import os import math import time import torch from datasets import load_dataset from model import GPT, GPTConfig import tiktoken # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- BATCH_SIZE = 64 # High batch size since we have an A100 BLOCK_SIZE = 256 MAX_STEPS = 5000 LEARNING_RATE = 6e-4 WARMUP_STEPS = 100 DATASET_NAME = "HuggingFaceFW/fineweb-edu" CHECKPOINT_DIR = "./checkpoints_continuous" EVAL_INTERVAL = 250 SAVE_INTERVAL = 500 os.makedirs(CHECKPOINT_DIR, exist_ok=True) # ----------------------------------------------------------------------------- # Optimization Settings for A100/H200 # ----------------------------------------------------------------------------- # Enable Tensor Cores torch.set_float32_matmul_precision('high') device = 'cuda' if torch.cuda.is_available() else 'cpu' # ----------------------------------------------------------------------------- # Cosine Learning Rate Scheduler (Karpathy's exact implementation) # ----------------------------------------------------------------------------- def get_lr(it): # 1) linear warmup for warmup_iters steps if it < WARMUP_STEPS: return LEARNING_RATE * (it + 1) / WARMUP_STEPS # 2) if it > max_steps, return min learning rate if it > MAX_STEPS: return LEARNING_RATE * 0.1 # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - WARMUP_STEPS) / (MAX_STEPS - WARMUP_STEPS) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 return LEARNING_RATE * 0.1 + coeff * (LEARNING_RATE - LEARNING_RATE * 0.1) # ----------------------------------------------------------------------------- # Main Training Loop # ----------------------------------------------------------------------------- def main(): print(f"Initializing NanoGPT on {device}...") # 1. Initialize Model config = GPTConfig(block_size=BLOCK_SIZE, vocab_size=50304, n_layer=4, n_head=4, n_embd=256) model = GPT(config) model.to(device) # 2. Compile model for massive speedup if hasattr(torch, 'compile'): print("Compiling model (this takes a minute)...") model = torch.compile(model) # 3. Setup Optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, fused=True) # 4. Load Dataset print(f"Streaming dataset: {DATASET_NAME}...") ds = load_dataset(DATASET_NAME, name="sample-10BT", split="train", streaming=True) ds_iter = iter(ds) enc = tiktoken.get_encoding("gpt2") # 5. Training Loop print("Starting continuous training loop...") t0 = time.time() for step in range(MAX_STEPS): # Get learning rate lr = get_lr(step) for param_group in optimizer.param_groups: param_group['lr'] = lr # Fetch data try: row = next(ds_iter) text = row.get("text", " ") if not text: text = " " except StopIteration: # Loop dataset ds_iter = iter(ds) row = next(ds_iter) text = row.get("text", " ") tokens = enc.encode(text, allowed_special={"<|endoftext|>"}) if len(tokens) < BLOCK_SIZE + 1: continue # Sample sequence ix = torch.randint(len(tokens) - BLOCK_SIZE, (BATCH_SIZE,)) x = torch.stack([torch.tensor(tokens[i:i+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True) y = torch.stack([torch.tensor(tokens[i+1:i+1+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True) # Forward pass (bfloat16) with torch.autocast(device_type=device, dtype=torch.bfloat16): logits, loss = model(x, y) # Backward pass optimizer.zero_grad(set_to_none=True) loss.backward() # Global gradient clipping norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Optimizer step optimizer.step() # Wait for the GPU to finish its work torch.cuda.synchronize() # Timing t1 = time.time() dt = t1 - t0 t0 = t1 tokens_processed = BATCH_SIZE * BLOCK_SIZE tokens_per_sec = tokens_processed / dt if step % 10 == 0: print(f"step {step:4d} | loss: {loss.item():.4f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}") if step > 0 and step % SAVE_INTERVAL == 0: raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model ckpt_path = os.path.join(CHECKPOINT_DIR, f"model_{step:05d}.pt") checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': step, 'config': config, } print(f"Saving checkpoint to {ckpt_path}") torch.save(checkpoint, ckpt_path) if __name__ == "__main__": main()