|
|
""" |
|
|
Muon Optimizer for BitTransformerLM Extensions |
|
|
============================================== |
|
|
|
|
|
Implementation of the Muon optimizer with orthogonal momentum updates. |
|
|
Based on "Muon: Momentum Orthogonalized by Newton's method" research. |
|
|
|
|
|
Key features: |
|
|
- Orthogonal momentum updates |
|
|
- Better convergence properties than Adam/AdamW |
|
|
- Memory efficient implementation |
|
|
- Compatible with BitTransformerLM's training infrastructure |
|
|
""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
from torch.optim.optimizer import Optimizer |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
import warnings |
|
|
|
|
|
|
|
|
class Muon(Optimizer): |
|
|
""" |
|
|
Muon optimizer with orthogonal momentum updates. |
|
|
|
|
|
This implementation provides momentum updates that are orthogonalized using |
|
|
Newton's method, leading to more stable training dynamics. |
|
|
|
|
|
Args: |
|
|
params: Iterable of parameters to optimize |
|
|
lr: Learning rate (default: 1e-3) |
|
|
momentum: Momentum factor (default: 0.95) |
|
|
nesterov: Enable Nesterov momentum (default: False) |
|
|
backend: Backend for orthogonalization ('newtonschulz' or 'svd') |
|
|
update_period: Period for updating orthogonalization (default: 1) |
|
|
rank_deficiency_threshold: Threshold for rank deficiency detection |
|
|
eps: Small constant for numerical stability (default: 1e-8) |
|
|
weight_decay: Weight decay coefficient (default: 0.0) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
lr: float = 1e-3, |
|
|
momentum: float = 0.95, |
|
|
nesterov: bool = False, |
|
|
backend: str = "newtonschulz", |
|
|
update_period: int = 1, |
|
|
rank_deficiency_threshold: float = 1e-6, |
|
|
eps: float = 1e-8, |
|
|
weight_decay: float = 0.0, |
|
|
): |
|
|
if not 0.0 <= lr: |
|
|
raise ValueError(f"Invalid learning rate: {lr}") |
|
|
if not 0.0 <= momentum <= 1.0: |
|
|
raise ValueError(f"Invalid momentum value: {momentum}") |
|
|
if not 0.0 <= weight_decay: |
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
|
|
if backend not in ["newtonschulz", "svd"]: |
|
|
raise ValueError(f"Invalid backend: {backend}") |
|
|
|
|
|
defaults = dict( |
|
|
lr=lr, |
|
|
momentum=momentum, |
|
|
nesterov=nesterov, |
|
|
backend=backend, |
|
|
update_period=update_period, |
|
|
rank_deficiency_threshold=rank_deficiency_threshold, |
|
|
eps=eps, |
|
|
weight_decay=weight_decay, |
|
|
) |
|
|
super().__init__(params, defaults) |
|
|
|
|
|
def _orthogonalize_newtonschulz(self, matrix: torch.Tensor, num_iterations: int = 5) -> torch.Tensor: |
|
|
"""Orthogonalize matrix using Newton-Schulz iteration.""" |
|
|
|
|
|
original_shape = matrix.shape |
|
|
if matrix.dim() > 2: |
|
|
matrix = matrix.view(-1, matrix.shape[-1]) |
|
|
|
|
|
if matrix.shape[0] >= matrix.shape[1]: |
|
|
|
|
|
X = matrix.clone() |
|
|
for _ in range(num_iterations): |
|
|
A = X.T @ X |
|
|
X = X @ (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A) |
|
|
else: |
|
|
|
|
|
X = matrix.clone() |
|
|
for _ in range(num_iterations): |
|
|
A = X @ X.T |
|
|
X = (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A) @ X |
|
|
|
|
|
return X.view(original_shape) |
|
|
|
|
|
def _orthogonalize_svd(self, matrix: torch.Tensor) -> torch.Tensor: |
|
|
"""Orthogonalize matrix using SVD decomposition.""" |
|
|
original_shape = matrix.shape |
|
|
if matrix.dim() > 2: |
|
|
matrix = matrix.view(-1, matrix.shape[-1]) |
|
|
|
|
|
try: |
|
|
U, _, Vt = torch.linalg.svd(matrix, full_matrices=False) |
|
|
orthogonal = U @ Vt |
|
|
return orthogonal.view(original_shape) |
|
|
except torch._C._LinAlgError: |
|
|
|
|
|
warnings.warn("SVD failed, falling back to Newton-Schulz") |
|
|
return self._orthogonalize_newtonschulz(matrix) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Perform a single optimization step.""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
for p in group["params"]: |
|
|
if p.grad is None: |
|
|
continue |
|
|
|
|
|
grad = p.grad |
|
|
if grad.dtype in {torch.float16, torch.bfloat16}: |
|
|
grad = grad.float() |
|
|
|
|
|
state = self.state[p] |
|
|
|
|
|
|
|
|
if len(state) == 0: |
|
|
state["step"] = 0 |
|
|
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
|
|
|
momentum_buffer = state["momentum_buffer"] |
|
|
state["step"] += 1 |
|
|
|
|
|
|
|
|
if group["weight_decay"] != 0: |
|
|
grad = grad.add(p, alpha=group["weight_decay"]) |
|
|
|
|
|
|
|
|
momentum_buffer.mul_(group["momentum"]).add_(grad) |
|
|
|
|
|
|
|
|
if state["step"] % group["update_period"] == 0 and momentum_buffer.numel() > 1: |
|
|
|
|
|
if momentum_buffer.dim() >= 2 and min(momentum_buffer.shape[-2:]) > 1: |
|
|
if group["backend"] == "newtonschulz": |
|
|
orthogonal_momentum = self._orthogonalize_newtonschulz(momentum_buffer) |
|
|
else: |
|
|
orthogonal_momentum = self._orthogonalize_svd(momentum_buffer) |
|
|
|
|
|
|
|
|
rank_ratio = torch.linalg.matrix_norm(orthogonal_momentum) / torch.linalg.matrix_norm(momentum_buffer) |
|
|
if rank_ratio < group["rank_deficiency_threshold"]: |
|
|
warnings.warn("Detected rank deficiency in momentum buffer") |
|
|
else: |
|
|
momentum_buffer.copy_(orthogonal_momentum) |
|
|
|
|
|
|
|
|
if group["nesterov"]: |
|
|
update = grad.add(momentum_buffer, alpha=group["momentum"]) |
|
|
else: |
|
|
update = momentum_buffer |
|
|
|
|
|
|
|
|
p.add_(update, alpha=-group["lr"]) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def configure_muon_optimizer( |
|
|
model: torch.nn.Module, |
|
|
lr: float = 1e-3, |
|
|
momentum: float = 0.95, |
|
|
weight_decay: float = 0.01, |
|
|
total_steps: Optional[int] = None, |
|
|
warmup_ratio: float = 0.1, |
|
|
nesterov: bool = False, |
|
|
backend: str = "newtonschulz", |
|
|
**muon_kwargs |
|
|
) -> Tuple[Muon, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
|
|
""" |
|
|
Configure Muon optimizer with OneCycle learning rate schedule. |
|
|
|
|
|
This function provides a drop-in replacement for BitTransformerLM's |
|
|
configure_optimizer function, using Muon instead of AdamW. |
|
|
|
|
|
Args: |
|
|
model: PyTorch model to optimize |
|
|
lr: Peak learning rate |
|
|
momentum: Momentum factor for Muon |
|
|
weight_decay: Weight decay coefficient |
|
|
total_steps: Total training steps for OneCycle schedule |
|
|
warmup_ratio: Fraction of steps for warmup |
|
|
nesterov: Enable Nesterov momentum |
|
|
backend: Orthogonalization backend |
|
|
**muon_kwargs: Additional arguments for Muon optimizer |
|
|
|
|
|
Returns: |
|
|
Tuple of (optimizer, scheduler) |
|
|
""" |
|
|
|
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
|
|
|
if param.dim() >= 2: |
|
|
decay_params.append(param) |
|
|
else: |
|
|
no_decay_params.append(param) |
|
|
|
|
|
param_groups = [ |
|
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
|
{"params": no_decay_params, "weight_decay": 0.0}, |
|
|
] |
|
|
|
|
|
optimizer = Muon( |
|
|
param_groups, |
|
|
lr=lr, |
|
|
momentum=momentum, |
|
|
nesterov=nesterov, |
|
|
backend=backend, |
|
|
**muon_kwargs |
|
|
) |
|
|
|
|
|
scheduler = None |
|
|
if total_steps is not None and total_steps > 0: |
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=lr, |
|
|
total_steps=total_steps, |
|
|
pct_start=warmup_ratio, |
|
|
anneal_strategy='cos', |
|
|
cycle_momentum=False, |
|
|
div_factor=25.0, |
|
|
final_div_factor=1e4, |
|
|
) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
|
|
|
def create_muon_training_config( |
|
|
lr: float = 1e-3, |
|
|
momentum: float = 0.95, |
|
|
weight_decay: float = 0.01, |
|
|
backend: str = "newtonschulz", |
|
|
nesterov: bool = False, |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Create a training configuration dictionary for Muon optimizer. |
|
|
|
|
|
This can be used with BitTransformerLM's training scripts by passing |
|
|
the config to the training loop. |
|
|
|
|
|
Args: |
|
|
lr: Learning rate |
|
|
momentum: Momentum factor |
|
|
weight_decay: Weight decay coefficient |
|
|
backend: Orthogonalization backend |
|
|
nesterov: Enable Nesterov momentum |
|
|
**kwargs: Additional configuration options |
|
|
|
|
|
Returns: |
|
|
Dictionary containing training configuration |
|
|
""" |
|
|
config = { |
|
|
"optimizer_type": "muon", |
|
|
"optimizer_config": { |
|
|
"lr": lr, |
|
|
"momentum": momentum, |
|
|
"weight_decay": weight_decay, |
|
|
"backend": backend, |
|
|
"nesterov": nesterov, |
|
|
**kwargs |
|
|
}, |
|
|
"scheduler_type": "onecycle", |
|
|
} |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
|
|
|
def integrate_with_bittransformerlm(): |
|
|
""" |
|
|
Example of how to integrate Muon optimizer with BitTransformerLM training. |
|
|
|
|
|
Usage: |
|
|
from BTLM_Extensions.muon_optimizer import configure_muon_optimizer |
|
|
|
|
|
# Replace the standard optimizer configuration |
|
|
optimizer, scheduler = configure_muon_optimizer( |
|
|
model, lr=1e-3, momentum=0.95, total_steps=1000 |
|
|
) |
|
|
|
|
|
# Use in training loop |
|
|
train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
model = nn.Sequential( |
|
|
nn.Linear(10, 20), |
|
|
nn.ReLU(), |
|
|
nn.Linear(20, 1) |
|
|
) |
|
|
|
|
|
optimizer, scheduler = configure_muon_optimizer(model, lr=1e-3, total_steps=100) |
|
|
|
|
|
|
|
|
x = torch.randn(32, 10) |
|
|
y = torch.randn(32, 1) |
|
|
|
|
|
pred = model(x) |
|
|
loss = nn.functional.mse_loss(pred, y) |
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
if scheduler: |
|
|
scheduler.step() |
|
|
|
|
|
print("Muon optimizer test completed successfully!") |
|
|
print(f"Loss: {loss.item():.4f}") |