""" train_code.py - Trains RippleGPT on Python code for validation. This script uses the prepared dataset to train the model in code completion. The focus is to validate if the architecture can learn code structures. Usage: python validation/train_code.py """ import os import sys import time import pickle import math import numpy as np import torch # Add root directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from src.model import RippleGPT from src.config import RippleConfig # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- # Directories DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') OUT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') # Training Hyperparameters BATCH_SIZE = 32 BLOCK_SIZE = 256 MAX_ITERS = 15000 # Optimized to prevent saturation EVAL_INTERVAL = 500 EVAL_ITERS = 200 LOG_INTERVAL = 100 # Model Hyperparameters (The Sweet Spot) N_LAYER = 6 N_HEAD = 8 N_EMBD = 384 DROPOUT = 0.1 # Optimization LEARNING_RATE = 1e-3 # Restores aggressive LR to learn fast WARMUP_ITERS = 200 # Device DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' # ----------------------------------------------------------------------------- # Helper Functions # ----------------------------------------------------------------------------- def get_batch(split: str, data_dir: str = DATA_DIR): """Loads a data batch.""" if split == 'train': data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') else: data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,)) x = torch.stack([torch.from_numpy((data[i:i+BLOCK_SIZE].astype(np.int64))) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+BLOCK_SIZE].astype(np.int64))) for i in ix]) if DEVICE == 'cuda': x, y = x.pin_memory().to(DEVICE, non_blocking=True), y.pin_memory().to(DEVICE, non_blocking=True) else: x, y = x.to(DEVICE), y.to(DEVICE) return x, y @torch.no_grad() def estimate_loss(model, ctx): """Estimates loss on train and validation splits.""" out = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(EVAL_ITERS) for k in range(EVAL_ITERS): X, Y = get_batch(split) with ctx: logits, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean() model.train() return out def get_lr(it: int) -> float: """Learning rate with linear warmup and cosine decay.""" # 1) Linear Warmup if it < WARMUP_ITERS: return LEARNING_RATE * it / WARMUP_ITERS # 2) If past the end, maintain minimum if it > MAX_ITERS: return LEARNING_RATE * 0.1 # 3) Cosine Decay decay_ratio = (it - WARMUP_ITERS) / (MAX_ITERS - WARMUP_ITERS) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return LEARNING_RATE * (0.1 + 0.9 * coeff) # Decays to 10% of original def train(): """Main training loop.""" print("=" * 60) print("šŸš€ RIPPLEGPT TRAINING FOR CODE COMPLETION") print("=" * 60) # Check if data exists if not os.path.exists(os.path.join(DATA_DIR, 'train.bin')): print("āŒ Data not found!") print(" Run first: python validation/code/prepare_code_data.py") return # Create checkpoints directory os.makedirs(OUT_DIR, exist_ok=True) # Load vocabulary meta_path = os.path.join(DATA_DIR, 'meta.pkl') with open(meta_path, 'rb') as f: meta = pickle.load(f) vocab_size = meta['vocab_size'] print(f"\nšŸ“š Vocab size: {vocab_size}") # Seed for reproducibility torch.manual_seed(1337) # Initialize model print(f"\nšŸ”§ Initializing model...") config = RippleConfig( vocab_size=vocab_size, block_size=BLOCK_SIZE, n_layer=N_LAYER, n_head=N_HEAD, n_embd=N_EMBD, dropout=DROPOUT, use_absolute_pos_emb=False # Use Ripple Field! ) model = RippleGPT(config) model.to(DEVICE) num_params = model.get_num_params() print(f" Parameters: {num_params / 1e6:.2f}M") print(f" Device: {DEVICE}") print(f" Block size: {BLOCK_SIZE}") print(f" Batch size: {BATCH_SIZE}") # Optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) # Autocast context from contextlib import nullcontext ctx = nullcontext() if DEVICE in ['cpu', 'mps'] else torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16) # Training loop print(f"\nšŸ“ˆ Starting training ({MAX_ITERS} iterations)...") print("-" * 60) X, Y = get_batch('train') t0 = time.time() best_val_loss = float('inf') for iter_num in range(MAX_ITERS): # Learning rate scheduling lr = get_lr(iter_num) for param_group in optimizer.param_groups: param_group['lr'] = lr # Periodic evaluation if iter_num % EVAL_INTERVAL == 0 and iter_num > 0: losses = estimate_loss(model, ctx) print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") # Save best model if losses['val'] < best_val_loss: best_val_loss = losses['val'] checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': config, 'iter_num': iter_num, 'best_val_loss': best_val_loss, } torch.save(checkpoint, os.path.join(OUT_DIR, 'ckpt_best.pt')) print(f" šŸ’¾ Best model saved! (val_loss: {best_val_loss:.4f})") # Forward/backward with ctx: logits, loss = model(X, Y) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() # Logging t1 = time.time() dt = t1 - t0 t0 = t1 if iter_num % LOG_INTERVAL == 0: decay_stats = model.get_decay_stats() print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.2f}ms, lr {lr:.6f}") print(f" Ripple Field Stats -> Mean Decay: {decay_stats['mean']:.4f}, Range: [{decay_stats['min']:.4f}, {decay_stats['max']:.4f}]") # Next batch X, Y = get_batch('train') # Save final checkpoint checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': config, 'iter_num': MAX_ITERS, 'best_val_loss': best_val_loss, } torch.save(checkpoint, os.path.join(OUT_DIR, 'ckpt_final.pt')) print("-" * 60) print(f"āœ… Training complete!") print(f" Best val loss: {best_val_loss:.4f}") print(f" Checkpoints saved to: {OUT_DIR}") print(f"\nNext step: python validation/code/validate_code.py") if __name__ == '__main__': train()