WCNegentropy commited on
Commit
99980ad
·
verified ·
1 Parent(s): 6ddf8d6

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
bit_transformer/BTLM_Extensions/lion_optimizer.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lion Optimizer for BitTransformerLM Extensions
3
+ ==============================================
4
+
5
+ Implementation of the Lion optimizer (EvoLved Sign Momentum).
6
+ Based on "Symbolic Discovery of Optimization Algorithms" research.
7
+
8
+ Key features:
9
+ - Sign-based momentum updates
10
+ - Extremely memory efficient (only stores momentum)
11
+ - Often outperforms Adam/AdamW with larger learning rates
12
+ - Compatible with BitTransformerLM's training infrastructure
13
+ """
14
+
15
+ import torch
16
+ from torch.optim.optimizer import Optimizer
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+
20
+ class Lion(Optimizer):
21
+ """
22
+ Lion optimizer implementation.
23
+
24
+ Lion uses the sign of the interpolated momentum for parameter updates,
25
+ making it very memory efficient while maintaining competitive performance.
26
+
27
+ Args:
28
+ params: Iterable of parameters to optimize
29
+ lr: Learning rate (default: 1e-4, typically needs to be smaller than Adam)
30
+ betas: Coefficients for computing momentum (default: (0.9, 0.99))
31
+ weight_decay: Weight decay coefficient (default: 0.0)
32
+ eps: Small constant for numerical stability (default: 1e-8)
33
+ maximize: Whether to maximize the objective (default: False)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ params,
39
+ lr: float = 1e-4,
40
+ betas: Tuple[float, float] = (0.9, 0.99),
41
+ weight_decay: float = 0.0,
42
+ eps: float = 1e-8,
43
+ maximize: bool = False,
44
+ ):
45
+ if not 0.0 <= lr:
46
+ raise ValueError(f"Invalid learning rate: {lr}")
47
+ if not 0.0 <= betas[0] < 1.0:
48
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
49
+ if not 0.0 <= betas[1] < 1.0:
50
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
51
+ if not 0.0 <= weight_decay:
52
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
53
+ if not 0.0 <= eps:
54
+ raise ValueError(f"Invalid epsilon value: {eps}")
55
+
56
+ defaults = dict(
57
+ lr=lr,
58
+ betas=betas,
59
+ weight_decay=weight_decay,
60
+ eps=eps,
61
+ maximize=maximize,
62
+ )
63
+ super().__init__(params, defaults)
64
+
65
+ @torch.no_grad()
66
+ def step(self, closure=None):
67
+ """Perform a single optimization step."""
68
+ loss = None
69
+ if closure is not None:
70
+ with torch.enable_grad():
71
+ loss = closure()
72
+
73
+ for group in self.param_groups:
74
+ for p in group["params"]:
75
+ if p.grad is None:
76
+ continue
77
+
78
+ grad = p.grad
79
+ if group["maximize"]:
80
+ grad = -grad
81
+
82
+ if grad.dtype in {torch.float16, torch.bfloat16}:
83
+ grad = grad.float()
84
+
85
+ state = self.state[p]
86
+
87
+ # State initialization
88
+ if len(state) == 0:
89
+ state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format)
90
+
91
+ momentum = state["momentum"]
92
+ beta1, beta2 = group["betas"]
93
+
94
+ # Weight decay (applied to parameters, not gradients)
95
+ if group["weight_decay"] != 0:
96
+ p.mul_(1 - group["lr"] * group["weight_decay"])
97
+
98
+ # Interpolate between momentum and gradient
99
+ # c_t = beta1 * m_{t-1} + (1 - beta1) * g_t
100
+ interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1)
101
+
102
+ # Update parameters using sign of interpolated momentum
103
+ # theta_t = theta_{t-1} - lr * sign(c_t)
104
+ p.add_(torch.sign(interpolated), alpha=-group["lr"])
105
+
106
+ # Update momentum
107
+ # m_t = beta2 * m_{t-1} + (1 - beta2) * g_t
108
+ momentum.mul_(beta2).add_(grad, alpha=1 - beta2)
109
+
110
+ return loss
111
+
112
+
113
+ def configure_lion_optimizer(
114
+ model: torch.nn.Module,
115
+ lr: float = 1e-4,
116
+ betas: Tuple[float, float] = (0.9, 0.99),
117
+ weight_decay: float = 0.01,
118
+ total_steps: Optional[int] = None,
119
+ warmup_ratio: float = 0.1,
120
+ **lion_kwargs
121
+ ) -> Tuple[Lion, Optional[torch.optim.lr_scheduler._LRScheduler]]:
122
+ """
123
+ Configure Lion optimizer with OneCycle learning rate schedule.
124
+
125
+ This function provides a drop-in replacement for BitTransformerLM's
126
+ configure_optimizer function, using Lion instead of AdamW.
127
+
128
+ Note: Lion typically works well with learning rates about 3-10x smaller
129
+ than Adam/AdamW, but higher weight decay (0.01-0.1).
130
+
131
+ Args:
132
+ model: PyTorch model to optimize
133
+ lr: Peak learning rate (typically smaller than Adam)
134
+ betas: Beta coefficients for momentum computation
135
+ weight_decay: Weight decay coefficient (can be higher than Adam)
136
+ total_steps: Total training steps for OneCycle schedule
137
+ warmup_ratio: Fraction of steps for warmup
138
+ **lion_kwargs: Additional arguments for Lion optimizer
139
+
140
+ Returns:
141
+ Tuple of (optimizer, scheduler)
142
+ """
143
+ # Filter parameters that need weight decay
144
+ decay_params = []
145
+ no_decay_params = []
146
+
147
+ for name, param in model.named_parameters():
148
+ if not param.requires_grad:
149
+ continue
150
+ # Apply weight decay to weights but not biases/norms
151
+ if param.dim() >= 2:
152
+ decay_params.append(param)
153
+ else:
154
+ no_decay_params.append(param)
155
+
156
+ param_groups = [
157
+ {"params": decay_params, "weight_decay": weight_decay},
158
+ {"params": no_decay_params, "weight_decay": 0.0},
159
+ ]
160
+
161
+ optimizer = Lion(
162
+ param_groups,
163
+ lr=lr,
164
+ betas=betas,
165
+ **lion_kwargs
166
+ )
167
+
168
+ scheduler = None
169
+ if total_steps is not None and total_steps > 0:
170
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
171
+ optimizer,
172
+ max_lr=lr,
173
+ total_steps=total_steps,
174
+ pct_start=warmup_ratio,
175
+ anneal_strategy='cos',
176
+ cycle_momentum=False, # Lion doesn't use cycling momentum
177
+ div_factor=25.0,
178
+ final_div_factor=1e4,
179
+ )
180
+
181
+ return optimizer, scheduler
182
+
183
+
184
+ def create_lion_training_config(
185
+ lr: float = 1e-4,
186
+ betas: Tuple[float, float] = (0.9, 0.99),
187
+ weight_decay: float = 0.01,
188
+ **kwargs
189
+ ) -> Dict[str, Any]:
190
+ """
191
+ Create a training configuration dictionary for Lion optimizer.
192
+
193
+ This can be used with BitTransformerLM's training scripts by passing
194
+ the config to the training loop.
195
+
196
+ Args:
197
+ lr: Learning rate
198
+ betas: Beta coefficients for momentum
199
+ weight_decay: Weight decay coefficient
200
+ **kwargs: Additional configuration options
201
+
202
+ Returns:
203
+ Dictionary containing training configuration
204
+ """
205
+ config = {
206
+ "optimizer_type": "lion",
207
+ "optimizer_config": {
208
+ "lr": lr,
209
+ "betas": betas,
210
+ "weight_decay": weight_decay,
211
+ **kwargs
212
+ },
213
+ "scheduler_type": "onecycle",
214
+ }
215
+
216
+ return config
217
+
218
+
219
+ class AdaptiveLion(Lion):
220
+ """
221
+ Enhanced Lion optimizer with adaptive learning rate scaling.
222
+
223
+ This variant automatically adjusts the learning rate based on the
224
+ magnitude of gradients and momentum, potentially improving stability.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ params,
230
+ lr: float = 1e-4,
231
+ betas: Tuple[float, float] = (0.9, 0.99),
232
+ weight_decay: float = 0.0,
233
+ eps: float = 1e-8,
234
+ maximize: bool = False,
235
+ adaptive_scale: float = 0.1,
236
+ min_scale: float = 0.01,
237
+ max_scale: float = 10.0,
238
+ ):
239
+ """
240
+ Args:
241
+ adaptive_scale: Scaling factor for adaptive adjustment
242
+ min_scale: Minimum learning rate scale
243
+ max_scale: Maximum learning rate scale
244
+ """
245
+ self.adaptive_scale = adaptive_scale
246
+ self.min_scale = min_scale
247
+ self.max_scale = max_scale
248
+
249
+ super().__init__(params, lr, betas, weight_decay, eps, maximize)
250
+
251
+ @torch.no_grad()
252
+ def step(self, closure=None):
253
+ """Perform optimization step with adaptive scaling."""
254
+ loss = None
255
+ if closure is not None:
256
+ with torch.enable_grad():
257
+ loss = closure()
258
+
259
+ for group in self.param_groups:
260
+ for p in group["params"]:
261
+ if p.grad is None:
262
+ continue
263
+
264
+ grad = p.grad
265
+ if group["maximize"]:
266
+ grad = -grad
267
+
268
+ if grad.dtype in {torch.float16, torch.bfloat16}:
269
+ grad = grad.float()
270
+
271
+ state = self.state[p]
272
+
273
+ if len(state) == 0:
274
+ state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format)
275
+ state["step"] = 0
276
+
277
+ momentum = state["momentum"]
278
+ state["step"] += 1
279
+ beta1, beta2 = group["betas"]
280
+
281
+ # Adaptive learning rate based on gradient magnitude
282
+ grad_norm = grad.norm().item()
283
+ momentum_norm = momentum.norm().item()
284
+
285
+ # Scale learning rate based on gradient/momentum ratio
286
+ if momentum_norm > 1e-8:
287
+ scale = 1.0 + self.adaptive_scale * (grad_norm / momentum_norm - 1.0)
288
+ scale = torch.clamp(torch.tensor(scale), self.min_scale, self.max_scale).item()
289
+ else:
290
+ scale = 1.0
291
+
292
+ adaptive_lr = group["lr"] * scale
293
+
294
+ # Weight decay
295
+ if group["weight_decay"] != 0:
296
+ p.mul_(1 - adaptive_lr * group["weight_decay"])
297
+
298
+ # Lion update with adaptive learning rate
299
+ interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1)
300
+ p.add_(torch.sign(interpolated), alpha=-adaptive_lr)
301
+ momentum.mul_(beta2).add_(grad, alpha=1 - beta2)
302
+
303
+ return loss
304
+
305
+
306
+ def configure_adaptive_lion_optimizer(
307
+ model: torch.nn.Module,
308
+ lr: float = 1e-4,
309
+ adaptive_scale: float = 0.1,
310
+ **kwargs
311
+ ) -> Tuple[AdaptiveLion, Optional[torch.optim.lr_scheduler._LRScheduler]]:
312
+ """Configure AdaptiveLion optimizer with learning rate scheduling."""
313
+ # Similar to configure_lion_optimizer but with AdaptiveLion
314
+ decay_params = []
315
+ no_decay_params = []
316
+
317
+ for name, param in model.named_parameters():
318
+ if not param.requires_grad:
319
+ continue
320
+ if param.dim() >= 2:
321
+ decay_params.append(param)
322
+ else:
323
+ no_decay_params.append(param)
324
+
325
+ param_groups = [
326
+ {"params": decay_params, "weight_decay": kwargs.get("weight_decay", 0.01)},
327
+ {"params": no_decay_params, "weight_decay": 0.0},
328
+ ]
329
+
330
+ optimizer = AdaptiveLion(
331
+ param_groups,
332
+ lr=lr,
333
+ adaptive_scale=adaptive_scale,
334
+ **{k: v for k, v in kwargs.items() if k != "weight_decay"}
335
+ )
336
+
337
+ scheduler = None
338
+ total_steps = kwargs.get("total_steps")
339
+ if total_steps is not None and total_steps > 0:
340
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
341
+ optimizer,
342
+ max_lr=lr,
343
+ total_steps=total_steps,
344
+ pct_start=kwargs.get("warmup_ratio", 0.1),
345
+ anneal_strategy='cos',
346
+ cycle_momentum=False,
347
+ div_factor=25.0,
348
+ final_div_factor=1e4,
349
+ )
350
+
351
+ return optimizer, scheduler
352
+
353
+
354
+ # Example usage and integration helpers
355
+ def integrate_with_bittransformerlm():
356
+ """
357
+ Example of how to integrate Lion optimizer with BitTransformerLM training.
358
+
359
+ Usage:
360
+ from BTLM_Extensions.lion_optimizer import configure_lion_optimizer
361
+
362
+ # Replace the standard optimizer configuration
363
+ # Note: Lion typically needs smaller learning rates than Adam
364
+ optimizer, scheduler = configure_lion_optimizer(
365
+ model, lr=1e-4, weight_decay=0.01, total_steps=1000
366
+ )
367
+
368
+ # Use in training loop
369
+ train_loop(model, data, optimizer=optimizer, scheduler=scheduler)
370
+
371
+ # For adaptive version:
372
+ from BTLM_Extensions.lion_optimizer import configure_adaptive_lion_optimizer
373
+
374
+ optimizer, scheduler = configure_adaptive_lion_optimizer(
375
+ model, lr=1e-4, adaptive_scale=0.1, total_steps=1000
376
+ )
377
+ """
378
+ pass
379
+
380
+
381
+ if __name__ == "__main__":
382
+ # Simple test of the optimizer
383
+ import torch.nn as nn
384
+
385
+ model = nn.Sequential(
386
+ nn.Linear(10, 20),
387
+ nn.ReLU(),
388
+ nn.Linear(20, 1)
389
+ )
390
+
391
+ print("Testing standard Lion optimizer...")
392
+ optimizer, scheduler = configure_lion_optimizer(model, lr=1e-4, total_steps=100)
393
+
394
+ # Simple training step
395
+ x = torch.randn(32, 10)
396
+ y = torch.randn(32, 1)
397
+
398
+ pred = model(x)
399
+ loss = nn.functional.mse_loss(pred, y)
400
+ initial_loss = loss.item()
401
+ loss.backward()
402
+
403
+ optimizer.step()
404
+ if scheduler:
405
+ scheduler.step()
406
+
407
+ print(f"Initial loss: {initial_loss:.4f}")
408
+
409
+ # Test adaptive version
410
+ print("Testing Adaptive Lion optimizer...")
411
+ model2 = nn.Sequential(
412
+ nn.Linear(10, 20),
413
+ nn.ReLU(),
414
+ nn.Linear(20, 1)
415
+ )
416
+
417
+ optimizer2, scheduler2 = configure_adaptive_lion_optimizer(
418
+ model2, lr=1e-4, adaptive_scale=0.1, total_steps=100
419
+ )
420
+
421
+ pred2 = model2(x)
422
+ loss2 = nn.functional.mse_loss(pred2, y)
423
+ loss2.backward()
424
+ optimizer2.step()
425
+ if scheduler2:
426
+ scheduler2.step()
427
+
428
+ print("Lion optimizers test completed successfully!")
429
+ print(f"Standard Lion loss: {initial_loss:.4f}")
430
+ print(f"Adaptive Lion loss: {loss2.item():.4f}")