WCNegentropy commited on
Commit
1e5dcf2
·
verified ·
1 Parent(s): 99980ad

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
bit_transformer/BTLM_Extensions/muon_optimizer.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon Optimizer for BitTransformerLM Extensions
3
+ ==============================================
4
+
5
+ Implementation of the Muon optimizer with orthogonal momentum updates.
6
+ Based on "Muon: Momentum Orthogonalized by Newton's method" research.
7
+
8
+ Key features:
9
+ - Orthogonal momentum updates
10
+ - Better convergence properties than Adam/AdamW
11
+ - Memory efficient implementation
12
+ - Compatible with BitTransformerLM's training infrastructure
13
+ """
14
+
15
+ import math
16
+ import torch
17
+ from torch.optim.optimizer import Optimizer
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+ import warnings
20
+
21
+
22
+ class Muon(Optimizer):
23
+ """
24
+ Muon optimizer with orthogonal momentum updates.
25
+
26
+ This implementation provides momentum updates that are orthogonalized using
27
+ Newton's method, leading to more stable training dynamics.
28
+
29
+ Args:
30
+ params: Iterable of parameters to optimize
31
+ lr: Learning rate (default: 1e-3)
32
+ momentum: Momentum factor (default: 0.95)
33
+ nesterov: Enable Nesterov momentum (default: False)
34
+ backend: Backend for orthogonalization ('newtonschulz' or 'svd')
35
+ update_period: Period for updating orthogonalization (default: 1)
36
+ rank_deficiency_threshold: Threshold for rank deficiency detection
37
+ eps: Small constant for numerical stability (default: 1e-8)
38
+ weight_decay: Weight decay coefficient (default: 0.0)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ params,
44
+ lr: float = 1e-3,
45
+ momentum: float = 0.95,
46
+ nesterov: bool = False,
47
+ backend: str = "newtonschulz",
48
+ update_period: int = 1,
49
+ rank_deficiency_threshold: float = 1e-6,
50
+ eps: float = 1e-8,
51
+ weight_decay: float = 0.0,
52
+ ):
53
+ if not 0.0 <= lr:
54
+ raise ValueError(f"Invalid learning rate: {lr}")
55
+ if not 0.0 <= momentum <= 1.0:
56
+ raise ValueError(f"Invalid momentum value: {momentum}")
57
+ if not 0.0 <= weight_decay:
58
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
59
+ if backend not in ["newtonschulz", "svd"]:
60
+ raise ValueError(f"Invalid backend: {backend}")
61
+
62
+ defaults = dict(
63
+ lr=lr,
64
+ momentum=momentum,
65
+ nesterov=nesterov,
66
+ backend=backend,
67
+ update_period=update_period,
68
+ rank_deficiency_threshold=rank_deficiency_threshold,
69
+ eps=eps,
70
+ weight_decay=weight_decay,
71
+ )
72
+ super().__init__(params, defaults)
73
+
74
+ def _orthogonalize_newtonschulz(self, matrix: torch.Tensor, num_iterations: int = 5) -> torch.Tensor:
75
+ """Orthogonalize matrix using Newton-Schulz iteration."""
76
+ # Handle different shapes
77
+ original_shape = matrix.shape
78
+ if matrix.dim() > 2:
79
+ matrix = matrix.view(-1, matrix.shape[-1])
80
+
81
+ if matrix.shape[0] >= matrix.shape[1]:
82
+ # Tall matrix - orthogonalize columns
83
+ X = matrix.clone()
84
+ for _ in range(num_iterations):
85
+ A = X.T @ X
86
+ X = X @ (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A)
87
+ else:
88
+ # Wide matrix - orthogonalize rows
89
+ X = matrix.clone()
90
+ for _ in range(num_iterations):
91
+ A = X @ X.T
92
+ X = (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A) @ X
93
+
94
+ return X.view(original_shape)
95
+
96
+ def _orthogonalize_svd(self, matrix: torch.Tensor) -> torch.Tensor:
97
+ """Orthogonalize matrix using SVD decomposition."""
98
+ original_shape = matrix.shape
99
+ if matrix.dim() > 2:
100
+ matrix = matrix.view(-1, matrix.shape[-1])
101
+
102
+ try:
103
+ U, _, Vt = torch.linalg.svd(matrix, full_matrices=False)
104
+ orthogonal = U @ Vt
105
+ return orthogonal.view(original_shape)
106
+ except torch._C._LinAlgError:
107
+ # Fallback to Newton-Schulz if SVD fails
108
+ warnings.warn("SVD failed, falling back to Newton-Schulz")
109
+ return self._orthogonalize_newtonschulz(matrix)
110
+
111
+ @torch.no_grad()
112
+ def step(self, closure=None):
113
+ """Perform a single optimization step."""
114
+ loss = None
115
+ if closure is not None:
116
+ with torch.enable_grad():
117
+ loss = closure()
118
+
119
+ for group in self.param_groups:
120
+ for p in group["params"]:
121
+ if p.grad is None:
122
+ continue
123
+
124
+ grad = p.grad
125
+ if grad.dtype in {torch.float16, torch.bfloat16}:
126
+ grad = grad.float()
127
+
128
+ state = self.state[p]
129
+
130
+ # State initialization
131
+ if len(state) == 0:
132
+ state["step"] = 0
133
+ state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
134
+
135
+ momentum_buffer = state["momentum_buffer"]
136
+ state["step"] += 1
137
+
138
+ # Weight decay
139
+ if group["weight_decay"] != 0:
140
+ grad = grad.add(p, alpha=group["weight_decay"])
141
+
142
+ # Apply momentum
143
+ momentum_buffer.mul_(group["momentum"]).add_(grad)
144
+
145
+ # Orthogonalize momentum every update_period steps
146
+ if state["step"] % group["update_period"] == 0 and momentum_buffer.numel() > 1:
147
+ # Only orthogonalize if we have sufficient dimensions
148
+ if momentum_buffer.dim() >= 2 and min(momentum_buffer.shape[-2:]) > 1:
149
+ if group["backend"] == "newtonschulz":
150
+ orthogonal_momentum = self._orthogonalize_newtonschulz(momentum_buffer)
151
+ else:
152
+ orthogonal_momentum = self._orthogonalize_svd(momentum_buffer)
153
+
154
+ # Check for rank deficiency
155
+ rank_ratio = torch.linalg.matrix_norm(orthogonal_momentum) / torch.linalg.matrix_norm(momentum_buffer)
156
+ if rank_ratio < group["rank_deficiency_threshold"]:
157
+ warnings.warn("Detected rank deficiency in momentum buffer")
158
+ else:
159
+ momentum_buffer.copy_(orthogonal_momentum)
160
+
161
+ # Apply Nesterov acceleration if enabled
162
+ if group["nesterov"]:
163
+ update = grad.add(momentum_buffer, alpha=group["momentum"])
164
+ else:
165
+ update = momentum_buffer
166
+
167
+ # Apply update
168
+ p.add_(update, alpha=-group["lr"])
169
+
170
+ return loss
171
+
172
+
173
+ def configure_muon_optimizer(
174
+ model: torch.nn.Module,
175
+ lr: float = 1e-3,
176
+ momentum: float = 0.95,
177
+ weight_decay: float = 0.01,
178
+ total_steps: Optional[int] = None,
179
+ warmup_ratio: float = 0.1,
180
+ nesterov: bool = False,
181
+ backend: str = "newtonschulz",
182
+ **muon_kwargs
183
+ ) -> Tuple[Muon, Optional[torch.optim.lr_scheduler._LRScheduler]]:
184
+ """
185
+ Configure Muon optimizer with OneCycle learning rate schedule.
186
+
187
+ This function provides a drop-in replacement for BitTransformerLM's
188
+ configure_optimizer function, using Muon instead of AdamW.
189
+
190
+ Args:
191
+ model: PyTorch model to optimize
192
+ lr: Peak learning rate
193
+ momentum: Momentum factor for Muon
194
+ weight_decay: Weight decay coefficient
195
+ total_steps: Total training steps for OneCycle schedule
196
+ warmup_ratio: Fraction of steps for warmup
197
+ nesterov: Enable Nesterov momentum
198
+ backend: Orthogonalization backend
199
+ **muon_kwargs: Additional arguments for Muon optimizer
200
+
201
+ Returns:
202
+ Tuple of (optimizer, scheduler)
203
+ """
204
+ # Filter parameters that need weight decay
205
+ decay_params = []
206
+ no_decay_params = []
207
+
208
+ for name, param in model.named_parameters():
209
+ if not param.requires_grad:
210
+ continue
211
+ # Apply weight decay to weights but not biases/norms
212
+ if param.dim() >= 2:
213
+ decay_params.append(param)
214
+ else:
215
+ no_decay_params.append(param)
216
+
217
+ param_groups = [
218
+ {"params": decay_params, "weight_decay": weight_decay},
219
+ {"params": no_decay_params, "weight_decay": 0.0},
220
+ ]
221
+
222
+ optimizer = Muon(
223
+ param_groups,
224
+ lr=lr,
225
+ momentum=momentum,
226
+ nesterov=nesterov,
227
+ backend=backend,
228
+ **muon_kwargs
229
+ )
230
+
231
+ scheduler = None
232
+ if total_steps is not None and total_steps > 0:
233
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
234
+ optimizer,
235
+ max_lr=lr,
236
+ total_steps=total_steps,
237
+ pct_start=warmup_ratio,
238
+ anneal_strategy='cos',
239
+ cycle_momentum=False, # Muon handles momentum internally
240
+ div_factor=25.0,
241
+ final_div_factor=1e4,
242
+ )
243
+
244
+ return optimizer, scheduler
245
+
246
+
247
+ def create_muon_training_config(
248
+ lr: float = 1e-3,
249
+ momentum: float = 0.95,
250
+ weight_decay: float = 0.01,
251
+ backend: str = "newtonschulz",
252
+ nesterov: bool = False,
253
+ **kwargs
254
+ ) -> Dict[str, Any]:
255
+ """
256
+ Create a training configuration dictionary for Muon optimizer.
257
+
258
+ This can be used with BitTransformerLM's training scripts by passing
259
+ the config to the training loop.
260
+
261
+ Args:
262
+ lr: Learning rate
263
+ momentum: Momentum factor
264
+ weight_decay: Weight decay coefficient
265
+ backend: Orthogonalization backend
266
+ nesterov: Enable Nesterov momentum
267
+ **kwargs: Additional configuration options
268
+
269
+ Returns:
270
+ Dictionary containing training configuration
271
+ """
272
+ config = {
273
+ "optimizer_type": "muon",
274
+ "optimizer_config": {
275
+ "lr": lr,
276
+ "momentum": momentum,
277
+ "weight_decay": weight_decay,
278
+ "backend": backend,
279
+ "nesterov": nesterov,
280
+ **kwargs
281
+ },
282
+ "scheduler_type": "onecycle",
283
+ }
284
+
285
+ return config
286
+
287
+
288
+ # Example usage and integration helpers
289
+ def integrate_with_bittransformerlm():
290
+ """
291
+ Example of how to integrate Muon optimizer with BitTransformerLM training.
292
+
293
+ Usage:
294
+ from BTLM_Extensions.muon_optimizer import configure_muon_optimizer
295
+
296
+ # Replace the standard optimizer configuration
297
+ optimizer, scheduler = configure_muon_optimizer(
298
+ model, lr=1e-3, momentum=0.95, total_steps=1000
299
+ )
300
+
301
+ # Use in training loop
302
+ train_loop(model, data, optimizer=optimizer, scheduler=scheduler)
303
+ """
304
+ pass
305
+
306
+
307
+ if __name__ == "__main__":
308
+ # Simple test of the optimizer
309
+ import torch.nn as nn
310
+
311
+ model = nn.Sequential(
312
+ nn.Linear(10, 20),
313
+ nn.ReLU(),
314
+ nn.Linear(20, 1)
315
+ )
316
+
317
+ optimizer, scheduler = configure_muon_optimizer(model, lr=1e-3, total_steps=100)
318
+
319
+ # Simple training step
320
+ x = torch.randn(32, 10)
321
+ y = torch.randn(32, 1)
322
+
323
+ pred = model(x)
324
+ loss = nn.functional.mse_loss(pred, y)
325
+ loss.backward()
326
+
327
+ optimizer.step()
328
+ if scheduler:
329
+ scheduler.step()
330
+
331
+ print("Muon optimizer test completed successfully!")
332
+ print(f"Loss: {loss.item():.4f}")