Spaces:
Sleeping
Sleeping
File size: 5,694 Bytes
1df0e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
import time
import os
from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats
class Trainer:
def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None):
self.model = model
self.optimizer = optimizer
self.scaler = scaler
self.config = config
self.device = device
self.checkpoint_dir = checkpoint_dir
self.logger = logger
self.model.to(self.device)
def validate(self, val_loader, global_step):
self.model.eval()
total_loss = 0
total_items = 0
num_batches = 100 # Validate on 100 batches to save time
print(f"\n[Validation] Starting validation at step {global_step}...")
with torch.no_grad():
for i, batch in enumerate(val_loader):
if i >= num_batches:
break
input_ids, labels = batch
input_ids = input_ids.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
# Auto-cast context
if self.device.type == 'cuda':
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
use_autocast = True if self.config.torch_dtype != torch.float32 else False
if use_autocast:
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
output = self.model(input_ids, labels)
else:
output = self.model(input_ids, labels)
total_loss += output["loss"].item()
total_items += 1
avg_loss = total_loss / total_items if total_items > 0 else 0
perplexity = torch.exp(torch.tensor(avg_loss)).item()
print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}")
self.model.train()
return avg_loss
def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500):
print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps}\n{'='*70}")
self.model.train()
global_step = start_step
running_loss = 0
print("Initializing data iterator...")
train_iter = iter(train_loader)
print("Fetching first batch...")
while global_step < total_steps:
step_start = time.time()
# Removed periodic cache clearing for performance
self.optimizer.zero_grad(set_to_none=True)
try:
batch = next(train_iter)
if global_step == start_step:
print(f"✓ First batch loaded! Starting training loop...")
except StopIteration:
train_iter = iter(train_loader)
batch = next(train_iter)
input_ids, labels = batch
input_ids = input_ids.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
# Determine autocast dtype
if self.device.type == 'cuda':
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
# Check if we should use autocast (skip if model uses float32)
use_autocast = True
if self.config.torch_dtype == torch.float32:
use_autocast = False
if use_autocast:
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
output = self.model(input_ids, labels)
loss = output["loss"]
else:
output = self.model(input_ids, labels)
loss = output["loss"]
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update")
else:
self.scaler.step(self.optimizer)
self.scaler.update()
global_step += 1
running_loss += loss.item()
if global_step % 10 == 0:
avg_loss = running_loss / 10
t_diff = time.time() - step_start
if self.device.type == 'cuda':
mem = torch.cuda.memory_allocated() / 1e9
max_mem = torch.cuda.max_memory_allocated() / 1e9
mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)"
else:
mem_str = "CPU Mode"
tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff
print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | "
f"{mem_str} | {tokens_per_sec:.0f} tok/s")
running_loss = 0
if global_step % 500 == 0:
save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir)
if val_loader is not None and global_step % eval_every == 0 and global_step > start_step:
self.validate(val_loader, global_step)
return global_step
|