File size: 7,517 Bytes
ef6446c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
#!/usr/bin/env python3
"""
Mixed Precision Training Utilities
This module provides utilities for mixed precision training using PyTorch's
automatic mixed precision (AMP) to improve training speed and reduce memory usage.
Author: Louis Chua Bean Chong
License: GPLv3
"""
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from typing import Optional, Callable
class MixedPrecisionTrainer:
"""
Mixed precision training wrapper for improved performance.
This class provides automatic mixed precision training capabilities
that can significantly improve training speed and reduce memory usage
on compatible hardware (especially NVIDIA GPUs with Tensor Cores).
"""
def __init__(self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
device: str = "auto",
dtype: torch.dtype = torch.float16,
enabled: bool = True):
"""
Initialize mixed precision trainer.
Args:
model: The model to train
optimizer: The optimizer to use
device: Device to use ("auto", "cpu", "cuda")
dtype: Precision dtype (float16, bfloat16)
enabled: Whether to enable mixed precision
"""
self.model = model
self.optimizer = optimizer
self.device = self._get_device(device)
self.dtype = dtype
self.enabled = enabled and self.device.type == "cuda"
# Initialize gradient scaler for mixed precision
self.scaler = GradScaler() if self.enabled else None
# Move model to device
self.model.to(self.device)
print(f"Mixed Precision Training: {'Enabled' if self.enabled else 'Disabled'}")
print(f"Device: {self.device}")
print(f"Precision: {self.dtype}")
def _get_device(self, device: str) -> torch.device:
"""Get the appropriate device."""
if device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
else:
return torch.device(device)
def train_step(self,
batch: torch.Tensor,
targets: torch.Tensor,
loss_fn: Optional[Callable] = None) -> dict:
"""
Perform a single training step with mixed precision.
Args:
batch: Input batch
targets: Target batch
loss_fn: Optional custom loss function
Returns:
dict: Training metrics
"""
self.model.train()
self.optimizer.zero_grad()
# Move data to device
batch = batch.to(self.device)
targets = targets.to(self.device)
if self.enabled:
# Mixed precision forward pass
with autocast(dtype=self.dtype):
if loss_fn is None:
# Use model's built-in loss computation
logits, loss = self.model(batch, targets)
else:
# Use custom loss function
logits = self.model(batch)
loss = loss_fn(logits, targets)
# Scaled backward pass
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Standard precision training
if loss_fn is None:
logits, loss = self.model(batch, targets)
else:
logits = self.model(batch)
loss = loss_fn(logits, targets)
loss.backward()
self.optimizer.step()
return {
"loss": loss.item(),
"logits": logits,
"scaler_scale": self.scaler.get_scale() if self.scaler else 1.0
}
def eval_step(self,
batch: torch.Tensor,
targets: torch.Tensor,
loss_fn: Optional[Callable] = None) -> dict:
"""
Perform a single evaluation step.
Args:
batch: Input batch
targets: Target batch
loss_fn: Optional custom loss function
Returns:
dict: Evaluation metrics
"""
self.model.eval()
# Move data to device
batch = batch.to(self.device)
targets = targets.to(self.device)
with torch.no_grad():
if self.enabled:
with autocast(dtype=self.dtype):
if loss_fn is None:
logits, loss = self.model(batch, targets)
else:
logits = self.model(batch)
loss = loss_fn(logits, targets)
else:
if loss_fn is None:
logits, loss = self.model(batch, targets)
else:
logits = self.model(batch)
loss = loss_fn(logits, targets)
return {
"loss": loss.item(),
"logits": logits
}
def save_checkpoint(self, path: str, **kwargs):
"""Save model checkpoint with mixed precision state."""
checkpoint = {
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scaler_state_dict": self.scaler.state_dict() if self.scaler else None,
"dtype": self.dtype,
"enabled": self.enabled,
**kwargs
}
torch.save(checkpoint, path)
def load_checkpoint(self, path: str):
"""Load model checkpoint with mixed precision state."""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scaler and checkpoint.get("scaler_state_dict"):
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
return checkpoint
def enable_mixed_precision(model: nn.Module,
optimizer: torch.optim.Optimizer,
**kwargs) -> MixedPrecisionTrainer:
"""
Convenience function to enable mixed precision training.
Args:
model: The model to train
optimizer: The optimizer to use
**kwargs: Additional arguments for MixedPrecisionTrainer
Returns:
MixedPrecisionTrainer: Configured trainer
"""
return MixedPrecisionTrainer(model, optimizer, **kwargs)
def get_optimal_dtype() -> torch.dtype:
"""
Get the optimal dtype for mixed precision training.
Returns:
torch.dtype: Optimal dtype (bfloat16 for newer GPUs, float16 for older)
"""
if torch.cuda.is_available():
# Check if bfloat16 is supported
if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
return torch.bfloat16
else:
return torch.float16
else:
return torch.float32
|