Transformers
English
Hindi
Sanskrit
sovereign-ai
ecological-intelligence
indian-llm
environmental-protection
iamkoder001 commited on
Commit
fc5998a
·
verified ·
1 Parent(s): cd37343

Create src/training/trainer.py

Browse files
Files changed (1) hide show
  1. src/training/trainer.py +75 -0
src/training/trainer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+ from torch.nn import functional as F
5
+ from torch.cuda.amp import GradScaler, autocast
6
+ import time
7
+
8
+ # --- Sovereign Training Utilities ---
9
+
10
+ def get_batch(data, block_size, batch_size, device):
11
+ """Generates a small batch of data of inputs x and targets y."""
12
+ ix = torch.randint(len(data) - block_size, (batch_size,))
13
+ x = torch.stack([data[i:i+block_size] for i in ix])
14
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
15
+ x, y = x.to(device), y.to(device)
16
+ return x, y
17
+
18
+ class SovereignTrainer:
19
+ def __init__(self, model, optimizer, config, device):
20
+ self.model = model.to(device)
21
+ self.optimizer = optimizer
22
+ self.config = config
23
+ self.device = device
24
+ self.scaler = GradScaler() # For Mixed-Precision Training
25
+ self.block_size = config['model_params']['n_positions']
26
+
27
+ def train_step(self, x, y):
28
+ self.optimizer.zero_grad(set_to_none=True)
29
+
30
+ # 1. Mixed Precision Forward Pass (Speeds up training on modern GPUs)
31
+ with autocast():
32
+ logits, loss = self.model(x, y)
33
+
34
+ # 2. Backpropagation with Scaling
35
+ self.scaler.scale(loss).backward()
36
+
37
+ # 3. Gradient Clipping (Prevents 'Exploding Gradients' in scratch builds)
38
+ self.scaler.unscale_(self.optimizer)
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
40
+
41
+ # 4. Optimizer Step
42
+ self.scaler.step(self.optimizer)
43
+ self.scaler.update()
44
+
45
+ return loss.item()
46
+
47
+ def run_pretraining(self, train_data, val_data, max_iters=10000):
48
+ """The core pre-training loop for ARAVALLI-1."""
49
+ print(f"Sovereign Pre-training Initiated on {self.device}...")
50
+ self.model.train()
51
+
52
+ start_time = time.time()
53
+ for iter in range(max_iters):
54
+ # Fetch batch
55
+ xb, yb = get_batch(train_data, self.block_size, 32, self.device)
56
+
57
+ # Execute step
58
+ loss = self.train_step(xb, yb)
59
+
60
+ # Logging and Checkpointing
61
+ if iter % 100 == 0 or iter == max_iters - 1:
62
+ dt = time.time() - start_time
63
+ print(f"Iter {iter}: Loss {loss:.4f} | Time: {dt:.2f}s")
64
+ # Trigger Sovereign Checkpoint (to be signed by pyHanko)
65
+ self.save_checkpoint(iter)
66
+ start_time = time.time()
67
+
68
+ def save_checkpoint(self, iter):
69
+ checkpoint = {
70
+ 'model': self.model.state_dict(),
71
+ 'optimizer': self.optimizer.state_dict(),
72
+ 'config': self.config,
73
+ 'iter': iter,
74
+ }
75
+ torch.save(checkpoint, f"data/processed/ckpt_iter_{iter}.pt")