File size: 5,306 Bytes
69d01e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8439f3c
69d01e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8439f3c
69d01e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8439f3c
69d01e8
 
 
 
 
8439f3c
69d01e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import math
import time
import torch
from datasets import load_dataset
from model import GPT, GPTConfig
import tiktoken

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
BATCH_SIZE = 64 # High batch size since we have an A100
BLOCK_SIZE = 256
MAX_STEPS = 5000
LEARNING_RATE = 6e-4
WARMUP_STEPS = 100
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
CHECKPOINT_DIR = "./checkpoints_continuous"
EVAL_INTERVAL = 250
SAVE_INTERVAL = 500

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# -----------------------------------------------------------------------------
# Optimization Settings for A100/H200
# -----------------------------------------------------------------------------
# Enable Tensor Cores
torch.set_float32_matmul_precision('high')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# -----------------------------------------------------------------------------
# Cosine Learning Rate Scheduler (Karpathy's exact implementation)
# -----------------------------------------------------------------------------
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < WARMUP_STEPS:
        return LEARNING_RATE * (it + 1) / WARMUP_STEPS
    # 2) if it > max_steps, return min learning rate
    if it > MAX_STEPS:
        return LEARNING_RATE * 0.1
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - WARMUP_STEPS) / (MAX_STEPS - WARMUP_STEPS)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return LEARNING_RATE * 0.1 + coeff * (LEARNING_RATE - LEARNING_RATE * 0.1)

# -----------------------------------------------------------------------------
# Main Training Loop
# -----------------------------------------------------------------------------
def main():
    print(f"Initializing NanoGPT on {device}...")
    
    # 1. Initialize Model
    config = GPTConfig(block_size=BLOCK_SIZE, vocab_size=50304, n_layer=4, n_head=4, n_embd=256)
    model = GPT(config)
    model.to(device)
    
    # 2. Compile model for massive speedup
    if hasattr(torch, 'compile'):
        print("Compiling model (this takes a minute)...")
        model = torch.compile(model)
        
    # 3. Setup Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8, fused=True)

    # 4. Load Dataset
    print(f"Streaming dataset: {DATASET_NAME}...")
    ds = load_dataset(DATASET_NAME, name="sample-10BT", split="train", streaming=True)
    ds_iter = iter(ds)
    enc = tiktoken.get_encoding("gpt2")
    
    # 5. Training Loop
    print("Starting continuous training loop...")
    t0 = time.time()
    
    for step in range(MAX_STEPS):
        
        # Get learning rate
        lr = get_lr(step)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        # Fetch data
        try:
            row = next(ds_iter)
            text = row.get("text", " ")
            if not text: text = " "
        except StopIteration:
            # Loop dataset
            ds_iter = iter(ds)
            row = next(ds_iter)
            text = row.get("text", " ")
            
        tokens = enc.encode(text, allowed_special={"<|endoftext|>"})
        if len(tokens) < BLOCK_SIZE + 1:
            continue
            
        # Sample sequence
        ix = torch.randint(len(tokens) - BLOCK_SIZE, (BATCH_SIZE,))
        x = torch.stack([torch.tensor(tokens[i:i+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
        y = torch.stack([torch.tensor(tokens[i+1:i+1+BLOCK_SIZE], dtype=torch.long) for i in ix]).to(device, non_blocking=True)
        
        # Forward pass (bfloat16)
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = model(x, y)
            
        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        
        # Global gradient clipping
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Optimizer step
        optimizer.step()
        
        # Wait for the GPU to finish its work
        torch.cuda.synchronize()
        
        # Timing
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        
        tokens_processed = BATCH_SIZE * BLOCK_SIZE
        tokens_per_sec = tokens_processed / dt
        
        if step % 10 == 0:
            print(f"step {step:4d} | loss: {loss.item():.4f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
            
        if step > 0 and step % SAVE_INTERVAL == 0:
            raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
            ckpt_path = os.path.join(CHECKPOINT_DIR, f"model_{step:05d}.pt")
            checkpoint = {
                'model': raw_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'step': step,
                'config': config,
            }
            print(f"Saving checkpoint to {ckpt_path}")
            torch.save(checkpoint, ckpt_path)

if __name__ == "__main__":
    main()