import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.nn import functional as F from torch.cuda.amp import GradScaler, autocast import time # --- Sovereign Training Utilities --- def get_batch(data, block_size, batch_size, device): """Generates a small batch of data of inputs x and targets y.""" ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([data[i:i+block_size] for i in ix]) y = torch.stack([data[i+1:i+block_size+1] for i in ix]) x, y = x.to(device), y.to(device) return x, y class SovereignTrainer: def __init__(self, model, optimizer, config, device): self.model = model.to(device) self.optimizer = optimizer self.config = config self.device = device self.scaler = GradScaler() # For Mixed-Precision Training self.block_size = config['model_params']['n_positions'] def train_step(self, x, y): self.optimizer.zero_grad(set_to_none=True) # 1. Mixed Precision Forward Pass (Speeds up training on modern GPUs) with autocast(): logits, loss = self.model(x, y) # 2. Backpropagation with Scaling self.scaler.scale(loss).backward() # 3. Gradient Clipping (Prevents 'Exploding Gradients' in scratch builds) self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 4. Optimizer Step self.scaler.step(self.optimizer) self.scaler.update() return loss.item() def run_pretraining(self, train_data, val_data, max_iters=10000): """The core pre-training loop for ARAVALLI-1.""" print(f"Sovereign Pre-training Initiated on {self.device}...") self.model.train() start_time = time.time() for iter in range(max_iters): # Fetch batch xb, yb = get_batch(train_data, self.block_size, 32, self.device) # Execute step loss = self.train_step(xb, yb) # Logging and Checkpointing if iter % 100 == 0 or iter == max_iters - 1: dt = time.time() - start_time print(f"Iter {iter}: Loss {loss:.4f} | Time: {dt:.2f}s") # Trigger Sovereign Checkpoint (to be signed by pyHanko) self.save_checkpoint(iter) start_time = time.time() def save_checkpoint(self, iter): checkpoint = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'config': self.config, 'iter': iter, } torch.save(checkpoint, f"data/processed/ckpt_iter_{iter}.pt")