|
|
""" |
|
|
Advanced training utilities: gradient clipping, LR finder, batch size finder, etc. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Callable |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def clip_gradients( |
|
|
model: nn.Module, |
|
|
max_norm: float = 1.0, |
|
|
norm_type: float = 2.0, |
|
|
error_if_nonfinite: bool = False, |
|
|
) -> float: |
|
|
""" |
|
|
Clip gradients to prevent explosion. |
|
|
|
|
|
Args: |
|
|
model: Model with gradients |
|
|
max_norm: Maximum gradient norm |
|
|
norm_type: Type of norm (2.0 for L2, float('inf') for max norm) |
|
|
error_if_nonfinite: Raise error if gradients are non-finite |
|
|
|
|
|
Returns: |
|
|
Total gradient norm before clipping |
|
|
""" |
|
|
total_norm = torch.nn.utils.clip_grad_norm_( |
|
|
model.parameters(), |
|
|
max_norm=max_norm, |
|
|
norm_type=norm_type, |
|
|
error_if_nonfinite=error_if_nonfinite, |
|
|
) |
|
|
|
|
|
if total_norm > max_norm: |
|
|
logger.debug(f"Gradients clipped: {total_norm:.4f} -> {max_norm:.4f}") |
|
|
|
|
|
return total_norm.item() |
|
|
|
|
|
|
|
|
def find_learning_rate( |
|
|
model: nn.Module, |
|
|
train_loader: DataLoader, |
|
|
loss_fn: Callable, |
|
|
optimizer_class: type = torch.optim.AdamW, |
|
|
min_lr: float = 1e-8, |
|
|
max_lr: float = 1.0, |
|
|
num_steps: int = 100, |
|
|
smooth: float = 0.05, |
|
|
) -> dict: |
|
|
""" |
|
|
Find optimal learning rate using learning rate range test. |
|
|
|
|
|
Based on: https://arxiv.org/abs/1506.01186 |
|
|
|
|
|
Args: |
|
|
model: Model to train |
|
|
train_loader: DataLoader for training |
|
|
loss_fn: Loss function |
|
|
optimizer_class: Optimizer class |
|
|
min_lr: Minimum learning rate to test |
|
|
max_lr: Maximum learning rate to test |
|
|
num_steps: Number of steps to run |
|
|
smooth: Smoothing factor for loss |
|
|
|
|
|
Returns: |
|
|
Dict with: |
|
|
- lrs: List of learning rates tested |
|
|
- losses: List of losses at each LR |
|
|
- best_lr: Recommended learning rate (steepest descent point) |
|
|
""" |
|
|
model.train() |
|
|
lrs = [] |
|
|
losses = [] |
|
|
|
|
|
|
|
|
lr_mult = (max_lr / min_lr) ** (1.0 / num_steps) |
|
|
|
|
|
|
|
|
optimizer = optimizer_class(model.parameters(), lr=min_lr) |
|
|
|
|
|
|
|
|
data_iter = iter(train_loader) |
|
|
batch = next(data_iter) |
|
|
|
|
|
current_lr = min_lr |
|
|
best_lr = min_lr |
|
|
min_loss = float("inf") |
|
|
|
|
|
logger.info("Starting learning rate finder...") |
|
|
|
|
|
for step in range(num_steps): |
|
|
|
|
|
current_lr = min_lr * (lr_mult**step) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group["lr"] = current_lr |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
if isinstance(batch, dict): |
|
|
images = batch.get("images", batch.get("image")) |
|
|
targets = batch.get("poses_target", batch.get("target")) |
|
|
else: |
|
|
images, targets = batch[0], batch[1] |
|
|
|
|
|
output = model.inference(images) if hasattr(model, "inference") else model(images) |
|
|
loss = loss_fn(output, targets) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
lrs.append(current_lr) |
|
|
losses.append(loss.item()) |
|
|
|
|
|
|
|
|
if step > 0: |
|
|
losses[-1] = smooth * losses[-1] + (1 - smooth) * losses[-2] |
|
|
|
|
|
|
|
|
if losses[-1] < min_loss: |
|
|
min_loss = losses[-1] |
|
|
best_lr = current_lr |
|
|
|
|
|
|
|
|
if step > 10 and losses[-1] > 10 * min(losses[: step - 10]): |
|
|
logger.warning(f"Loss exploded at LR={current_lr:.2e}, stopping") |
|
|
break |
|
|
|
|
|
|
|
|
try: |
|
|
batch = next(data_iter) |
|
|
except StopIteration: |
|
|
data_iter = iter(train_loader) |
|
|
batch = next(data_iter) |
|
|
|
|
|
logger.info(f"LR finder complete. Recommended LR: {best_lr:.2e}") |
|
|
|
|
|
return { |
|
|
"lrs": lrs, |
|
|
"losses": losses, |
|
|
"best_lr": best_lr, |
|
|
"min_loss": min_loss, |
|
|
} |
|
|
|
|
|
|
|
|
def find_optimal_batch_size( |
|
|
model: nn.Module, |
|
|
dataset, |
|
|
loss_fn: Callable, |
|
|
device: str = "cuda", |
|
|
initial_batch_size: int = 1, |
|
|
max_batch_size: int = 64, |
|
|
factor: int = 2, |
|
|
tolerance: int = 3, |
|
|
) -> dict: |
|
|
""" |
|
|
Automatically find the largest batch size that fits in GPU memory. |
|
|
|
|
|
Uses binary search to find optimal batch size. |
|
|
|
|
|
Args: |
|
|
model: Model to test |
|
|
dataset: Dataset to use |
|
|
loss_fn: Loss function |
|
|
device: Device to use |
|
|
initial_batch_size: Starting batch size |
|
|
max_batch_size: Maximum batch size to try |
|
|
factor: Multiplicative factor for increases |
|
|
tolerance: Number of successful runs before increasing |
|
|
|
|
|
Returns: |
|
|
Dict with optimal batch size and statistics |
|
|
""" |
|
|
model = model.to(device) |
|
|
model.train() |
|
|
|
|
|
current_batch_size = initial_batch_size |
|
|
successful_runs = 0 |
|
|
max_successful_batch = initial_batch_size |
|
|
|
|
|
logger.info("Starting automatic batch size finder...") |
|
|
|
|
|
while current_batch_size <= max_batch_size: |
|
|
try: |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=current_batch_size, |
|
|
shuffle=False, |
|
|
num_workers=0, |
|
|
) |
|
|
|
|
|
|
|
|
batch = next(iter(dataloader)) |
|
|
|
|
|
if isinstance(batch, dict): |
|
|
images = batch.get("images", batch.get("image")) |
|
|
else: |
|
|
images = batch[0] |
|
|
|
|
|
images = images.to(device) |
|
|
|
|
|
|
|
|
output = model.inference(images) if hasattr(model, "inference") else model(images) |
|
|
|
|
|
|
|
|
if isinstance(output, dict): |
|
|
loss = sum(v.mean() for v in output.values() if isinstance(v, torch.Tensor)) |
|
|
else: |
|
|
loss = output.mean() |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
successful_runs += 1 |
|
|
max_successful_batch = current_batch_size |
|
|
|
|
|
logger.info(f"✓ Batch size {current_batch_size} works") |
|
|
|
|
|
|
|
|
if successful_runs >= tolerance: |
|
|
old_size = current_batch_size |
|
|
current_batch_size = min(current_batch_size * factor, max_batch_size) |
|
|
successful_runs = 0 |
|
|
logger.info(f"Increasing batch size: {old_size} -> {current_batch_size}") |
|
|
|
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e): |
|
|
logger.warning(f"✗ Batch size {current_batch_size} failed (OOM)") |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if current_batch_size > initial_batch_size: |
|
|
|
|
|
break |
|
|
else: |
|
|
|
|
|
current_batch_size = max(1, current_batch_size // factor) |
|
|
break |
|
|
else: |
|
|
raise |
|
|
|
|
|
logger.info(f"Optimal batch size: {max_successful_batch}") |
|
|
|
|
|
return { |
|
|
"optimal_batch_size": max_successful_batch, |
|
|
"max_tested": current_batch_size, |
|
|
"initial_batch_size": initial_batch_size, |
|
|
} |
|
|
|
|
|
|
|
|
def get_bf16_autocast_context(enable: bool = True): |
|
|
""" |
|
|
Get autocast context for BF16 (bfloat16) mixed precision. |
|
|
|
|
|
BF16 is better than FP16 for training stability while maintaining speed. |
|
|
|
|
|
Args: |
|
|
enable: Whether to enable BF16 |
|
|
|
|
|
Returns: |
|
|
Autocast context manager |
|
|
""" |
|
|
if not enable: |
|
|
return torch.cuda.amp.autocast(enabled=False) |
|
|
|
|
|
|
|
|
if not torch.cuda.is_bf16_supported(): |
|
|
logger.warning("BF16 not supported on this GPU, falling back to FP16") |
|
|
return torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) |
|
|
|
|
|
return torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
def enable_bf16_training(model: nn.Module) -> nn.Module: |
|
|
""" |
|
|
Convert model to use BF16 for training. |
|
|
|
|
|
Args: |
|
|
model: Model to convert |
|
|
|
|
|
Returns: |
|
|
Model with BF16 enabled |
|
|
""" |
|
|
if not torch.cuda.is_bf16_supported(): |
|
|
logger.warning("BF16 not supported, using FP16 instead") |
|
|
return model.half() |
|
|
|
|
|
|
|
|
model = model.to(torch.bfloat16) |
|
|
logger.info("Model converted to BF16") |
|
|
return model |
|
|
|