""" Advanced training utilities: gradient clipping, LR finder, batch size finder, etc. """ import logging from typing import Callable import torch import torch.nn as nn from torch.utils.data import DataLoader logger = logging.getLogger(__name__) def clip_gradients( model: nn.Module, max_norm: float = 1.0, norm_type: float = 2.0, error_if_nonfinite: bool = False, ) -> float: """ Clip gradients to prevent explosion. Args: model: Model with gradients max_norm: Maximum gradient norm norm_type: Type of norm (2.0 for L2, float('inf') for max norm) error_if_nonfinite: Raise error if gradients are non-finite Returns: Total gradient norm before clipping """ total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, ) if total_norm > max_norm: logger.debug(f"Gradients clipped: {total_norm:.4f} -> {max_norm:.4f}") return total_norm.item() def find_learning_rate( model: nn.Module, train_loader: DataLoader, loss_fn: Callable, optimizer_class: type = torch.optim.AdamW, min_lr: float = 1e-8, max_lr: float = 1.0, num_steps: int = 100, smooth: float = 0.05, ) -> dict: """ Find optimal learning rate using learning rate range test. Based on: https://arxiv.org/abs/1506.01186 Args: model: Model to train train_loader: DataLoader for training loss_fn: Loss function optimizer_class: Optimizer class min_lr: Minimum learning rate to test max_lr: Maximum learning rate to test num_steps: Number of steps to run smooth: Smoothing factor for loss Returns: Dict with: - lrs: List of learning rates tested - losses: List of losses at each LR - best_lr: Recommended learning rate (steepest descent point) """ model.train() lrs = [] losses = [] # Exponential range lr_mult = (max_lr / min_lr) ** (1.0 / num_steps) # Create optimizer with initial LR optimizer = optimizer_class(model.parameters(), lr=min_lr) # Get a batch data_iter = iter(train_loader) batch = next(data_iter) current_lr = min_lr best_lr = min_lr min_loss = float("inf") logger.info("Starting learning rate finder...") for step in range(num_steps): # Update learning rate current_lr = min_lr * (lr_mult**step) for param_group in optimizer.param_groups: param_group["lr"] = current_lr # Forward pass optimizer.zero_grad() if isinstance(batch, dict): images = batch.get("images", batch.get("image")) targets = batch.get("poses_target", batch.get("target")) else: images, targets = batch[0], batch[1] output = model.inference(images) if hasattr(model, "inference") else model(images) loss = loss_fn(output, targets) # Backward pass loss.backward() optimizer.step() # Record lrs.append(current_lr) losses.append(loss.item()) # Smooth losses if step > 0: losses[-1] = smooth * losses[-1] + (1 - smooth) * losses[-2] # Find steepest descent (lowest loss) if losses[-1] < min_loss: min_loss = losses[-1] best_lr = current_lr # Stop if loss explodes if step > 10 and losses[-1] > 10 * min(losses[: step - 10]): logger.warning(f"Loss exploded at LR={current_lr:.2e}, stopping") break # Get next batch if available try: batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) batch = next(data_iter) logger.info(f"LR finder complete. Recommended LR: {best_lr:.2e}") return { "lrs": lrs, "losses": losses, "best_lr": best_lr, "min_loss": min_loss, } def find_optimal_batch_size( model: nn.Module, dataset, loss_fn: Callable, device: str = "cuda", initial_batch_size: int = 1, max_batch_size: int = 64, factor: int = 2, tolerance: int = 3, ) -> dict: """ Automatically find the largest batch size that fits in GPU memory. Uses binary search to find optimal batch size. Args: model: Model to test dataset: Dataset to use loss_fn: Loss function device: Device to use initial_batch_size: Starting batch size max_batch_size: Maximum batch size to try factor: Multiplicative factor for increases tolerance: Number of successful runs before increasing Returns: Dict with optimal batch size and statistics """ model = model.to(device) model.train() current_batch_size = initial_batch_size successful_runs = 0 max_successful_batch = initial_batch_size logger.info("Starting automatic batch size finder...") while current_batch_size <= max_batch_size: try: # Create dataloader with current batch size dataloader = DataLoader( dataset, batch_size=current_batch_size, shuffle=False, num_workers=0, # Single process for testing ) # Try to run a forward and backward pass batch = next(iter(dataloader)) if isinstance(batch, dict): images = batch.get("images", batch.get("image")) else: images = batch[0] images = images.to(device) # Forward pass output = model.inference(images) if hasattr(model, "inference") else model(images) # Dummy loss if isinstance(output, dict): loss = sum(v.mean() for v in output.values() if isinstance(v, torch.Tensor)) else: loss = output.mean() # Backward pass loss.backward() # Clear gradients model.zero_grad() # Clear cache if device == "cuda": torch.cuda.empty_cache() successful_runs += 1 max_successful_batch = current_batch_size logger.info(f"✓ Batch size {current_batch_size} works") # Increase batch size if we've had enough successes if successful_runs >= tolerance: old_size = current_batch_size current_batch_size = min(current_batch_size * factor, max_batch_size) successful_runs = 0 logger.info(f"Increasing batch size: {old_size} -> {current_batch_size}") except RuntimeError as e: if "out of memory" in str(e): logger.warning(f"✗ Batch size {current_batch_size} failed (OOM)") # Clear cache if device == "cuda": torch.cuda.empty_cache() # Binary search: try midpoint if current_batch_size > initial_batch_size: # We found the limit break else: # Start from beginning with smaller size current_batch_size = max(1, current_batch_size // factor) break else: raise logger.info(f"Optimal batch size: {max_successful_batch}") return { "optimal_batch_size": max_successful_batch, "max_tested": current_batch_size, "initial_batch_size": initial_batch_size, } def get_bf16_autocast_context(enable: bool = True): """ Get autocast context for BF16 (bfloat16) mixed precision. BF16 is better than FP16 for training stability while maintaining speed. Args: enable: Whether to enable BF16 Returns: Autocast context manager """ if not enable: return torch.cuda.amp.autocast(enabled=False) # Check if BF16 is supported if not torch.cuda.is_bf16_supported(): logger.warning("BF16 not supported on this GPU, falling back to FP16") return torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) return torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16) def enable_bf16_training(model: nn.Module) -> nn.Module: """ Convert model to use BF16 for training. Args: model: Model to convert Returns: Model with BF16 enabled """ if not torch.cuda.is_bf16_supported(): logger.warning("BF16 not supported, using FP16 instead") return model.half() # Convert model parameters to BF16 model = model.to(torch.bfloat16) logger.info("Model converted to BF16") return model