Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
train_large.py - Trains larger model for the Killer Test.
Usage:
python validation/memory/train_large.py --config small # 7M params
python validation/memory/train_large.py --config medium # 25M params
python validation/memory/train_large.py --config large # 50M params
"""
import os
import sys
import time
import pickle
import argparse
import numpy as np
import torch
# Add root directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.model import RippleGPT
from src.config import RippleConfig
from validation.memory.model_configs import get_config, print_configs, ModelConfig
# Directories
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')
# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
def get_batch(split: str, block_size: int, batch_size: int):
"""Loads a data batch."""
if split == 'train':
data = np.memmap(os.path.join(DATA_DIR, 'train.bin'), dtype=np.uint16, mode='r')
else:
data = np.memmap(os.path.join(DATA_DIR, 'val.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size].astype(np.int64))) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size].astype(np.int64))) for i in ix])
if DEVICE == 'cuda':
x, y = x.pin_memory().to(DEVICE, non_blocking=True), y.pin_memory().to(DEVICE, non_blocking=True)
else:
x, y = x.to(DEVICE), y.to(DEVICE)
return x, y
@torch.no_grad()
def estimate_loss(model, ctx, block_size: int, batch_size: int, eval_iters: int = 50):
"""Estimates loss on train and validation splits."""
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split, block_size, batch_size)
with ctx:
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
def get_lr(it: int, warmup_iters: int, max_iters: int, max_lr: float, min_lr: float) -> float:
"""Cosine decay with warmup."""
if it < warmup_iters:
return max_lr * it / warmup_iters
if it > max_iters:
return min_lr
decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
def train(config_name: str = "medium", max_iters: int = 10000):
"""Main training loop."""
model_cfg = get_config(config_name)
print("=" * 70)
print(f"๐Ÿง  KILLER TEST TRAINING: {model_cfg.name.upper()} MODEL")
print("=" * 70)
# Check data
if not os.path.exists(os.path.join(DATA_DIR, 'train.bin')):
print("โŒ Data not found!")
print(" Run first: python validation/memory/prepare_large_data.py --size 50")
return
os.makedirs(CKPT_DIR, exist_ok=True)
# Load vocabulary
with open(os.path.join(DATA_DIR, 'meta.pkl'), 'rb') as f:
meta = pickle.load(f)
vocab_size = meta['vocab_size']
# Load dataset stats
with open(os.path.join(DATA_DIR, 'stats.pkl'), 'rb') as f:
data_stats = pickle.load(f)
print(f"\n๐Ÿ“š Dataset: {data_stats.get('actual_mb', 'N/A'):.1f}MB")
print(f"๐Ÿ“š Vocab size: {vocab_size}")
# Training configuration based on model size
batch_size = 32 if model_cfg.name in ["small", "medium"] else 16
# Smaller learning rate for larger models
max_lr = {
"small": 1e-3,
"medium": 6e-4,
"large": 3e-4,
"xlarge": 1e-4
}.get(model_cfg.name, 6e-4)
min_lr = max_lr / 10
warmup_iters = 200
eval_interval = 500
log_interval = 50
torch.manual_seed(1337)
# Initialize model
print(f"\n๐Ÿ”ง Initializing model {model_cfg.name}...")
config = RippleConfig(
vocab_size=vocab_size,
block_size=model_cfg.block_size,
n_layer=model_cfg.n_layer,
n_head=model_cfg.n_head,
n_embd=model_cfg.n_embd,
dropout=model_cfg.dropout,
use_absolute_pos_emb=False # Ripple Field!
)
model = RippleGPT(config)
model.to(DEVICE)
num_params = model.get_num_params()
print(f" Parameters: {num_params / 1e6:.2f}M")
print(f" Device: {DEVICE}")
print(f" Block size: {model_cfg.block_size}")
print(f" Batch size: {batch_size}")
print(f" Max LR: {max_lr}")
print(f" Max iters: {max_iters}")
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.99))
# Context
from contextlib import nullcontext
ctx = nullcontext() if DEVICE in ['cpu', 'mps'] else torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16)
# Training loop
print(f"\n๐Ÿ“ˆ Starting training ({max_iters} iterations)...")
print("-" * 70)
X, Y = get_batch('train', model_cfg.block_size, batch_size)
t0 = time.time()
best_val_loss = float('inf')
for iter_num in range(max_iters):
# LR scheduling
lr = get_lr(iter_num, warmup_iters, max_iters, max_lr, min_lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Evaluation
if iter_num % eval_interval == 0 and iter_num > 0:
losses = estimate_loss(model, ctx, model_cfg.block_size, batch_size)
print(f"step {iter_num}: train {losses['train']:.4f}, val {losses['val']:.4f}, lr {lr:.2e}")
if losses['val'] < best_val_loss:
best_val_loss = losses['val']
checkpoint = {
'model': model.state_dict(),
'config': config,
'model_config_name': model_cfg.name,
'iter_num': iter_num,
'best_val_loss': best_val_loss,
}
ckpt_path = os.path.join(CKPT_DIR, f'ckpt_{model_cfg.name}_best.pt')
torch.save(checkpoint, ckpt_path)
print(f" ๐Ÿ’พ Best model saved! (val_loss: {best_val_loss:.4f})")
# Forward/backward
with ctx:
logits, loss = model(X, Y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# Logging
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % log_interval == 0:
print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.0f}ms, lr {lr:.2e}")
X, Y = get_batch('train', model_cfg.block_size, batch_size)
# Final checkpoint
checkpoint = {
'model': model.state_dict(),
'config': config,
'model_config_name': model_cfg.name,
'iter_num': max_iters,
'best_val_loss': best_val_loss,
}
torch.save(checkpoint, os.path.join(CKPT_DIR, f'ckpt_{model_cfg.name}_final.pt'))
print("-" * 70)
print(f"โœ… Training complete!")
print(f" Best val loss: {best_val_loss:.4f}")
print(f" Checkpoints at: {CKPT_DIR}")
print(f"\nNext step: python validation/memory/needle_test.py --config {model_cfg.name}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Trains model for Killer Test')
parser.add_argument('--config', type=str, default='medium',
choices=['small', 'medium', 'large', 'xlarge'],
help='Model configuration')
parser.add_argument('--iters', type=int, default=10000, help='Number of iterations')
args = parser.parse_args()
print_configs()
train(args.config, args.iters)