Transformers
English
Hindi
Sanskrit
sovereign-ai
ecological-intelligence
indian-llm
environmental-protection
ARAVALLI-1 / src /training /trainer.py
iamkoder001's picture
Create src/training/trainer.py
fc5998a verified
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
import time
# --- Sovereign Training Utilities ---
def get_batch(data, block_size, batch_size, device):
"""Generates a small batch of data of inputs x and targets y."""
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
class SovereignTrainer:
def __init__(self, model, optimizer, config, device):
self.model = model.to(device)
self.optimizer = optimizer
self.config = config
self.device = device
self.scaler = GradScaler() # For Mixed-Precision Training
self.block_size = config['model_params']['n_positions']
def train_step(self, x, y):
self.optimizer.zero_grad(set_to_none=True)
# 1. Mixed Precision Forward Pass (Speeds up training on modern GPUs)
with autocast():
logits, loss = self.model(x, y)
# 2. Backpropagation with Scaling
self.scaler.scale(loss).backward()
# 3. Gradient Clipping (Prevents 'Exploding Gradients' in scratch builds)
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 4. Optimizer Step
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
def run_pretraining(self, train_data, val_data, max_iters=10000):
"""The core pre-training loop for ARAVALLI-1."""
print(f"Sovereign Pre-training Initiated on {self.device}...")
self.model.train()
start_time = time.time()
for iter in range(max_iters):
# Fetch batch
xb, yb = get_batch(train_data, self.block_size, 32, self.device)
# Execute step
loss = self.train_step(xb, yb)
# Logging and Checkpointing
if iter % 100 == 0 or iter == max_iters - 1:
dt = time.time() - start_time
print(f"Iter {iter}: Loss {loss:.4f} | Time: {dt:.2f}s")
# Trigger Sovereign Checkpoint (to be signed by pyHanko)
self.save_checkpoint(iter)
start_time = time.time()
def save_checkpoint(self, iter):
checkpoint = {
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'config': self.config,
'iter': iter,
}
torch.save(checkpoint, f"data/processed/ckpt_iter_{iter}.pt")