| | """ |
| | train_large.py - Trains larger model for the Killer Test. |
| | |
| | Usage: |
| | python validation/memory/train_large.py --config small # 7M params |
| | python validation/memory/train_large.py --config medium # 25M params |
| | python validation/memory/train_large.py --config large # 50M params |
| | """ |
| |
|
| | import os |
| | import sys |
| | import time |
| | import pickle |
| | import argparse |
| | import numpy as np |
| | import torch |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
| |
|
| | from src.model import RippleGPT |
| | from src.config import RippleConfig |
| | from validation.memory.model_configs import get_config, print_configs, ModelConfig |
| |
|
| | |
| | DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') |
| | CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') |
| |
|
| | |
| | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' |
| |
|
| |
|
| | def get_batch(split: str, block_size: int, batch_size: int): |
| | """Loads a data batch.""" |
| | if split == 'train': |
| | data = np.memmap(os.path.join(DATA_DIR, 'train.bin'), dtype=np.uint16, mode='r') |
| | else: |
| | data = np.memmap(os.path.join(DATA_DIR, 'val.bin'), dtype=np.uint16, mode='r') |
| | |
| | ix = torch.randint(len(data) - block_size, (batch_size,)) |
| | x = torch.stack([torch.from_numpy((data[i:i+block_size].astype(np.int64))) for i in ix]) |
| | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size].astype(np.int64))) for i in ix]) |
| | |
| | if DEVICE == 'cuda': |
| | x, y = x.pin_memory().to(DEVICE, non_blocking=True), y.pin_memory().to(DEVICE, non_blocking=True) |
| | else: |
| | x, y = x.to(DEVICE), y.to(DEVICE) |
| | |
| | return x, y |
| |
|
| |
|
| | @torch.no_grad() |
| | def estimate_loss(model, ctx, block_size: int, batch_size: int, eval_iters: int = 50): |
| | """Estimates loss on train and validation splits.""" |
| | out = {} |
| | model.eval() |
| | |
| | for split in ['train', 'val']: |
| | losses = torch.zeros(eval_iters) |
| | for k in range(eval_iters): |
| | X, Y = get_batch(split, block_size, batch_size) |
| | with ctx: |
| | logits, loss = model(X, Y) |
| | losses[k] = loss.item() |
| | out[split] = losses.mean() |
| | |
| | model.train() |
| | return out |
| |
|
| |
|
| | def get_lr(it: int, warmup_iters: int, max_iters: int, max_lr: float, min_lr: float) -> float: |
| | """Cosine decay with warmup.""" |
| | if it < warmup_iters: |
| | return max_lr * it / warmup_iters |
| | |
| | if it > max_iters: |
| | return min_lr |
| | |
| | decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) |
| | coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio)) |
| | return min_lr + coeff * (max_lr - min_lr) |
| |
|
| |
|
| | def train(config_name: str = "medium", max_iters: int = 10000): |
| | """Main training loop.""" |
| | |
| | model_cfg = get_config(config_name) |
| | |
| | print("=" * 70) |
| | print(f"๐ง KILLER TEST TRAINING: {model_cfg.name.upper()} MODEL") |
| | print("=" * 70) |
| | |
| | |
| | if not os.path.exists(os.path.join(DATA_DIR, 'train.bin')): |
| | print("โ Data not found!") |
| | print(" Run first: python validation/memory/prepare_large_data.py --size 50") |
| | return |
| | |
| | os.makedirs(CKPT_DIR, exist_ok=True) |
| | |
| | |
| | with open(os.path.join(DATA_DIR, 'meta.pkl'), 'rb') as f: |
| | meta = pickle.load(f) |
| | vocab_size = meta['vocab_size'] |
| | |
| | |
| | with open(os.path.join(DATA_DIR, 'stats.pkl'), 'rb') as f: |
| | data_stats = pickle.load(f) |
| | |
| | print(f"\n๐ Dataset: {data_stats.get('actual_mb', 'N/A'):.1f}MB") |
| | print(f"๐ Vocab size: {vocab_size}") |
| | |
| | |
| | batch_size = 32 if model_cfg.name in ["small", "medium"] else 16 |
| | |
| | |
| | max_lr = { |
| | "small": 1e-3, |
| | "medium": 6e-4, |
| | "large": 3e-4, |
| | "xlarge": 1e-4 |
| | }.get(model_cfg.name, 6e-4) |
| | |
| | min_lr = max_lr / 10 |
| | warmup_iters = 200 |
| | eval_interval = 500 |
| | log_interval = 50 |
| | |
| | torch.manual_seed(1337) |
| | |
| | |
| | print(f"\n๐ง Initializing model {model_cfg.name}...") |
| | |
| | config = RippleConfig( |
| | vocab_size=vocab_size, |
| | block_size=model_cfg.block_size, |
| | n_layer=model_cfg.n_layer, |
| | n_head=model_cfg.n_head, |
| | n_embd=model_cfg.n_embd, |
| | dropout=model_cfg.dropout, |
| | use_absolute_pos_emb=False |
| | ) |
| | |
| | model = RippleGPT(config) |
| | model.to(DEVICE) |
| | |
| | num_params = model.get_num_params() |
| | print(f" Parameters: {num_params / 1e6:.2f}M") |
| | print(f" Device: {DEVICE}") |
| | print(f" Block size: {model_cfg.block_size}") |
| | print(f" Batch size: {batch_size}") |
| | print(f" Max LR: {max_lr}") |
| | print(f" Max iters: {max_iters}") |
| | |
| | |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.99)) |
| | |
| | |
| | from contextlib import nullcontext |
| | ctx = nullcontext() if DEVICE in ['cpu', 'mps'] else torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16) |
| | |
| | |
| | print(f"\n๐ Starting training ({max_iters} iterations)...") |
| | print("-" * 70) |
| | |
| | X, Y = get_batch('train', model_cfg.block_size, batch_size) |
| | t0 = time.time() |
| | best_val_loss = float('inf') |
| | |
| | for iter_num in range(max_iters): |
| | |
| | lr = get_lr(iter_num, warmup_iters, max_iters, max_lr, min_lr) |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = lr |
| | |
| | |
| | if iter_num % eval_interval == 0 and iter_num > 0: |
| | losses = estimate_loss(model, ctx, model_cfg.block_size, batch_size) |
| | print(f"step {iter_num}: train {losses['train']:.4f}, val {losses['val']:.4f}, lr {lr:.2e}") |
| | |
| | if losses['val'] < best_val_loss: |
| | best_val_loss = losses['val'] |
| | checkpoint = { |
| | 'model': model.state_dict(), |
| | 'config': config, |
| | 'model_config_name': model_cfg.name, |
| | 'iter_num': iter_num, |
| | 'best_val_loss': best_val_loss, |
| | } |
| | ckpt_path = os.path.join(CKPT_DIR, f'ckpt_{model_cfg.name}_best.pt') |
| | torch.save(checkpoint, ckpt_path) |
| | print(f" ๐พ Best model saved! (val_loss: {best_val_loss:.4f})") |
| | |
| | |
| | with ctx: |
| | logits, loss = model(X, Y) |
| | |
| | optimizer.zero_grad(set_to_none=True) |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | |
| | optimizer.step() |
| | |
| | |
| | t1 = time.time() |
| | dt = t1 - t0 |
| | t0 = t1 |
| | |
| | if iter_num % log_interval == 0: |
| | print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.0f}ms, lr {lr:.2e}") |
| | |
| | X, Y = get_batch('train', model_cfg.block_size, batch_size) |
| | |
| | |
| | checkpoint = { |
| | 'model': model.state_dict(), |
| | 'config': config, |
| | 'model_config_name': model_cfg.name, |
| | 'iter_num': max_iters, |
| | 'best_val_loss': best_val_loss, |
| | } |
| | torch.save(checkpoint, os.path.join(CKPT_DIR, f'ckpt_{model_cfg.name}_final.pt')) |
| | |
| | print("-" * 70) |
| | print(f"โ
Training complete!") |
| | print(f" Best val loss: {best_val_loss:.4f}") |
| | print(f" Checkpoints at: {CKPT_DIR}") |
| | print(f"\nNext step: python validation/memory/needle_test.py --config {model_cfg.name}") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser(description='Trains model for Killer Test') |
| | parser.add_argument('--config', type=str, default='medium', |
| | choices=['small', 'medium', 'large', 'xlarge'], |
| | help='Model configuration') |
| | parser.add_argument('--iters', type=int, default=10000, help='Number of iterations') |
| | args = parser.parse_args() |
| | |
| | print_configs() |
| | train(args.config, args.iters) |
| |
|