WCNegentropy commited on
Commit
6ddf8d6
·
verified ·
1 Parent(s): d580d32

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
bit_transformer/BTLM_Extensions/adafactor_optimizer.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adafactor Optimizer for BitTransformerLM Extensions
3
+ ===================================================
4
+
5
+ Implementation of the Adafactor optimizer with memory-efficient factorization.
6
+ Based on "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" research.
7
+
8
+ Key features:
9
+ - Factorized second moment estimates for memory efficiency
10
+ - Automatic scaling of learning rates
11
+ - Relative step size and clip threshold
12
+ - Compatible with BitTransformerLM's training infrastructure
13
+ """
14
+
15
+ import math
16
+ import torch
17
+ from torch.optim.optimizer import Optimizer
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+
20
+
21
+ class Adafactor(Optimizer):
22
+ """
23
+ Adafactor optimizer implementation.
24
+
25
+ Adafactor reduces memory usage by factorizing the second moment estimates
26
+ for parameters with 2 or more dimensions, making it highly memory efficient
27
+ for large transformer models.
28
+
29
+ Args:
30
+ params: Iterable of parameters to optimize
31
+ lr: External learning rate (default: None, uses automatic scaling)
32
+ eps2: Regularization constant for second moment (default: 1e-30)
33
+ cliping_threshold: Threshold for adaptive clipping (default: 1.0)
34
+ decay_rate: Coefficient used for computing running averages (default: -0.8)
35
+ beta1: Coefficient used for computing running averages of gradient (default: None)
36
+ weight_decay: Weight decay coefficient (default: 0.0)
37
+ scale_parameter: If True, learning rate is scaled by root mean square of parameter (default: True)
38
+ relative_step_size: If True, use relative step size (default: True)
39
+ warmup_init: If True, warmup learning rate (default: False)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ params,
45
+ lr: Optional[float] = None,
46
+ eps2: float = 1e-30,
47
+ cliping_threshold: float = 1.0,
48
+ decay_rate: float = -0.8,
49
+ beta1: Optional[float] = None,
50
+ weight_decay: float = 0.0,
51
+ scale_parameter: bool = True,
52
+ relative_step_size: bool = True,
53
+ warmup_init: bool = False,
54
+ ):
55
+ if lr is not None and lr <= 0.0:
56
+ raise ValueError(f"Invalid learning rate: {lr}")
57
+ if weight_decay < 0.0:
58
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
59
+
60
+ defaults = dict(
61
+ lr=lr,
62
+ eps2=eps2,
63
+ cliping_threshold=cliping_threshold,
64
+ decay_rate=decay_rate,
65
+ beta1=beta1,
66
+ weight_decay=weight_decay,
67
+ scale_parameter=scale_parameter,
68
+ relative_step_size=relative_step_size,
69
+ warmup_init=warmup_init,
70
+ )
71
+ super().__init__(params, defaults)
72
+
73
+ def _get_lr(self, param_group, param_state):
74
+ """Compute learning rate for parameter group."""
75
+ min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
76
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
77
+ param_scale = 1.0
78
+ if param_group["scale_parameter"]:
79
+ param_scale = max(param_group["eps2"], param_state["RMS"])
80
+ return param_scale * rel_step_sz
81
+
82
+ def _get_options(self, param_group, param_shape):
83
+ """Get optimization options for parameter."""
84
+ factored = len(param_shape) >= 2
85
+ use_first_moment = param_group["beta1"] is not None
86
+ return factored, use_first_moment
87
+
88
+ def _rms(self, tensor):
89
+ """Root mean square."""
90
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
91
+
92
+ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
93
+ """Approximation of exponential moving average of square of gradient."""
94
+ r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
95
+ .rsqrt_())
96
+ c_factor = ((exp_avg_sq_col).rsqrt())
97
+ return torch.mul(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))
98
+
99
+ @torch.no_grad()
100
+ def step(self, closure=None):
101
+ """Perform a single optimization step."""
102
+ loss = None
103
+ if closure is not None:
104
+ with torch.enable_grad():
105
+ loss = closure()
106
+
107
+ for group in self.param_groups:
108
+ for p in group["params"]:
109
+ if p.grad is None:
110
+ continue
111
+
112
+ grad = p.grad
113
+ if grad.dtype in {torch.float16, torch.bfloat16}:
114
+ grad = grad.float()
115
+
116
+ state = self.state[p]
117
+ grad_shape = grad.shape
118
+
119
+ factored, use_first_moment = self._get_options(group, grad_shape)
120
+
121
+ # State Initialization
122
+ if len(state) == 0:
123
+ state["step"] = 0
124
+
125
+ if use_first_moment:
126
+ # Exponential moving average of gradient values
127
+ state["exp_avg"] = torch.zeros_like(grad).float()
128
+ if factored:
129
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).float()
130
+ state["exp_avg_sq_col"] = torch.zeros(
131
+ grad_shape[:-2] + grad_shape[-1:]).float()
132
+ else:
133
+ state["exp_avg_sq"] = torch.zeros_like(grad).float()
134
+
135
+ state["RMS"] = 0
136
+
137
+ p_data_fp32 = p.data
138
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
139
+ p_data_fp32 = p_data_fp32.float()
140
+
141
+ state["step"] += 1
142
+ state["RMS"] = self._rms(p_data_fp32)
143
+
144
+ lr = group["lr"]
145
+ if group["lr"] is None:
146
+ lr = self._get_lr(group, state)
147
+
148
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
149
+ update = grad**2 + group["eps2"]
150
+
151
+ if factored:
152
+ exp_avg_sq_row = state["exp_avg_sq_row"]
153
+ exp_avg_sq_col = state["exp_avg_sq_col"]
154
+
155
+ exp_avg_sq_row.mul_(beta2t).add_(
156
+ update.mean(dim=-1), alpha=1.0 - beta2t)
157
+ exp_avg_sq_col.mul_(beta2t).add_(
158
+ update.mean(dim=-2), alpha=1.0 - beta2t)
159
+
160
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
161
+ update.mul_(grad)
162
+ else:
163
+ exp_avg_sq = state["exp_avg_sq"]
164
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
165
+ update = exp_avg_sq.rsqrt().mul_(grad)
166
+
167
+ update.div_(max(1.0, self._rms(update) / group["cliping_threshold"]))
168
+
169
+ if use_first_moment:
170
+ exp_avg = state["exp_avg"]
171
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"])
172
+ update = exp_avg
173
+
174
+ if group["weight_decay"] != 0:
175
+ p_data_fp32.mul_(1 - group["weight_decay"] * lr)
176
+
177
+ p_data_fp32.add_(update, alpha=-lr)
178
+
179
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
180
+ p.data.copy_(p_data_fp32)
181
+
182
+ return loss
183
+
184
+
185
+ def configure_adafactor_optimizer(
186
+ model: torch.nn.Module,
187
+ lr: Optional[float] = None,
188
+ weight_decay: float = 0.0,
189
+ total_steps: Optional[int] = None,
190
+ warmup_ratio: float = 0.1,
191
+ scale_parameter: bool = True,
192
+ relative_step_size: bool = True,
193
+ warmup_init: bool = False,
194
+ cliping_threshold: float = 1.0,
195
+ decay_rate: float = -0.8,
196
+ beta1: Optional[float] = None,
197
+ eps2: float = 1e-30,
198
+ **adafactor_kwargs
199
+ ) -> Tuple[Adafactor, Optional[torch.optim.lr_scheduler._LRScheduler]]:
200
+ """
201
+ Configure Adafactor optimizer with optional learning rate scheduling.
202
+
203
+ This function provides a drop-in replacement for BitTransformerLM's
204
+ configure_optimizer function, using Adafactor instead of AdamW.
205
+
206
+ Args:
207
+ model: PyTorch model to optimize
208
+ lr: External learning rate (None for automatic scaling)
209
+ weight_decay: Weight decay coefficient
210
+ total_steps: Total training steps for scheduling
211
+ warmup_ratio: Fraction of steps for warmup
212
+ scale_parameter: Whether to scale learning rate by parameter RMS
213
+ relative_step_size: Whether to use relative step size
214
+ warmup_init: Whether to use warmup initialization
215
+ cliping_threshold: Threshold for adaptive clipping
216
+ decay_rate: Decay rate for second moment estimates
217
+ beta1: Coefficient for first moment (None to disable)
218
+ eps2: Regularization constant
219
+ **adafactor_kwargs: Additional arguments for Adafactor
220
+
221
+ Returns:
222
+ Tuple of (optimizer, scheduler)
223
+ """
224
+ # Adafactor can handle all parameters in one group efficiently
225
+ params = [p for p in model.parameters() if p.requires_grad]
226
+
227
+ optimizer = Adafactor(
228
+ params,
229
+ lr=lr,
230
+ weight_decay=weight_decay,
231
+ scale_parameter=scale_parameter,
232
+ relative_step_size=relative_step_size,
233
+ warmup_init=warmup_init,
234
+ cliping_threshold=cliping_threshold,
235
+ decay_rate=decay_rate,
236
+ beta1=beta1,
237
+ eps2=eps2,
238
+ **adafactor_kwargs
239
+ )
240
+
241
+ scheduler = None
242
+ # Adafactor has built-in learning rate scaling, but we can still use OneCycle
243
+ if total_steps is not None and total_steps > 0 and lr is not None:
244
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
245
+ optimizer,
246
+ max_lr=lr,
247
+ total_steps=total_steps,
248
+ pct_start=warmup_ratio,
249
+ anneal_strategy='cos',
250
+ cycle_momentum=False, # Adafactor doesn't use momentum cycling
251
+ div_factor=25.0,
252
+ final_div_factor=1e4,
253
+ )
254
+
255
+ return optimizer, scheduler
256
+
257
+
258
+ class AdafactorScheduler(torch.optim.lr_scheduler._LRScheduler):
259
+ """
260
+ Custom scheduler for Adafactor with warmup and polynomial decay.
261
+
262
+ This scheduler is specifically designed to work with Adafactor's
263
+ relative step size feature.
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ optimizer: Adafactor,
269
+ warmup_steps: int = 1000,
270
+ total_steps: Optional[int] = None,
271
+ min_lr_ratio: float = 0.1,
272
+ polynomial_power: float = 1.0,
273
+ last_epoch: int = -1,
274
+ ):
275
+ self.warmup_steps = warmup_steps
276
+ self.total_steps = total_steps
277
+ self.min_lr_ratio = min_lr_ratio
278
+ self.polynomial_power = polynomial_power
279
+ super().__init__(optimizer, last_epoch)
280
+
281
+ def get_lr(self):
282
+ step = self.last_epoch + 1
283
+
284
+ if step < self.warmup_steps:
285
+ # Linear warmup
286
+ return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs]
287
+
288
+ if self.total_steps is None:
289
+ # No decay after warmup
290
+ return self.base_lrs
291
+
292
+ # Polynomial decay
293
+ progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
294
+ progress = min(progress, 1.0)
295
+ decay_factor = (1 - progress) ** self.polynomial_power
296
+ decay_factor = max(decay_factor, self.min_lr_ratio)
297
+
298
+ return [base_lr * decay_factor for base_lr in self.base_lrs]
299
+
300
+
301
+ def configure_adafactor_with_scheduler(
302
+ model: torch.nn.Module,
303
+ lr: float = 1e-3,
304
+ warmup_steps: int = 1000,
305
+ total_steps: Optional[int] = None,
306
+ weight_decay: float = 0.0,
307
+ **kwargs
308
+ ) -> Tuple[Adafactor, AdafactorScheduler]:
309
+ """
310
+ Configure Adafactor optimizer with custom Adafactor scheduler.
311
+
312
+ Args:
313
+ model: PyTorch model to optimize
314
+ lr: Base learning rate
315
+ warmup_steps: Number of warmup steps
316
+ total_steps: Total training steps
317
+ weight_decay: Weight decay coefficient
318
+ **kwargs: Additional arguments for Adafactor
319
+
320
+ Returns:
321
+ Tuple of (optimizer, scheduler)
322
+ """
323
+ params = [p for p in model.parameters() if p.requires_grad]
324
+
325
+ optimizer = Adafactor(
326
+ params,
327
+ lr=lr,
328
+ weight_decay=weight_decay,
329
+ relative_step_size=False, # We'll use external scheduler
330
+ **kwargs
331
+ )
332
+
333
+ scheduler = AdafactorScheduler(
334
+ optimizer,
335
+ warmup_steps=warmup_steps,
336
+ total_steps=total_steps,
337
+ )
338
+
339
+ return optimizer, scheduler
340
+
341
+
342
+ def create_adafactor_training_config(
343
+ lr: Optional[float] = None,
344
+ weight_decay: float = 0.0,
345
+ scale_parameter: bool = True,
346
+ relative_step_size: bool = True,
347
+ warmup_init: bool = False,
348
+ **kwargs
349
+ ) -> Dict[str, Any]:
350
+ """
351
+ Create a training configuration dictionary for Adafactor optimizer.
352
+
353
+ Args:
354
+ lr: External learning rate (None for automatic)
355
+ weight_decay: Weight decay coefficient
356
+ scale_parameter: Whether to scale by parameter RMS
357
+ relative_step_size: Whether to use relative step size
358
+ warmup_init: Whether to use warmup initialization
359
+ **kwargs: Additional configuration options
360
+
361
+ Returns:
362
+ Dictionary containing training configuration
363
+ """
364
+ config = {
365
+ "optimizer_type": "adafactor",
366
+ "optimizer_config": {
367
+ "lr": lr,
368
+ "weight_decay": weight_decay,
369
+ "scale_parameter": scale_parameter,
370
+ "relative_step_size": relative_step_size,
371
+ "warmup_init": warmup_init,
372
+ **kwargs
373
+ },
374
+ "scheduler_type": "adafactor_custom" if lr is None else "onecycle",
375
+ }
376
+
377
+ return config
378
+
379
+
380
+ # Example usage and integration helpers
381
+ def integrate_with_bittransformerlm():
382
+ """
383
+ Example of how to integrate Adafactor optimizer with BitTransformerLM training.
384
+
385
+ Usage:
386
+ from BTLM_Extensions.adafactor_optimizer import configure_adafactor_optimizer
387
+
388
+ # Option 1: Use Adafactor with automatic learning rate scaling
389
+ optimizer, scheduler = configure_adafactor_optimizer(
390
+ model, lr=None, total_steps=1000 # lr=None enables auto-scaling
391
+ )
392
+
393
+ # Option 2: Use Adafactor with fixed learning rate
394
+ optimizer, scheduler = configure_adafactor_optimizer(
395
+ model, lr=1e-3, total_steps=1000
396
+ )
397
+
398
+ # Option 3: Use Adafactor with custom scheduler
399
+ from BTLM_Extensions.adafactor_optimizer import configure_adafactor_with_scheduler
400
+
401
+ optimizer, scheduler = configure_adafactor_with_scheduler(
402
+ model, lr=1e-3, warmup_steps=100, total_steps=1000
403
+ )
404
+
405
+ # Use in training loop
406
+ train_loop(model, data, optimizer=optimizer, scheduler=scheduler)
407
+ """
408
+ pass
409
+
410
+
411
+ def analyze_memory_usage(model: torch.nn.Module) -> Dict[str, float]:
412
+ """
413
+ Analyze memory usage comparison between optimizers.
414
+
415
+ Args:
416
+ model: PyTorch model to analyze
417
+
418
+ Returns:
419
+ Dictionary with memory usage estimates in MB
420
+ """
421
+ param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
422
+ param_bytes = param_count * 4 # Assume float32
423
+
424
+ # AdamW memory: parameters + gradients + 2 momentum states
425
+ adamw_memory = param_bytes * 4
426
+
427
+ # Adafactor memory estimation
428
+ adafactor_memory = param_bytes # parameters
429
+ adafactor_memory += param_bytes # gradients
430
+
431
+ # For factored parameters (2D), Adafactor stores row and column means
432
+ factored_params = 0
433
+ unfactored_params = 0
434
+
435
+ for p in model.parameters():
436
+ if p.requires_grad:
437
+ if len(p.shape) >= 2:
438
+ factored_params += p.shape[0] + p.shape[1] # row + col means
439
+ else:
440
+ unfactored_params += p.numel()
441
+
442
+ adafactor_memory += (factored_params + unfactored_params) * 4 # second moments
443
+
444
+ return {
445
+ "adamw_mb": adamw_memory / (1024 * 1024),
446
+ "adafactor_mb": adafactor_memory / (1024 * 1024),
447
+ "savings_mb": (adamw_memory - adafactor_memory) / (1024 * 1024),
448
+ "savings_percent": ((adamw_memory - adafactor_memory) / adamw_memory) * 100,
449
+ }
450
+
451
+
452
+ if __name__ == "__main__":
453
+ # Simple test of the optimizer
454
+ import torch.nn as nn
455
+
456
+ model = nn.Sequential(
457
+ nn.Linear(100, 200),
458
+ nn.ReLU(),
459
+ nn.Linear(200, 50),
460
+ nn.ReLU(),
461
+ nn.Linear(50, 1)
462
+ )
463
+
464
+ print("Testing Adafactor optimizer...")
465
+
466
+ # Test with automatic learning rate
467
+ optimizer, scheduler = configure_adafactor_optimizer(
468
+ model, lr=None, total_steps=100
469
+ )
470
+
471
+ # Simple training step
472
+ x = torch.randn(32, 100)
473
+ y = torch.randn(32, 1)
474
+
475
+ pred = model(x)
476
+ loss = nn.functional.mse_loss(pred, y)
477
+ initial_loss = loss.item()
478
+ loss.backward()
479
+
480
+ optimizer.step()
481
+ if scheduler:
482
+ scheduler.step()
483
+
484
+ # Test with fixed learning rate
485
+ optimizer2, scheduler2 = configure_adafactor_optimizer(
486
+ model, lr=1e-3, total_steps=100
487
+ )
488
+
489
+ pred = model(x)
490
+ loss = nn.functional.mse_loss(pred, y)
491
+ loss.backward()
492
+ optimizer2.step()
493
+ if scheduler2:
494
+ scheduler2.step()
495
+
496
+ # Analyze memory usage
497
+ memory_analysis = analyze_memory_usage(model)
498
+
499
+ print("Adafactor optimizer test completed successfully!")
500
+ print(f"Initial loss: {initial_loss:.4f}")
501
+ print(f"Final loss: {loss.item():.4f}")
502
+ print(f"Memory analysis: {memory_analysis}")