|
|
""" |
|
|
OmniCoreX Trainer Module |
|
|
|
|
|
Provides the most super advanced, highest level training routines for OmniCoreX including: |
|
|
- Efficient training loops with mixed precision support |
|
|
- Advanced optimizer and scheduler setup |
|
|
- Checkpoint saving/restoring with state dict management |
|
|
- Gradient accumulation and clipping for large batch training |
|
|
- Multi-device and distributed training ready |
|
|
- Extensive logging and real-time progress tracking |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, |
|
|
model: nn.Module, |
|
|
train_loader: DataLoader, |
|
|
valid_loader: Optional[DataLoader], |
|
|
save_dir: str, |
|
|
lr: float = 5e-5, |
|
|
weight_decay: float = 0.01, |
|
|
max_grad_norm: float = 1.0, |
|
|
accumulation_steps: int = 1, |
|
|
total_steps: int = 100000, |
|
|
warmup_steps: int = 1000, |
|
|
device: Optional[torch.device] = None, |
|
|
mixed_precision: bool = True): |
|
|
""" |
|
|
Initialize the training module. |
|
|
|
|
|
Args: |
|
|
model: OmniCoreX neural network model. |
|
|
train_loader: DataLoader for training data. |
|
|
valid_loader: Optional DataLoader for validation data. |
|
|
save_dir: Directory path to save checkpoints. |
|
|
lr: Learning rate for optimizer. |
|
|
weight_decay: Weight decay coefficient. |
|
|
max_grad_norm: Max gradient norm for clipping. |
|
|
accumulation_steps: Steps to accumulate gradients before optimizer step. |
|
|
total_steps: Total training steps for scheduler. |
|
|
warmup_steps: Warm-up learning rate steps. |
|
|
device: Device for training, default to cuda if available. |
|
|
mixed_precision: Enable AMP for faster training & less memory. |
|
|
""" |
|
|
self.model = model |
|
|
self.train_loader = train_loader |
|
|
self.valid_loader = valid_loader |
|
|
self.save_dir = save_dir |
|
|
self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) |
|
|
self.lr = lr |
|
|
self.weight_decay = weight_decay |
|
|
self.max_grad_norm = max_grad_norm |
|
|
self.accumulation_steps = accumulation_steps |
|
|
self.total_steps = total_steps |
|
|
self.warmup_steps = warmup_steps |
|
|
self.mixed_precision = mixed_precision |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.optimizer = AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) |
|
|
|
|
|
def lr_lambda(current_step): |
|
|
if current_step < self.warmup_steps: |
|
|
return float(current_step) / float(max(1, self.warmup_steps)) |
|
|
return max( |
|
|
0.0, float(self.total_steps - current_step) / float(max(1, self.total_steps - self.warmup_steps)) |
|
|
) |
|
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda) |
|
|
|
|
|
self.scaler = GradScaler(enabled=mixed_precision) |
|
|
|
|
|
os.makedirs(self.save_dir, exist_ok=True) |
|
|
|
|
|
def save_checkpoint(self, step: int) -> None: |
|
|
""" |
|
|
Saves model and optimizer state dictionaries. |
|
|
|
|
|
Args: |
|
|
step: Current training step to tag checkpoint file. |
|
|
""" |
|
|
checkpoint_path = os.path.join(self.save_dir, f"checkpoint_step_{step}.pt") |
|
|
torch.save({ |
|
|
"model_state_dict": self.model.state_dict(), |
|
|
"optimizer_state_dict": self.optimizer.state_dict(), |
|
|
"scheduler_state_dict": self.scheduler.state_dict(), |
|
|
"scaler_state_dict": self.scaler.state_dict(), |
|
|
"step": step, |
|
|
}, checkpoint_path) |
|
|
print(f"[Trainer] Checkpoint saved at step {step} to {checkpoint_path}") |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path: str) -> int: |
|
|
""" |
|
|
Loads model and optimizer state from checkpoint file. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the checkpoint file. |
|
|
|
|
|
Returns: |
|
|
step: The training step resumed from. |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
self.model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
|
|
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) |
|
|
self.scaler.load_state_dict(checkpoint.get("scaler_state_dict", {})) |
|
|
step = checkpoint.get("step", 0) |
|
|
print(f"[Trainer] Loaded checkpoint from {checkpoint_path} at step {step}") |
|
|
return step |
|
|
|
|
|
def train_epoch(self, start_step: int = 0) -> int: |
|
|
""" |
|
|
Runs one full epoch of training with gradient accumulation and mixed precision. |
|
|
|
|
|
Args: |
|
|
start_step: Initial global step count. |
|
|
|
|
|
Returns: |
|
|
Updated global step count after epoch. |
|
|
""" |
|
|
self.model.train() |
|
|
step = start_step |
|
|
optimizer = self.optimizer |
|
|
scheduler = self.scheduler |
|
|
scaler = self.scaler |
|
|
acc_steps = self.accumulation_steps |
|
|
max_grad_norm = self.max_grad_norm |
|
|
|
|
|
running_loss = 0.0 |
|
|
start_time = time.time() |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
for batch_idx, batch in enumerate(self.train_loader): |
|
|
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
|
|
|
|
|
with autocast(enabled=self.mixed_precision): |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
|
|
|
if 'labels' in inputs: |
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
loss = loss_fn(outputs.view(-1, outputs.size(-1)), inputs['labels'].view(-1)) |
|
|
else: |
|
|
|
|
|
loss = outputs.mean() |
|
|
|
|
|
loss = loss / acc_steps |
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
if (batch_idx + 1) % acc_steps == 0 or (batch_idx + 1) == len(self.train_loader): |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
optimizer.zero_grad() |
|
|
scheduler.step() |
|
|
step += 1 |
|
|
|
|
|
running_loss += loss.item() * acc_steps |
|
|
elapsed = time.time() - start_time |
|
|
avg_loss = running_loss / step |
|
|
print(f"Step {step:6d} | Loss: {avg_loss:.6f} | LR: {scheduler.get_last_lr()[0]:.8f} | Time: {elapsed:.2f}s") |
|
|
|
|
|
return step |
|
|
|
|
|
def evaluate(self) -> Dict[str, float]: |
|
|
""" |
|
|
Runs evaluation on validation loader if provided. |
|
|
|
|
|
Returns: |
|
|
Dictionary of evaluation metrics. |
|
|
""" |
|
|
if self.valid_loader is None: |
|
|
print("[Trainer] No validation data provided for evaluation.") |
|
|
return {} |
|
|
|
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
count = 0 |
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in self.valid_loader: |
|
|
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
if 'labels' in inputs: |
|
|
loss = loss_fn(outputs.view(-1, outputs.size(-1)), inputs['labels'].view(-1)) |
|
|
total_loss += loss.item() |
|
|
count += 1 |
|
|
|
|
|
avg_loss = total_loss / count if count > 0 else 0.0 |
|
|
print(f"[Trainer] Validation Loss: {avg_loss:.6f}") |
|
|
return {"validation_loss": avg_loss} |
|
|
|
|
|
def fit(self, |
|
|
epochs: int, |
|
|
start_step: int = 0, |
|
|
checkpoint_interval: int = 1000, |
|
|
validate_interval: int = 1000): |
|
|
""" |
|
|
Runs the full training process including periodic validation and saving. |
|
|
|
|
|
Args: |
|
|
epochs: Number of epochs to train. |
|
|
start_step: Step number to resume from. |
|
|
checkpoint_interval: Save checkpoint every N steps. |
|
|
validate_interval: Run validation every N steps. |
|
|
""" |
|
|
global_step = start_step |
|
|
for epoch in range(epochs): |
|
|
print(f"[Trainer] Starting epoch {epoch + 1}/{epochs}") |
|
|
global_step = self.train_epoch(global_step) |
|
|
|
|
|
if global_step % validate_interval == 0 and self.valid_loader is not None: |
|
|
self.evaluate() |
|
|
|
|
|
if global_step % checkpoint_interval == 0: |
|
|
self.save_checkpoint(global_step) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Trainer module loaded. Instantiate with model and dataloaders for training.") |
|
|
|
|
|
|