""" Lion Optimizer for BitTransformerLM Extensions ============================================== Implementation of the Lion optimizer (EvoLved Sign Momentum). Based on "Symbolic Discovery of Optimization Algorithms" research. Key features: - Sign-based momentum updates - Extremely memory efficient (only stores momentum) - Often outperforms Adam/AdamW with larger learning rates - Compatible with BitTransformerLM's training infrastructure """ import torch from torch.optim.optimizer import Optimizer from typing import Any, Dict, List, Optional, Tuple, Union class Lion(Optimizer): """ Lion optimizer implementation. Lion uses the sign of the interpolated momentum for parameter updates, making it very memory efficient while maintaining competitive performance. Args: params: Iterable of parameters to optimize lr: Learning rate (default: 1e-4, typically needs to be smaller than Adam) betas: Coefficients for computing momentum (default: (0.9, 0.99)) weight_decay: Weight decay coefficient (default: 0.0) eps: Small constant for numerical stability (default: 1e-8) maximize: Whether to maximize the objective (default: False) """ def __init__( self, params, lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, eps: float = 1e-8, maximize: bool = False, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") defaults = dict( lr=lr, betas=betas, weight_decay=weight_decay, eps=eps, maximize=maximize, ) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): """Perform a single optimization step.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if group["maximize"]: grad = -grad if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() state = self.state[p] # State initialization if len(state) == 0: state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) momentum = state["momentum"] beta1, beta2 = group["betas"] # Weight decay (applied to parameters, not gradients) if group["weight_decay"] != 0: p.mul_(1 - group["lr"] * group["weight_decay"]) # Interpolate between momentum and gradient # c_t = beta1 * m_{t-1} + (1 - beta1) * g_t interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) # Update parameters using sign of interpolated momentum # theta_t = theta_{t-1} - lr * sign(c_t) p.add_(torch.sign(interpolated), alpha=-group["lr"]) # Update momentum # m_t = beta2 * m_{t-1} + (1 - beta2) * g_t momentum.mul_(beta2).add_(grad, alpha=1 - beta2) return loss def configure_lion_optimizer( model: torch.nn.Module, lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.01, total_steps: Optional[int] = None, warmup_ratio: float = 0.1, **lion_kwargs ) -> Tuple[Lion, Optional[torch.optim.lr_scheduler._LRScheduler]]: """ Configure Lion optimizer with OneCycle learning rate schedule. This function provides a drop-in replacement for BitTransformerLM's configure_optimizer function, using Lion instead of AdamW. Note: Lion typically works well with learning rates about 3-10x smaller than Adam/AdamW, but higher weight decay (0.01-0.1). Args: model: PyTorch model to optimize lr: Peak learning rate (typically smaller than Adam) betas: Beta coefficients for momentum computation weight_decay: Weight decay coefficient (can be higher than Adam) total_steps: Total training steps for OneCycle schedule warmup_ratio: Fraction of steps for warmup **lion_kwargs: Additional arguments for Lion optimizer Returns: Tuple of (optimizer, scheduler) """ # Filter parameters that need weight decay decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # Apply weight decay to weights but not biases/norms if param.dim() >= 2: decay_params.append(param) else: no_decay_params.append(param) param_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] optimizer = Lion( param_groups, lr=lr, betas=betas, **lion_kwargs ) scheduler = None if total_steps is not None and total_steps > 0: scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=lr, total_steps=total_steps, pct_start=warmup_ratio, anneal_strategy='cos', cycle_momentum=False, # Lion doesn't use cycling momentum div_factor=25.0, final_div_factor=1e4, ) return optimizer, scheduler def create_lion_training_config( lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.01, **kwargs ) -> Dict[str, Any]: """ Create a training configuration dictionary for Lion optimizer. This can be used with BitTransformerLM's training scripts by passing the config to the training loop. Args: lr: Learning rate betas: Beta coefficients for momentum weight_decay: Weight decay coefficient **kwargs: Additional configuration options Returns: Dictionary containing training configuration """ config = { "optimizer_type": "lion", "optimizer_config": { "lr": lr, "betas": betas, "weight_decay": weight_decay, **kwargs }, "scheduler_type": "onecycle", } return config class AdaptiveLion(Lion): """ Enhanced Lion optimizer with adaptive learning rate scaling. This variant automatically adjusts the learning rate based on the magnitude of gradients and momentum, potentially improving stability. """ def __init__( self, params, lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, eps: float = 1e-8, maximize: bool = False, adaptive_scale: float = 0.1, min_scale: float = 0.01, max_scale: float = 10.0, ): """ Args: adaptive_scale: Scaling factor for adaptive adjustment min_scale: Minimum learning rate scale max_scale: Maximum learning rate scale """ self.adaptive_scale = adaptive_scale self.min_scale = min_scale self.max_scale = max_scale super().__init__(params, lr, betas, weight_decay, eps, maximize) @torch.no_grad() def step(self, closure=None): """Perform optimization step with adaptive scaling.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if group["maximize"]: grad = -grad if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() state = self.state[p] if len(state) == 0: state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) state["step"] = 0 momentum = state["momentum"] state["step"] += 1 beta1, beta2 = group["betas"] # Adaptive learning rate based on gradient magnitude grad_norm = grad.norm().item() momentum_norm = momentum.norm().item() # Scale learning rate based on gradient/momentum ratio if momentum_norm > 1e-8: scale = 1.0 + self.adaptive_scale * (grad_norm / momentum_norm - 1.0) scale = torch.clamp(torch.tensor(scale), self.min_scale, self.max_scale).item() else: scale = 1.0 adaptive_lr = group["lr"] * scale # Weight decay if group["weight_decay"] != 0: p.mul_(1 - adaptive_lr * group["weight_decay"]) # Lion update with adaptive learning rate interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) p.add_(torch.sign(interpolated), alpha=-adaptive_lr) momentum.mul_(beta2).add_(grad, alpha=1 - beta2) return loss def configure_adaptive_lion_optimizer( model: torch.nn.Module, lr: float = 1e-4, adaptive_scale: float = 0.1, **kwargs ) -> Tuple[AdaptiveLion, Optional[torch.optim.lr_scheduler._LRScheduler]]: """Configure AdaptiveLion optimizer with learning rate scheduling.""" # Similar to configure_lion_optimizer but with AdaptiveLion decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if param.dim() >= 2: decay_params.append(param) else: no_decay_params.append(param) param_groups = [ {"params": decay_params, "weight_decay": kwargs.get("weight_decay", 0.01)}, {"params": no_decay_params, "weight_decay": 0.0}, ] optimizer = AdaptiveLion( param_groups, lr=lr, adaptive_scale=adaptive_scale, **{k: v for k, v in kwargs.items() if k != "weight_decay"} ) scheduler = None total_steps = kwargs.get("total_steps") if total_steps is not None and total_steps > 0: scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=lr, total_steps=total_steps, pct_start=kwargs.get("warmup_ratio", 0.1), anneal_strategy='cos', cycle_momentum=False, div_factor=25.0, final_div_factor=1e4, ) return optimizer, scheduler # Example usage and integration helpers def integrate_with_bittransformerlm(): """ Example of how to integrate Lion optimizer with BitTransformerLM training. Usage: from BTLM_Extensions.lion_optimizer import configure_lion_optimizer # Replace the standard optimizer configuration # Note: Lion typically needs smaller learning rates than Adam optimizer, scheduler = configure_lion_optimizer( model, lr=1e-4, weight_decay=0.01, total_steps=1000 ) # Use in training loop train_loop(model, data, optimizer=optimizer, scheduler=scheduler) # For adaptive version: from BTLM_Extensions.lion_optimizer import configure_adaptive_lion_optimizer optimizer, scheduler = configure_adaptive_lion_optimizer( model, lr=1e-4, adaptive_scale=0.1, total_steps=1000 ) """ pass if __name__ == "__main__": # Simple test of the optimizer import torch.nn as nn model = nn.Sequential( nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 1) ) print("Testing standard Lion optimizer...") optimizer, scheduler = configure_lion_optimizer(model, lr=1e-4, total_steps=100) # Simple training step x = torch.randn(32, 10) y = torch.randn(32, 1) pred = model(x) loss = nn.functional.mse_loss(pred, y) initial_loss = loss.item() loss.backward() optimizer.step() if scheduler: scheduler.step() print(f"Initial loss: {initial_loss:.4f}") # Test adaptive version print("Testing Adaptive Lion optimizer...") model2 = nn.Sequential( nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 1) ) optimizer2, scheduler2 = configure_adaptive_lion_optimizer( model2, lr=1e-4, adaptive_scale=0.1, total_steps=100 ) pred2 = model2(x) loss2 = nn.functional.mse_loss(pred2, y) loss2.backward() optimizer2.step() if scheduler2: scheduler2.step() print("Lion optimizers test completed successfully!") print(f"Standard Lion loss: {initial_loss:.4f}") print(f"Adaptive Lion loss: {loss2.item():.4f}")