File size: 6,797 Bytes
d278d9d
 
 
 
 
 
 
 
200f2a8
d278d9d
200f2a8
 
d278d9d
 
200f2a8
d278d9d
61afda7
 
d278d9d
200f2a8
 
d278d9d
 
 
 
 
 
 
 
61afda7
d278d9d
 
200f2a8
 
 
 
 
 
 
 
 
 
d278d9d
200f2a8
d278d9d
 
200f2a8
 
 
 
 
d278d9d
 
200f2a8
 
 
 
 
 
 
 
 
d278d9d
 
61afda7
200f2a8
 
 
 
 
 
 
 
 
d278d9d
200f2a8
 
 
 
 
d278d9d
200f2a8
 
 
 
d278d9d
200f2a8
d278d9d
 
 
61afda7
 
d278d9d
 
200f2a8
 
61afda7
200f2a8
 
 
61afda7
 
 
 
 
200f2a8
61afda7
 
 
d278d9d
 
200f2a8
61afda7
d278d9d
61afda7
200f2a8
 
 
 
 
 
61afda7
d278d9d
 
 
 
200f2a8
 
d278d9d
200f2a8
 
61afda7
d278d9d
61afda7
200f2a8
61afda7
 
 
 
 
d278d9d
61afda7
 
 
 
200f2a8
d278d9d
200f2a8
 
 
 
61afda7
 
 
200f2a8
 
61afda7
 
d278d9d
 
61afda7
 
 
 
200f2a8
 
 
61afda7
 
 
200f2a8
 
61afda7
d278d9d
 
200f2a8
 
 
 
 
 
 
 
 
 
 
 
 
d278d9d
61afda7
d278d9d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
Training script for baseline NanoGPT model on enwik8 dataset.
Ensures proper bpc calculation and comparable evaluation with DTAT.
"""

import os
import time
import math
import wandb
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from model_baseline import BaselineTransformer
from config.baseline_config import get_config

# -----------------------------------------------------------------------------
# I/O
def get_batch(data, block_size, batch_size, device):
    """Generate a small batch of data of inputs x and targets y."""
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

def estimate_loss(model, data, config):
    """Estimate loss on data split, ensuring proper bpc calculation."""
    model.eval()
    total_loss = 0.0
    total_steps = config.eval_iters
    
    with torch.no_grad():
        for _ in range(total_steps):
            X, Y = get_batch(data, config.block_size, config.batch_size, config.device)
            with torch.amp.autocast('cuda', enabled=config.mixed_precision):
                logits, loss = model(X, Y)
            total_loss += loss.item()
    
    model.train()
    return total_loss / total_steps

def get_lr(it, config):
    """
    Learning rate scheduler with linear warmup and cosine decay.
    Matches DTAT's scheduler exactly.
    """
    # Linear warmup
    if it < config.warmup_iters:
        return config.learning_rate * it / config.warmup_iters
    
    # Cosine decay
    if config.decay_lr:
        decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
        decay_ratio = min(decay_ratio, 1.0)  # Cap at 1.0
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
        return config.min_lr + coeff * (config.learning_rate - config.min_lr)
    
    return config.learning_rate

def main():
    # Initialize wandb
    wandb.init(project="enwik8-baseline", name="baseline-run")
    wandb.config.update(get_config().__dict__)
    
    # Get config and setup
    config = get_config()
    device = config.device
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = config.cudnn_benchmark
    
    # Data loading
    print("Loading data...")
    data_dir = os.path.join('data')
    train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r')
    val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r')
    
    # Model init
    print("Initializing model...")
    model = BaselineTransformer(config).to(device)
    print(f"number of parameters: {model.get_num_params()/1e6:.2f}M")
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        betas=(config.beta1, config.beta2),
        weight_decay=config.weight_decay
    )
    
    # Mixed precision setup
    scaler = torch.amp.GradScaler('cuda', enabled=config.mixed_precision)
    
    # Memory optimizations
    if config.gradient_checkpointing:
        model.gradient_checkpointing_enable()
    
    # Calculate total steps and epochs
    total_steps = config.max_iters
    batch_size = config.batch_size
    block_size = config.block_size
    total_epochs = (total_steps * batch_size * block_size) // len(train_data)
    
    # Create progress bar
    pbar = tqdm(range(config.max_iters), desc=f"Training (0/{total_epochs} epochs)")
    
    best_val_loss = float('inf')
    no_improvement = 0
    t0 = time.time()
    
    for iter_num in pbar:
        # Early stopping check
        if no_improvement >= config.patience:
            print(f"\nEarly stopping triggered after {iter_num} iterations")
            print(f"Best validation loss: {best_val_loss:.4f}")
            break
        
        # Update learning rate
        lr = get_lr(iter_num, config)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        # Sample a batch of data
        X, Y = get_batch(train_data, config.block_size, config.batch_size, device)
        
        # Mixed precision training
        with torch.amp.autocast('cuda', enabled=config.mixed_precision):
            logits, loss = model(X, Y)
        
        # Backward pass with gradient scaling
        optimizer.zero_grad(set_to_none=True)  # Slightly faster than zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        scaler.step(optimizer)
        scaler.update()
        
        # Logging
        if iter_num % config.log_interval == 0:
            # Calculate current epoch
            current_tokens = (iter_num + 1) * batch_size * block_size
            current_epoch = current_tokens / len(train_data)
            

            val_loss = estimate_loss(model, val_data, config)


            # Update progress bar
            pbar.set_description(
                f"Training ({current_epoch:.1f}/{total_epochs} epochs) | "
                f"loss: {loss.item():.4f} | "  # This is now directly in BPC
                f"bpc: {loss.item():.2f} | "   # Same as loss since it's already BPC
                f"lr: {lr:.1e} | "
                f"tokens/sec: {(batch_size * block_size) / (time.time() - t0):.1f}"
            )
            
            # Log to wandb
            wandb.log({
                "iter": iter_num,
                "epoch": current_epoch,
                "train/loss": loss.item(),
                "train/bpc": loss.item(),  # Same as loss since it's already BPC
                "lr": lr,
                "tokens_per_sec": (batch_size * block_size) / (time.time() - t0),
            })
        
        # Check validation and save every 100 iterations
        if iter_num > 0 and iter_num % 100 == 0:
            val_loss = estimate_loss(model, val_data, config)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                no_improvement = 0
                print(f"\nSaving best model with val_loss: {best_val_loss:.4f}")
                torch.save(model.state_dict(), os.path.join(os.path.dirname(__file__), 'best_baseline.pt'))
            else:
                no_improvement += 1
            
            # Log validation loss to wandb
            wandb.log({
                "val/loss": val_loss,
                "val/bpc": val_loss,
                "lr": lr,

            })
    
    wandb.finish()

if __name__ == '__main__':
    main()