sohv commited on
Commit
1202b75
·
verified ·
1 Parent(s): f52cfc5

Upload src/optimizer.py

Browse files
Files changed (1) hide show
  1. src/optimizer.py +135 -0
src/optimizer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon Optimizer Implementation for nanoKimi
3
+
4
+ Based on the Muon optimizer described in Kimi-K2 papers.
5
+ Combines momentum with adaptive learning rates for better convergence.
6
+ """
7
+
8
+ import torch
9
+ import torch.optim as optimizer
10
+ from typing import Any, Dict, Optional
11
+
12
+
13
+ class Muon(optimizer.Optimizer):
14
+ """
15
+ Muon optimizer: A momentum-based optimizer with adaptive learning rates
16
+
17
+ This optimizer combines the benefits of momentum with adaptive learning rate
18
+ scaling, designed specifically for large language model training.
19
+
20
+ Args:
21
+ params: iterable of parameters to optimize
22
+ lr: learning rate (default: 1e-3)
23
+ momentum: momentum factor (default: 0.9)
24
+ weight_decay: weight decay (L2 penalty) (default: 0.01)
25
+ eps: term added to the denominator to improve numerical stability (default: 1e-8)
26
+ backend: backend to use ('torch' or 'triton') (default: 'torch')
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ params,
32
+ lr: float = 1e-3,
33
+ momentum: float = 0.9,
34
+ weight_decay: float = 0.01,
35
+ eps: float = 1e-8,
36
+ backend: str = 'torch'
37
+ ):
38
+ if not 0.0 <= lr:
39
+ raise ValueError(f"Invalid learning rate: {lr}")
40
+ if not 0.0 <= eps:
41
+ raise ValueError(f"Invalid epsilon value: {eps}")
42
+ if not 0.0 <= momentum < 1.0:
43
+ raise ValueError(f"Invalid momentum value: {momentum}")
44
+ if not 0.0 <= weight_decay:
45
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
46
+
47
+ defaults = dict(
48
+ lr=lr,
49
+ momentum=momentum,
50
+ weight_decay=weight_decay,
51
+ eps=eps,
52
+ backend=backend
53
+ )
54
+ super(Muon, self).__init__(params, defaults)
55
+
56
+ @torch.no_grad()
57
+ def step(self, closure=None):
58
+ """Performs a single optimization step"""
59
+ loss = None
60
+ if closure is not None:
61
+ with torch.enable_grad():
62
+ loss = closure()
63
+
64
+ for group in self.param_groups:
65
+ weight_decay = group['weight_decay']
66
+ momentum = group['momentum']
67
+ lr = group['lr']
68
+ eps = group['eps']
69
+
70
+ for p in group['params']:
71
+ if p.grad is None:
72
+ continue
73
+
74
+ grad = p.grad
75
+ if weight_decay != 0:
76
+ grad = grad.add(p, alpha=weight_decay)
77
+
78
+ param_state = self.state[p]
79
+
80
+ # State initialization
81
+ if len(param_state) == 0:
82
+ param_state['step'] = 0
83
+ # Exponential moving average of gradient values
84
+ param_state['exp_avg'] = torch.zeros_like(p)
85
+ # Exponential moving average of squared gradient values
86
+ param_state['exp_avg_sq'] = torch.zeros_like(p)
87
+
88
+ exp_avg, exp_avg_sq = param_state['exp_avg'], param_state['exp_avg_sq']
89
+ param_state['step'] += 1
90
+
91
+ # Decay the first and second moment running average coefficient
92
+ exp_avg.mul_(momentum).add_(grad, alpha=1 - momentum)
93
+ exp_avg_sq.mul_(momentum).addcmul_(grad, grad, value=1 - momentum)
94
+
95
+ # Bias correction
96
+ step = param_state['step']
97
+ bias_correction1 = 1 - momentum ** step
98
+ bias_correction2 = 1 - momentum ** step
99
+
100
+ # Compute the denominator
101
+ denom = (exp_avg_sq / bias_correction2).sqrt_().add_(eps)
102
+
103
+ # Compute the step size
104
+ step_size = lr / bias_correction1
105
+
106
+ # Update parameters
107
+ p.addcdiv_(exp_avg, denom, value=-step_size)
108
+
109
+ return loss
110
+
111
+ def zero_grad(self, set_to_none: bool = True) -> None:
112
+ """Clear gradients"""
113
+ for group in self.param_groups:
114
+ for p in group['params']:
115
+ if p.grad is not None:
116
+ if set_to_none:
117
+ p.grad = None
118
+ else:
119
+ if p.grad.grad_fn is not None:
120
+ p.grad.detach_()
121
+ else:
122
+ p.grad.requires_grad_(False)
123
+ p.grad.zero_()
124
+
125
+
126
+ def create_muon_optimizer(model, config):
127
+ """Create Muon optimizer with the given configuration"""
128
+ return Muon(
129
+ model.parameters(),
130
+ lr=config['learning_rate'],
131
+ momentum=config['momentum'],
132
+ weight_decay=config['weight_decay'],
133
+ eps=config['eps'],
134
+ backend=config.get('backend', 'torch')
135
+ )