""" Adafactor Optimizer for BitTransformerLM Extensions =================================================== Implementation of the Adafactor optimizer with memory-efficient factorization. Based on "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" research. Key features: - Factorized second moment estimates for memory efficiency - Automatic scaling of learning rates - Relative step size and clip threshold - Compatible with BitTransformerLM's training infrastructure """ import math import torch from torch.optim.optimizer import Optimizer from typing import Any, Dict, List, Optional, Tuple, Union class Adafactor(Optimizer): """ Adafactor optimizer implementation. Adafactor reduces memory usage by factorizing the second moment estimates for parameters with 2 or more dimensions, making it highly memory efficient for large transformer models. Args: params: Iterable of parameters to optimize lr: External learning rate (default: None, uses automatic scaling) eps2: Regularization constant for second moment (default: 1e-30) cliping_threshold: Threshold for adaptive clipping (default: 1.0) decay_rate: Coefficient used for computing running averages (default: -0.8) beta1: Coefficient used for computing running averages of gradient (default: None) weight_decay: Weight decay coefficient (default: 0.0) scale_parameter: If True, learning rate is scaled by root mean square of parameter (default: True) relative_step_size: If True, use relative step size (default: True) warmup_init: If True, warmup learning rate (default: False) """ def __init__( self, params, lr: Optional[float] = None, eps2: float = 1e-30, cliping_threshold: float = 1.0, decay_rate: float = -0.8, beta1: Optional[float] = None, weight_decay: float = 0.0, scale_parameter: bool = True, relative_step_size: bool = True, warmup_init: bool = False, ): if lr is not None and lr <= 0.0: raise ValueError(f"Invalid learning rate: {lr}") if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, eps2=eps2, cliping_threshold=cliping_threshold, decay_rate=decay_rate, beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, relative_step_size=relative_step_size, warmup_init=warmup_init, ) super().__init__(params, defaults) def _get_lr(self, param_group, param_state): """Compute learning rate for parameter group.""" min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 if param_group["scale_parameter"]: param_scale = max(param_group["eps2"], param_state["RMS"]) return param_scale * rel_step_sz def _get_options(self, param_group, param_shape): """Get optimization options for parameter.""" factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None return factored, use_first_moment def _rms(self, tensor): """Root mean square.""" return tensor.norm(2) / (tensor.numel() ** 0.5) def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): """Approximation of exponential moving average of square of gradient.""" r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) .rsqrt_()) c_factor = ((exp_avg_sq_col).rsqrt()) return torch.mul(r_factor.unsqueeze(-1), c_factor.unsqueeze(0)) @torch.no_grad() def step(self, closure=None): """Perform a single optimization step.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() state = self.state[p] grad_shape = grad.shape factored, use_first_moment = self._get_options(group, grad_shape) # State Initialization if len(state) == 0: state["step"] = 0 if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad).float() if factored: state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).float() state["exp_avg_sq_col"] = torch.zeros( grad_shape[:-2] + grad_shape[-1:]).float() else: state["exp_avg_sq"] = torch.zeros_like(grad).float() state["RMS"] = 0 p_data_fp32 = p.data if p.data.dtype in {torch.float16, torch.bfloat16}: p_data_fp32 = p_data_fp32.float() state["step"] += 1 state["RMS"] = self._rms(p_data_fp32) lr = group["lr"] if group["lr"] is None: lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = grad**2 + group["eps2"] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row.mul_(beta2t).add_( update.mean(dim=-1), alpha=1.0 - beta2t) exp_avg_sq_col.mul_(beta2t).add_( update.mean(dim=-2), alpha=1.0 - beta2t) update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) update = exp_avg_sq.rsqrt().mul_(grad) update.div_(max(1.0, self._rms(update) / group["cliping_threshold"])) if use_first_moment: exp_avg = state["exp_avg"] exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) update = exp_avg if group["weight_decay"] != 0: p_data_fp32.mul_(1 - group["weight_decay"] * lr) p_data_fp32.add_(update, alpha=-lr) if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) return loss def configure_adafactor_optimizer( model: torch.nn.Module, lr: Optional[float] = None, weight_decay: float = 0.0, total_steps: Optional[int] = None, warmup_ratio: float = 0.1, scale_parameter: bool = True, relative_step_size: bool = True, warmup_init: bool = False, cliping_threshold: float = 1.0, decay_rate: float = -0.8, beta1: Optional[float] = None, eps2: float = 1e-30, **adafactor_kwargs ) -> Tuple[Adafactor, Optional[torch.optim.lr_scheduler._LRScheduler]]: """ Configure Adafactor optimizer with optional learning rate scheduling. This function provides a drop-in replacement for BitTransformerLM's configure_optimizer function, using Adafactor instead of AdamW. Args: model: PyTorch model to optimize lr: External learning rate (None for automatic scaling) weight_decay: Weight decay coefficient total_steps: Total training steps for scheduling warmup_ratio: Fraction of steps for warmup scale_parameter: Whether to scale learning rate by parameter RMS relative_step_size: Whether to use relative step size warmup_init: Whether to use warmup initialization cliping_threshold: Threshold for adaptive clipping decay_rate: Decay rate for second moment estimates beta1: Coefficient for first moment (None to disable) eps2: Regularization constant **adafactor_kwargs: Additional arguments for Adafactor Returns: Tuple of (optimizer, scheduler) """ # Adafactor can handle all parameters in one group efficiently params = [p for p in model.parameters() if p.requires_grad] optimizer = Adafactor( params, lr=lr, weight_decay=weight_decay, scale_parameter=scale_parameter, relative_step_size=relative_step_size, warmup_init=warmup_init, cliping_threshold=cliping_threshold, decay_rate=decay_rate, beta1=beta1, eps2=eps2, **adafactor_kwargs ) scheduler = None # Adafactor has built-in learning rate scaling, but we can still use OneCycle if total_steps is not None and total_steps > 0 and lr is not None: scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=lr, total_steps=total_steps, pct_start=warmup_ratio, anneal_strategy='cos', cycle_momentum=False, # Adafactor doesn't use momentum cycling div_factor=25.0, final_div_factor=1e4, ) return optimizer, scheduler class AdafactorScheduler(torch.optim.lr_scheduler._LRScheduler): """ Custom scheduler for Adafactor with warmup and polynomial decay. This scheduler is specifically designed to work with Adafactor's relative step size feature. """ def __init__( self, optimizer: Adafactor, warmup_steps: int = 1000, total_steps: Optional[int] = None, min_lr_ratio: float = 0.1, polynomial_power: float = 1.0, last_epoch: int = -1, ): self.warmup_steps = warmup_steps self.total_steps = total_steps self.min_lr_ratio = min_lr_ratio self.polynomial_power = polynomial_power super().__init__(optimizer, last_epoch) def get_lr(self): step = self.last_epoch + 1 if step < self.warmup_steps: # Linear warmup return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs] if self.total_steps is None: # No decay after warmup return self.base_lrs # Polynomial decay progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) progress = min(progress, 1.0) decay_factor = (1 - progress) ** self.polynomial_power decay_factor = max(decay_factor, self.min_lr_ratio) return [base_lr * decay_factor for base_lr in self.base_lrs] def configure_adafactor_with_scheduler( model: torch.nn.Module, lr: float = 1e-3, warmup_steps: int = 1000, total_steps: Optional[int] = None, weight_decay: float = 0.0, **kwargs ) -> Tuple[Adafactor, AdafactorScheduler]: """ Configure Adafactor optimizer with custom Adafactor scheduler. Args: model: PyTorch model to optimize lr: Base learning rate warmup_steps: Number of warmup steps total_steps: Total training steps weight_decay: Weight decay coefficient **kwargs: Additional arguments for Adafactor Returns: Tuple of (optimizer, scheduler) """ params = [p for p in model.parameters() if p.requires_grad] optimizer = Adafactor( params, lr=lr, weight_decay=weight_decay, relative_step_size=False, # We'll use external scheduler **kwargs ) scheduler = AdafactorScheduler( optimizer, warmup_steps=warmup_steps, total_steps=total_steps, ) return optimizer, scheduler def create_adafactor_training_config( lr: Optional[float] = None, weight_decay: float = 0.0, scale_parameter: bool = True, relative_step_size: bool = True, warmup_init: bool = False, **kwargs ) -> Dict[str, Any]: """ Create a training configuration dictionary for Adafactor optimizer. Args: lr: External learning rate (None for automatic) weight_decay: Weight decay coefficient scale_parameter: Whether to scale by parameter RMS relative_step_size: Whether to use relative step size warmup_init: Whether to use warmup initialization **kwargs: Additional configuration options Returns: Dictionary containing training configuration """ config = { "optimizer_type": "adafactor", "optimizer_config": { "lr": lr, "weight_decay": weight_decay, "scale_parameter": scale_parameter, "relative_step_size": relative_step_size, "warmup_init": warmup_init, **kwargs }, "scheduler_type": "adafactor_custom" if lr is None else "onecycle", } return config # Example usage and integration helpers def integrate_with_bittransformerlm(): """ Example of how to integrate Adafactor optimizer with BitTransformerLM training. Usage: from BTLM_Extensions.adafactor_optimizer import configure_adafactor_optimizer # Option 1: Use Adafactor with automatic learning rate scaling optimizer, scheduler = configure_adafactor_optimizer( model, lr=None, total_steps=1000 # lr=None enables auto-scaling ) # Option 2: Use Adafactor with fixed learning rate optimizer, scheduler = configure_adafactor_optimizer( model, lr=1e-3, total_steps=1000 ) # Option 3: Use Adafactor with custom scheduler from BTLM_Extensions.adafactor_optimizer import configure_adafactor_with_scheduler optimizer, scheduler = configure_adafactor_with_scheduler( model, lr=1e-3, warmup_steps=100, total_steps=1000 ) # Use in training loop train_loop(model, data, optimizer=optimizer, scheduler=scheduler) """ pass def analyze_memory_usage(model: torch.nn.Module) -> Dict[str, float]: """ Analyze memory usage comparison between optimizers. Args: model: PyTorch model to analyze Returns: Dictionary with memory usage estimates in MB """ param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) param_bytes = param_count * 4 # Assume float32 # AdamW memory: parameters + gradients + 2 momentum states adamw_memory = param_bytes * 4 # Adafactor memory estimation adafactor_memory = param_bytes # parameters adafactor_memory += param_bytes # gradients # For factored parameters (2D), Adafactor stores row and column means factored_params = 0 unfactored_params = 0 for p in model.parameters(): if p.requires_grad: if len(p.shape) >= 2: factored_params += p.shape[0] + p.shape[1] # row + col means else: unfactored_params += p.numel() adafactor_memory += (factored_params + unfactored_params) * 4 # second moments return { "adamw_mb": adamw_memory / (1024 * 1024), "adafactor_mb": adafactor_memory / (1024 * 1024), "savings_mb": (adamw_memory - adafactor_memory) / (1024 * 1024), "savings_percent": ((adamw_memory - adafactor_memory) / adamw_memory) * 100, } if __name__ == "__main__": # Simple test of the optimizer import torch.nn as nn model = nn.Sequential( nn.Linear(100, 200), nn.ReLU(), nn.Linear(200, 50), nn.ReLU(), nn.Linear(50, 1) ) print("Testing Adafactor optimizer...") # Test with automatic learning rate optimizer, scheduler = configure_adafactor_optimizer( model, lr=None, total_steps=100 ) # Simple training step x = torch.randn(32, 100) y = torch.randn(32, 1) pred = model(x) loss = nn.functional.mse_loss(pred, y) initial_loss = loss.item() loss.backward() optimizer.step() if scheduler: scheduler.step() # Test with fixed learning rate optimizer2, scheduler2 = configure_adafactor_optimizer( model, lr=1e-3, total_steps=100 ) pred = model(x) loss = nn.functional.mse_loss(pred, y) loss.backward() optimizer2.step() if scheduler2: scheduler2.step() # Analyze memory usage memory_analysis = analyze_memory_usage(model) print("Adafactor optimizer test completed successfully!") print(f"Initial loss: {initial_loss:.4f}") print(f"Final loss: {loss.item():.4f}") print(f"Memory analysis: {memory_analysis}")