3d_model / ylff /utils /training_utils.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Advanced training utilities: gradient clipping, LR finder, batch size finder, etc.
"""
import logging
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
def clip_gradients(
model: nn.Module,
max_norm: float = 1.0,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
) -> float:
"""
Clip gradients to prevent explosion.
Args:
model: Model with gradients
max_norm: Maximum gradient norm
norm_type: Type of norm (2.0 for L2, float('inf') for max norm)
error_if_nonfinite: Raise error if gradients are non-finite
Returns:
Total gradient norm before clipping
"""
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
)
if total_norm > max_norm:
logger.debug(f"Gradients clipped: {total_norm:.4f} -> {max_norm:.4f}")
return total_norm.item()
def find_learning_rate(
model: nn.Module,
train_loader: DataLoader,
loss_fn: Callable,
optimizer_class: type = torch.optim.AdamW,
min_lr: float = 1e-8,
max_lr: float = 1.0,
num_steps: int = 100,
smooth: float = 0.05,
) -> dict:
"""
Find optimal learning rate using learning rate range test.
Based on: https://arxiv.org/abs/1506.01186
Args:
model: Model to train
train_loader: DataLoader for training
loss_fn: Loss function
optimizer_class: Optimizer class
min_lr: Minimum learning rate to test
max_lr: Maximum learning rate to test
num_steps: Number of steps to run
smooth: Smoothing factor for loss
Returns:
Dict with:
- lrs: List of learning rates tested
- losses: List of losses at each LR
- best_lr: Recommended learning rate (steepest descent point)
"""
model.train()
lrs = []
losses = []
# Exponential range
lr_mult = (max_lr / min_lr) ** (1.0 / num_steps)
# Create optimizer with initial LR
optimizer = optimizer_class(model.parameters(), lr=min_lr)
# Get a batch
data_iter = iter(train_loader)
batch = next(data_iter)
current_lr = min_lr
best_lr = min_lr
min_loss = float("inf")
logger.info("Starting learning rate finder...")
for step in range(num_steps):
# Update learning rate
current_lr = min_lr * (lr_mult**step)
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# Forward pass
optimizer.zero_grad()
if isinstance(batch, dict):
images = batch.get("images", batch.get("image"))
targets = batch.get("poses_target", batch.get("target"))
else:
images, targets = batch[0], batch[1]
output = model.inference(images) if hasattr(model, "inference") else model(images)
loss = loss_fn(output, targets)
# Backward pass
loss.backward()
optimizer.step()
# Record
lrs.append(current_lr)
losses.append(loss.item())
# Smooth losses
if step > 0:
losses[-1] = smooth * losses[-1] + (1 - smooth) * losses[-2]
# Find steepest descent (lowest loss)
if losses[-1] < min_loss:
min_loss = losses[-1]
best_lr = current_lr
# Stop if loss explodes
if step > 10 and losses[-1] > 10 * min(losses[: step - 10]):
logger.warning(f"Loss exploded at LR={current_lr:.2e}, stopping")
break
# Get next batch if available
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
batch = next(data_iter)
logger.info(f"LR finder complete. Recommended LR: {best_lr:.2e}")
return {
"lrs": lrs,
"losses": losses,
"best_lr": best_lr,
"min_loss": min_loss,
}
def find_optimal_batch_size(
model: nn.Module,
dataset,
loss_fn: Callable,
device: str = "cuda",
initial_batch_size: int = 1,
max_batch_size: int = 64,
factor: int = 2,
tolerance: int = 3,
) -> dict:
"""
Automatically find the largest batch size that fits in GPU memory.
Uses binary search to find optimal batch size.
Args:
model: Model to test
dataset: Dataset to use
loss_fn: Loss function
device: Device to use
initial_batch_size: Starting batch size
max_batch_size: Maximum batch size to try
factor: Multiplicative factor for increases
tolerance: Number of successful runs before increasing
Returns:
Dict with optimal batch size and statistics
"""
model = model.to(device)
model.train()
current_batch_size = initial_batch_size
successful_runs = 0
max_successful_batch = initial_batch_size
logger.info("Starting automatic batch size finder...")
while current_batch_size <= max_batch_size:
try:
# Create dataloader with current batch size
dataloader = DataLoader(
dataset,
batch_size=current_batch_size,
shuffle=False,
num_workers=0, # Single process for testing
)
# Try to run a forward and backward pass
batch = next(iter(dataloader))
if isinstance(batch, dict):
images = batch.get("images", batch.get("image"))
else:
images = batch[0]
images = images.to(device)
# Forward pass
output = model.inference(images) if hasattr(model, "inference") else model(images)
# Dummy loss
if isinstance(output, dict):
loss = sum(v.mean() for v in output.values() if isinstance(v, torch.Tensor))
else:
loss = output.mean()
# Backward pass
loss.backward()
# Clear gradients
model.zero_grad()
# Clear cache
if device == "cuda":
torch.cuda.empty_cache()
successful_runs += 1
max_successful_batch = current_batch_size
logger.info(f"✓ Batch size {current_batch_size} works")
# Increase batch size if we've had enough successes
if successful_runs >= tolerance:
old_size = current_batch_size
current_batch_size = min(current_batch_size * factor, max_batch_size)
successful_runs = 0
logger.info(f"Increasing batch size: {old_size} -> {current_batch_size}")
except RuntimeError as e:
if "out of memory" in str(e):
logger.warning(f"✗ Batch size {current_batch_size} failed (OOM)")
# Clear cache
if device == "cuda":
torch.cuda.empty_cache()
# Binary search: try midpoint
if current_batch_size > initial_batch_size:
# We found the limit
break
else:
# Start from beginning with smaller size
current_batch_size = max(1, current_batch_size // factor)
break
else:
raise
logger.info(f"Optimal batch size: {max_successful_batch}")
return {
"optimal_batch_size": max_successful_batch,
"max_tested": current_batch_size,
"initial_batch_size": initial_batch_size,
}
def get_bf16_autocast_context(enable: bool = True):
"""
Get autocast context for BF16 (bfloat16) mixed precision.
BF16 is better than FP16 for training stability while maintaining speed.
Args:
enable: Whether to enable BF16
Returns:
Autocast context manager
"""
if not enable:
return torch.cuda.amp.autocast(enabled=False)
# Check if BF16 is supported
if not torch.cuda.is_bf16_supported():
logger.warning("BF16 not supported on this GPU, falling back to FP16")
return torch.cuda.amp.autocast(enabled=True, dtype=torch.float16)
return torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16)
def enable_bf16_training(model: nn.Module) -> nn.Module:
"""
Convert model to use BF16 for training.
Args:
model: Model to convert
Returns:
Model with BF16 enabled
"""
if not torch.cuda.is_bf16_supported():
logger.warning("BF16 not supported, using FP16 instead")
return model.half()
# Convert model parameters to BF16
model = model.to(torch.bfloat16)
logger.info("Model converted to BF16")
return model