""" train_large.py - Trains larger model for the Killer Test. Usage: python validation/memory/train_large.py --config small # 7M params python validation/memory/train_large.py --config medium # 25M params python validation/memory/train_large.py --config large # 50M params """ import os import sys import time import pickle import argparse import numpy as np import torch # Add root directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from src.model import RippleGPT from src.config import RippleConfig from validation.memory.model_configs import get_config, print_configs, ModelConfig # Directories DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') # Device DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' def get_batch(split: str, block_size: int, batch_size: int): """Loads a data batch.""" if split == 'train': data = np.memmap(os.path.join(DATA_DIR, 'train.bin'), dtype=np.uint16, mode='r') else: data = np.memmap(os.path.join(DATA_DIR, 'val.bin'), dtype=np.uint16, mode='r') ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size].astype(np.int64))) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size].astype(np.int64))) for i in ix]) if DEVICE == 'cuda': x, y = x.pin_memory().to(DEVICE, non_blocking=True), y.pin_memory().to(DEVICE, non_blocking=True) else: x, y = x.to(DEVICE), y.to(DEVICE) return x, y @torch.no_grad() def estimate_loss(model, ctx, block_size: int, batch_size: int, eval_iters: int = 50): """Estimates loss on train and validation splits.""" out = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(split, block_size, batch_size) with ctx: logits, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean() model.train() return out def get_lr(it: int, warmup_iters: int, max_iters: int, max_lr: float, min_lr: float) -> float: """Cosine decay with warmup.""" if it < warmup_iters: return max_lr * it / warmup_iters if it > max_iters: return min_lr decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio)) return min_lr + coeff * (max_lr - min_lr) def train(config_name: str = "medium", max_iters: int = 10000): """Main training loop.""" model_cfg = get_config(config_name) print("=" * 70) print(f"🧠 KILLER TEST TRAINING: {model_cfg.name.upper()} MODEL") print("=" * 70) # Check data if not os.path.exists(os.path.join(DATA_DIR, 'train.bin')): print("āŒ Data not found!") print(" Run first: python validation/memory/prepare_large_data.py --size 50") return os.makedirs(CKPT_DIR, exist_ok=True) # Load vocabulary with open(os.path.join(DATA_DIR, 'meta.pkl'), 'rb') as f: meta = pickle.load(f) vocab_size = meta['vocab_size'] # Load dataset stats with open(os.path.join(DATA_DIR, 'stats.pkl'), 'rb') as f: data_stats = pickle.load(f) print(f"\nšŸ“š Dataset: {data_stats.get('actual_mb', 'N/A'):.1f}MB") print(f"šŸ“š Vocab size: {vocab_size}") # Training configuration based on model size batch_size = 32 if model_cfg.name in ["small", "medium"] else 16 # Smaller learning rate for larger models max_lr = { "small": 1e-3, "medium": 6e-4, "large": 3e-4, "xlarge": 1e-4 }.get(model_cfg.name, 6e-4) min_lr = max_lr / 10 warmup_iters = 200 eval_interval = 500 log_interval = 50 torch.manual_seed(1337) # Initialize model print(f"\nšŸ”§ Initializing model {model_cfg.name}...") config = RippleConfig( vocab_size=vocab_size, block_size=model_cfg.block_size, n_layer=model_cfg.n_layer, n_head=model_cfg.n_head, n_embd=model_cfg.n_embd, dropout=model_cfg.dropout, use_absolute_pos_emb=False # Ripple Field! ) model = RippleGPT(config) model.to(DEVICE) num_params = model.get_num_params() print(f" Parameters: {num_params / 1e6:.2f}M") print(f" Device: {DEVICE}") print(f" Block size: {model_cfg.block_size}") print(f" Batch size: {batch_size}") print(f" Max LR: {max_lr}") print(f" Max iters: {max_iters}") # Optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.99)) # Context from contextlib import nullcontext ctx = nullcontext() if DEVICE in ['cpu', 'mps'] else torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16) # Training loop print(f"\nšŸ“ˆ Starting training ({max_iters} iterations)...") print("-" * 70) X, Y = get_batch('train', model_cfg.block_size, batch_size) t0 = time.time() best_val_loss = float('inf') for iter_num in range(max_iters): # LR scheduling lr = get_lr(iter_num, warmup_iters, max_iters, max_lr, min_lr) for param_group in optimizer.param_groups: param_group['lr'] = lr # Evaluation if iter_num % eval_interval == 0 and iter_num > 0: losses = estimate_loss(model, ctx, model_cfg.block_size, batch_size) print(f"step {iter_num}: train {losses['train']:.4f}, val {losses['val']:.4f}, lr {lr:.2e}") if losses['val'] < best_val_loss: best_val_loss = losses['val'] checkpoint = { 'model': model.state_dict(), 'config': config, 'model_config_name': model_cfg.name, 'iter_num': iter_num, 'best_val_loss': best_val_loss, } ckpt_path = os.path.join(CKPT_DIR, f'ckpt_{model_cfg.name}_best.pt') torch.save(checkpoint, ckpt_path) print(f" šŸ’¾ Best model saved! (val_loss: {best_val_loss:.4f})") # Forward/backward with ctx: logits, loss = model(X, Y) optimizer.zero_grad(set_to_none=True) loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() # Logging t1 = time.time() dt = t1 - t0 t0 = t1 if iter_num % log_interval == 0: print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.0f}ms, lr {lr:.2e}") X, Y = get_batch('train', model_cfg.block_size, batch_size) # Final checkpoint checkpoint = { 'model': model.state_dict(), 'config': config, 'model_config_name': model_cfg.name, 'iter_num': max_iters, 'best_val_loss': best_val_loss, } torch.save(checkpoint, os.path.join(CKPT_DIR, f'ckpt_{model_cfg.name}_final.pt')) print("-" * 70) print(f"āœ… Training complete!") print(f" Best val loss: {best_val_loss:.4f}") print(f" Checkpoints at: {CKPT_DIR}") print(f"\nNext step: python validation/memory/needle_test.py --config {model_cfg.name}") if __name__ == '__main__': parser = argparse.ArgumentParser(description='Trains model for Killer Test') parser.add_argument('--config', type=str, default='medium', choices=['small', 'medium', 'large', 'xlarge'], help='Model configuration') parser.add_argument('--iters', type=int, default=10000, help='Number of iterations') args = parser.parse_args() print_configs() train(args.config, args.iters)