""" Muon Optimizer for BitTransformerLM Extensions ============================================== Implementation of the Muon optimizer with orthogonal momentum updates. Based on "Muon: Momentum Orthogonalized by Newton's method" research. Key features: - Orthogonal momentum updates - Better convergence properties than Adam/AdamW - Memory efficient implementation - Compatible with BitTransformerLM's training infrastructure """ import math import torch from torch.optim.optimizer import Optimizer from typing import Any, Dict, List, Optional, Tuple, Union import warnings class Muon(Optimizer): """ Muon optimizer with orthogonal momentum updates. This implementation provides momentum updates that are orthogonalized using Newton's method, leading to more stable training dynamics. Args: params: Iterable of parameters to optimize lr: Learning rate (default: 1e-3) momentum: Momentum factor (default: 0.95) nesterov: Enable Nesterov momentum (default: False) backend: Backend for orthogonalization ('newtonschulz' or 'svd') update_period: Period for updating orthogonalization (default: 1) rank_deficiency_threshold: Threshold for rank deficiency detection eps: Small constant for numerical stability (default: 1e-8) weight_decay: Weight decay coefficient (default: 0.0) """ def __init__( self, params, lr: float = 1e-3, momentum: float = 0.95, nesterov: bool = False, backend: str = "newtonschulz", update_period: int = 1, rank_deficiency_threshold: float = 1e-6, eps: float = 1e-8, weight_decay: float = 0.0, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= momentum <= 1.0: raise ValueError(f"Invalid momentum value: {momentum}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if backend not in ["newtonschulz", "svd"]: raise ValueError(f"Invalid backend: {backend}") defaults = dict( lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, update_period=update_period, rank_deficiency_threshold=rank_deficiency_threshold, eps=eps, weight_decay=weight_decay, ) super().__init__(params, defaults) def _orthogonalize_newtonschulz(self, matrix: torch.Tensor, num_iterations: int = 5) -> torch.Tensor: """Orthogonalize matrix using Newton-Schulz iteration.""" # Handle different shapes original_shape = matrix.shape if matrix.dim() > 2: matrix = matrix.view(-1, matrix.shape[-1]) if matrix.shape[0] >= matrix.shape[1]: # Tall matrix - orthogonalize columns X = matrix.clone() for _ in range(num_iterations): A = X.T @ X X = X @ (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A) else: # Wide matrix - orthogonalize rows X = matrix.clone() for _ in range(num_iterations): A = X @ X.T X = (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A) @ X return X.view(original_shape) def _orthogonalize_svd(self, matrix: torch.Tensor) -> torch.Tensor: """Orthogonalize matrix using SVD decomposition.""" original_shape = matrix.shape if matrix.dim() > 2: matrix = matrix.view(-1, matrix.shape[-1]) try: U, _, Vt = torch.linalg.svd(matrix, full_matrices=False) orthogonal = U @ Vt return orthogonal.view(original_shape) except torch._C._LinAlgError: # Fallback to Newton-Schulz if SVD fails warnings.warn("SVD failed, falling back to Newton-Schulz") return self._orthogonalize_newtonschulz(matrix) @torch.no_grad() def step(self, closure=None): """Perform a single optimization step.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) momentum_buffer = state["momentum_buffer"] state["step"] += 1 # Weight decay if group["weight_decay"] != 0: grad = grad.add(p, alpha=group["weight_decay"]) # Apply momentum momentum_buffer.mul_(group["momentum"]).add_(grad) # Orthogonalize momentum every update_period steps if state["step"] % group["update_period"] == 0 and momentum_buffer.numel() > 1: # Only orthogonalize if we have sufficient dimensions if momentum_buffer.dim() >= 2 and min(momentum_buffer.shape[-2:]) > 1: if group["backend"] == "newtonschulz": orthogonal_momentum = self._orthogonalize_newtonschulz(momentum_buffer) else: orthogonal_momentum = self._orthogonalize_svd(momentum_buffer) # Check for rank deficiency rank_ratio = torch.linalg.matrix_norm(orthogonal_momentum) / torch.linalg.matrix_norm(momentum_buffer) if rank_ratio < group["rank_deficiency_threshold"]: warnings.warn("Detected rank deficiency in momentum buffer") else: momentum_buffer.copy_(orthogonal_momentum) # Apply Nesterov acceleration if enabled if group["nesterov"]: update = grad.add(momentum_buffer, alpha=group["momentum"]) else: update = momentum_buffer # Apply update p.add_(update, alpha=-group["lr"]) return loss def configure_muon_optimizer( model: torch.nn.Module, lr: float = 1e-3, momentum: float = 0.95, weight_decay: float = 0.01, total_steps: Optional[int] = None, warmup_ratio: float = 0.1, nesterov: bool = False, backend: str = "newtonschulz", **muon_kwargs ) -> Tuple[Muon, Optional[torch.optim.lr_scheduler._LRScheduler]]: """ Configure Muon optimizer with OneCycle learning rate schedule. This function provides a drop-in replacement for BitTransformerLM's configure_optimizer function, using Muon instead of AdamW. Args: model: PyTorch model to optimize lr: Peak learning rate momentum: Momentum factor for Muon weight_decay: Weight decay coefficient total_steps: Total training steps for OneCycle schedule warmup_ratio: Fraction of steps for warmup nesterov: Enable Nesterov momentum backend: Orthogonalization backend **muon_kwargs: Additional arguments for Muon optimizer Returns: Tuple of (optimizer, scheduler) """ # Filter parameters that need weight decay decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # Apply weight decay to weights but not biases/norms if param.dim() >= 2: decay_params.append(param) else: no_decay_params.append(param) param_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] optimizer = Muon( param_groups, lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, **muon_kwargs ) scheduler = None if total_steps is not None and total_steps > 0: scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=lr, total_steps=total_steps, pct_start=warmup_ratio, anneal_strategy='cos', cycle_momentum=False, # Muon handles momentum internally div_factor=25.0, final_div_factor=1e4, ) return optimizer, scheduler def create_muon_training_config( lr: float = 1e-3, momentum: float = 0.95, weight_decay: float = 0.01, backend: str = "newtonschulz", nesterov: bool = False, **kwargs ) -> Dict[str, Any]: """ Create a training configuration dictionary for Muon optimizer. This can be used with BitTransformerLM's training scripts by passing the config to the training loop. Args: lr: Learning rate momentum: Momentum factor weight_decay: Weight decay coefficient backend: Orthogonalization backend nesterov: Enable Nesterov momentum **kwargs: Additional configuration options Returns: Dictionary containing training configuration """ config = { "optimizer_type": "muon", "optimizer_config": { "lr": lr, "momentum": momentum, "weight_decay": weight_decay, "backend": backend, "nesterov": nesterov, **kwargs }, "scheduler_type": "onecycle", } return config # Example usage and integration helpers def integrate_with_bittransformerlm(): """ Example of how to integrate Muon optimizer with BitTransformerLM training. Usage: from BTLM_Extensions.muon_optimizer import configure_muon_optimizer # Replace the standard optimizer configuration optimizer, scheduler = configure_muon_optimizer( model, lr=1e-3, momentum=0.95, total_steps=1000 ) # Use in training loop train_loop(model, data, optimizer=optimizer, scheduler=scheduler) """ pass if __name__ == "__main__": # Simple test of the optimizer import torch.nn as nn model = nn.Sequential( nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 1) ) optimizer, scheduler = configure_muon_optimizer(model, lr=1e-3, total_steps=100) # Simple training step x = torch.randn(32, 10) y = torch.randn(32, 1) pred = model(x) loss = nn.functional.mse_loss(pred, y) loss.backward() optimizer.step() if scheduler: scheduler.step() print("Muon optimizer test completed successfully!") print(f"Loss: {loss.item():.4f}")