| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | from torch.nn import functional as F |
| | from torch.cuda.amp import GradScaler, autocast |
| | import time |
| |
|
| | |
| |
|
| | def get_batch(data, block_size, batch_size, device): |
| | """Generates a small batch of data of inputs x and targets y.""" |
| | ix = torch.randint(len(data) - block_size, (batch_size,)) |
| | x = torch.stack([data[i:i+block_size] for i in ix]) |
| | y = torch.stack([data[i+1:i+block_size+1] for i in ix]) |
| | x, y = x.to(device), y.to(device) |
| | return x, y |
| |
|
| | class SovereignTrainer: |
| | def __init__(self, model, optimizer, config, device): |
| | self.model = model.to(device) |
| | self.optimizer = optimizer |
| | self.config = config |
| | self.device = device |
| | self.scaler = GradScaler() |
| | self.block_size = config['model_params']['n_positions'] |
| |
|
| | def train_step(self, x, y): |
| | self.optimizer.zero_grad(set_to_none=True) |
| | |
| | |
| | with autocast(): |
| | logits, loss = self.model(x, y) |
| | |
| | |
| | self.scaler.scale(loss).backward() |
| | |
| | |
| | self.scaler.unscale_(self.optimizer) |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| | |
| | |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | |
| | return loss.item() |
| |
|
| | def run_pretraining(self, train_data, val_data, max_iters=10000): |
| | """The core pre-training loop for ARAVALLI-1.""" |
| | print(f"Sovereign Pre-training Initiated on {self.device}...") |
| | self.model.train() |
| | |
| | start_time = time.time() |
| | for iter in range(max_iters): |
| | |
| | xb, yb = get_batch(train_data, self.block_size, 32, self.device) |
| | |
| | |
| | loss = self.train_step(xb, yb) |
| | |
| | |
| | if iter % 100 == 0 or iter == max_iters - 1: |
| | dt = time.time() - start_time |
| | print(f"Iter {iter}: Loss {loss:.4f} | Time: {dt:.2f}s") |
| | |
| | self.save_checkpoint(iter) |
| | start_time = time.time() |
| |
|
| | def save_checkpoint(self, iter): |
| | checkpoint = { |
| | 'model': self.model.state_dict(), |
| | 'optimizer': self.optimizer.state_dict(), |
| | 'config': self.config, |
| | 'iter': iter, |
| | } |
| | torch.save(checkpoint, f"data/processed/ckpt_iter_{iter}.pt") |
| |
|