|
|
""" |
|
|
Muon Optimizer Implementation for nanoKimi |
|
|
|
|
|
Based on the Muon optimizer described in Kimi-K2 papers. |
|
|
Combines momentum with adaptive learning rates for better convergence. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.optim as optimizer |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
|
|
|
class Muon(optimizer.Optimizer): |
|
|
""" |
|
|
Muon optimizer: A momentum-based optimizer with adaptive learning rates |
|
|
|
|
|
This optimizer combines the benefits of momentum with adaptive learning rate |
|
|
scaling, designed specifically for large language model training. |
|
|
|
|
|
Args: |
|
|
params: iterable of parameters to optimize |
|
|
lr: learning rate (default: 1e-3) |
|
|
momentum: momentum factor (default: 0.9) |
|
|
weight_decay: weight decay (L2 penalty) (default: 0.01) |
|
|
eps: term added to the denominator to improve numerical stability (default: 1e-8) |
|
|
backend: backend to use ('torch' or 'triton') (default: 'torch') |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
lr: float = 1e-3, |
|
|
momentum: float = 0.9, |
|
|
weight_decay: float = 0.01, |
|
|
eps: float = 1e-8, |
|
|
backend: str = 'torch' |
|
|
): |
|
|
if not 0.0 <= lr: |
|
|
raise ValueError(f"Invalid learning rate: {lr}") |
|
|
if not 0.0 <= eps: |
|
|
raise ValueError(f"Invalid epsilon value: {eps}") |
|
|
if not 0.0 <= momentum < 1.0: |
|
|
raise ValueError(f"Invalid momentum value: {momentum}") |
|
|
if not 0.0 <= weight_decay: |
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
|
|
|
|
|
defaults = dict( |
|
|
lr=lr, |
|
|
momentum=momentum, |
|
|
weight_decay=weight_decay, |
|
|
eps=eps, |
|
|
backend=backend |
|
|
) |
|
|
super(Muon, self).__init__(params, defaults) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Performs a single optimization step""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
weight_decay = group['weight_decay'] |
|
|
momentum = group['momentum'] |
|
|
lr = group['lr'] |
|
|
eps = group['eps'] |
|
|
|
|
|
for p in group['params']: |
|
|
if p.grad is None: |
|
|
continue |
|
|
|
|
|
grad = p.grad |
|
|
if weight_decay != 0: |
|
|
grad = grad.add(p, alpha=weight_decay) |
|
|
|
|
|
param_state = self.state[p] |
|
|
|
|
|
|
|
|
if len(param_state) == 0: |
|
|
param_state['step'] = 0 |
|
|
|
|
|
param_state['exp_avg'] = torch.zeros_like(p) |
|
|
|
|
|
param_state['exp_avg_sq'] = torch.zeros_like(p) |
|
|
|
|
|
exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq'] |
|
|
param_state['step'] += 1 |
|
|
|
|
|
|
|
|
exp_avg.mul_(momentum).add_(grad, alpha=1 - momentum) |
|
|
exp_avg_sq.mul_(momentum).addcmul_(grad, grad, value=1 - momentum) |
|
|
|
|
|
|
|
|
step = param_state['step'] |
|
|
bias_correction1 = 1 - momentum ** step |
|
|
bias_correction2 = 1 - momentum ** step |
|
|
|
|
|
|
|
|
denom = (exp_avg_sq / bias_correction2).sqrt_().add_(eps) |
|
|
|
|
|
|
|
|
step_size = lr / bias_correction1 |
|
|
|
|
|
|
|
|
p.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
|
|
|
return loss |
|
|
|
|
|
def zero_grad(self, set_to_none: bool = True) -> None: |
|
|
"""Clear gradients""" |
|
|
for group in self.param_groups: |
|
|
for p in group['params']: |
|
|
if p.grad is not None: |
|
|
if set_to_none: |
|
|
p.grad = None |
|
|
else: |
|
|
if p.grad.grad_fn is not None: |
|
|
p.grad.detach_() |
|
|
else: |
|
|
p.grad.requires_grad_(False) |
|
|
p.grad.zero_() |
|
|
|
|
|
|
|
|
def create_muon_optimizer(model, config): |
|
|
"""Create Muon optimizer with the given configuration""" |
|
|
return Muon( |
|
|
model.parameters(), |
|
|
lr=config['learning_rate'], |
|
|
momentum=config['momentum'], |
|
|
weight_decay=config['weight_decay'], |
|
|
eps=config['eps'], |
|
|
backend=config.get('backend', 'torch') |
|
|
) |
|
|
|