File size: 4,813 Bytes
1202b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Muon Optimizer Implementation for nanoKimi

Based on the Muon optimizer described in Kimi-K2 papers.
Combines momentum with adaptive learning rates for better convergence.
"""

import torch
import torch.optim as optimizer
from typing import Any, Dict, Optional


class Muon(optimizer.Optimizer):
    """
    Muon optimizer: A momentum-based optimizer with adaptive learning rates
    
    This optimizer combines the benefits of momentum with adaptive learning rate
    scaling, designed specifically for large language model training.
    
    Args:
        params: iterable of parameters to optimize
        lr: learning rate (default: 1e-3)
        momentum: momentum factor (default: 0.9)
        weight_decay: weight decay (L2 penalty) (default: 0.01)
        eps: term added to the denominator to improve numerical stability (default: 1e-8)
        backend: backend to use ('torch' or 'triton') (default: 'torch')
    """
    
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.01,
        eps: float = 1e-8,
        backend: str = 'torch'
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= momentum < 1.0:
            raise ValueError(f"Invalid momentum value: {momentum}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        
        defaults = dict(
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            eps=eps,
            backend=backend
        )
        super(Muon, self).__init__(params, defaults)
    
    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step"""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            lr = group['lr']
            eps = group['eps']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad
                if weight_decay != 0:
                    grad = grad.add(p, alpha=weight_decay)
                
                param_state = self.state[p]
                
                # State initialization
                if len(param_state) == 0:
                    param_state['step'] = 0
                    # Exponential moving average of gradient values
                    param_state['exp_avg'] = torch.zeros_like(p)
                    # Exponential moving average of squared gradient values
                    param_state['exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq']
                param_state['step'] += 1
                
                # Decay the first and second moment running average coefficient
                exp_avg.mul_(momentum).add_(grad, alpha=1 - momentum)
                exp_avg_sq.mul_(momentum).addcmul_(grad, grad, value=1 - momentum)
                
                # Bias correction
                step = param_state['step']
                bias_correction1 = 1 - momentum ** step
                bias_correction2 = 1 - momentum ** step
                
                # Compute the denominator
                denom = (exp_avg_sq / bias_correction2).sqrt_().add_(eps)
                
                # Compute the step size
                step_size = lr / bias_correction1
                
                # Update parameters
                p.addcdiv_(exp_avg, denom, value=-step_size)
        
        return loss
    
    def zero_grad(self, set_to_none: bool = True) -> None:
        """Clear gradients"""
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    if set_to_none:
                        p.grad = None
                    else:
                        if p.grad.grad_fn is not None:
                            p.grad.detach_()
                        else:
                            p.grad.requires_grad_(False)
                        p.grad.zero_()


def create_muon_optimizer(model, config):
    """Create Muon optimizer with the given configuration"""
    return Muon(
        model.parameters(),
        lr=config['learning_rate'],
        momentum=config['momentum'],
        weight_decay=config['weight_decay'],
        eps=config['eps'],
        backend=config.get('backend', 'torch')
    )