|
|
""" |
|
|
Adafactor Optimizer for BitTransformerLM Extensions |
|
|
=================================================== |
|
|
|
|
|
Implementation of the Adafactor optimizer with memory-efficient factorization. |
|
|
Based on "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" research. |
|
|
|
|
|
Key features: |
|
|
- Factorized second moment estimates for memory efficiency |
|
|
- Automatic scaling of learning rates |
|
|
- Relative step size and clip threshold |
|
|
- Compatible with BitTransformerLM's training infrastructure |
|
|
""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
from torch.optim.optimizer import Optimizer |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
class Adafactor(Optimizer): |
|
|
""" |
|
|
Adafactor optimizer implementation. |
|
|
|
|
|
Adafactor reduces memory usage by factorizing the second moment estimates |
|
|
for parameters with 2 or more dimensions, making it highly memory efficient |
|
|
for large transformer models. |
|
|
|
|
|
Args: |
|
|
params: Iterable of parameters to optimize |
|
|
lr: External learning rate (default: None, uses automatic scaling) |
|
|
eps2: Regularization constant for second moment (default: 1e-30) |
|
|
cliping_threshold: Threshold for adaptive clipping (default: 1.0) |
|
|
decay_rate: Coefficient used for computing running averages (default: -0.8) |
|
|
beta1: Coefficient used for computing running averages of gradient (default: None) |
|
|
weight_decay: Weight decay coefficient (default: 0.0) |
|
|
scale_parameter: If True, learning rate is scaled by root mean square of parameter (default: True) |
|
|
relative_step_size: If True, use relative step size (default: True) |
|
|
warmup_init: If True, warmup learning rate (default: False) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
lr: Optional[float] = None, |
|
|
eps2: float = 1e-30, |
|
|
cliping_threshold: float = 1.0, |
|
|
decay_rate: float = -0.8, |
|
|
beta1: Optional[float] = None, |
|
|
weight_decay: float = 0.0, |
|
|
scale_parameter: bool = True, |
|
|
relative_step_size: bool = True, |
|
|
warmup_init: bool = False, |
|
|
): |
|
|
if lr is not None and lr <= 0.0: |
|
|
raise ValueError(f"Invalid learning rate: {lr}") |
|
|
if weight_decay < 0.0: |
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
|
|
|
|
|
defaults = dict( |
|
|
lr=lr, |
|
|
eps2=eps2, |
|
|
cliping_threshold=cliping_threshold, |
|
|
decay_rate=decay_rate, |
|
|
beta1=beta1, |
|
|
weight_decay=weight_decay, |
|
|
scale_parameter=scale_parameter, |
|
|
relative_step_size=relative_step_size, |
|
|
warmup_init=warmup_init, |
|
|
) |
|
|
super().__init__(params, defaults) |
|
|
|
|
|
def _get_lr(self, param_group, param_state): |
|
|
"""Compute learning rate for parameter group.""" |
|
|
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 |
|
|
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) |
|
|
param_scale = 1.0 |
|
|
if param_group["scale_parameter"]: |
|
|
param_scale = max(param_group["eps2"], param_state["RMS"]) |
|
|
return param_scale * rel_step_sz |
|
|
|
|
|
def _get_options(self, param_group, param_shape): |
|
|
"""Get optimization options for parameter.""" |
|
|
factored = len(param_shape) >= 2 |
|
|
use_first_moment = param_group["beta1"] is not None |
|
|
return factored, use_first_moment |
|
|
|
|
|
def _rms(self, tensor): |
|
|
"""Root mean square.""" |
|
|
return tensor.norm(2) / (tensor.numel() ** 0.5) |
|
|
|
|
|
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): |
|
|
"""Approximation of exponential moving average of square of gradient.""" |
|
|
r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) |
|
|
.rsqrt_()) |
|
|
c_factor = ((exp_avg_sq_col).rsqrt()) |
|
|
return torch.mul(r_factor.unsqueeze(-1), c_factor.unsqueeze(0)) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Perform a single optimization step.""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
for p in group["params"]: |
|
|
if p.grad is None: |
|
|
continue |
|
|
|
|
|
grad = p.grad |
|
|
if grad.dtype in {torch.float16, torch.bfloat16}: |
|
|
grad = grad.float() |
|
|
|
|
|
state = self.state[p] |
|
|
grad_shape = grad.shape |
|
|
|
|
|
factored, use_first_moment = self._get_options(group, grad_shape) |
|
|
|
|
|
|
|
|
if len(state) == 0: |
|
|
state["step"] = 0 |
|
|
|
|
|
if use_first_moment: |
|
|
|
|
|
state["exp_avg"] = torch.zeros_like(grad).float() |
|
|
if factored: |
|
|
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).float() |
|
|
state["exp_avg_sq_col"] = torch.zeros( |
|
|
grad_shape[:-2] + grad_shape[-1:]).float() |
|
|
else: |
|
|
state["exp_avg_sq"] = torch.zeros_like(grad).float() |
|
|
|
|
|
state["RMS"] = 0 |
|
|
|
|
|
p_data_fp32 = p.data |
|
|
if p.data.dtype in {torch.float16, torch.bfloat16}: |
|
|
p_data_fp32 = p_data_fp32.float() |
|
|
|
|
|
state["step"] += 1 |
|
|
state["RMS"] = self._rms(p_data_fp32) |
|
|
|
|
|
lr = group["lr"] |
|
|
if group["lr"] is None: |
|
|
lr = self._get_lr(group, state) |
|
|
|
|
|
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) |
|
|
update = grad**2 + group["eps2"] |
|
|
|
|
|
if factored: |
|
|
exp_avg_sq_row = state["exp_avg_sq_row"] |
|
|
exp_avg_sq_col = state["exp_avg_sq_col"] |
|
|
|
|
|
exp_avg_sq_row.mul_(beta2t).add_( |
|
|
update.mean(dim=-1), alpha=1.0 - beta2t) |
|
|
exp_avg_sq_col.mul_(beta2t).add_( |
|
|
update.mean(dim=-2), alpha=1.0 - beta2t) |
|
|
|
|
|
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
|
update.mul_(grad) |
|
|
else: |
|
|
exp_avg_sq = state["exp_avg_sq"] |
|
|
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) |
|
|
update = exp_avg_sq.rsqrt().mul_(grad) |
|
|
|
|
|
update.div_(max(1.0, self._rms(update) / group["cliping_threshold"])) |
|
|
|
|
|
if use_first_moment: |
|
|
exp_avg = state["exp_avg"] |
|
|
exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) |
|
|
update = exp_avg |
|
|
|
|
|
if group["weight_decay"] != 0: |
|
|
p_data_fp32.mul_(1 - group["weight_decay"] * lr) |
|
|
|
|
|
p_data_fp32.add_(update, alpha=-lr) |
|
|
|
|
|
if p.data.dtype in {torch.float16, torch.bfloat16}: |
|
|
p.data.copy_(p_data_fp32) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def configure_adafactor_optimizer( |
|
|
model: torch.nn.Module, |
|
|
lr: Optional[float] = None, |
|
|
weight_decay: float = 0.0, |
|
|
total_steps: Optional[int] = None, |
|
|
warmup_ratio: float = 0.1, |
|
|
scale_parameter: bool = True, |
|
|
relative_step_size: bool = True, |
|
|
warmup_init: bool = False, |
|
|
cliping_threshold: float = 1.0, |
|
|
decay_rate: float = -0.8, |
|
|
beta1: Optional[float] = None, |
|
|
eps2: float = 1e-30, |
|
|
**adafactor_kwargs |
|
|
) -> Tuple[Adafactor, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
|
|
""" |
|
|
Configure Adafactor optimizer with optional learning rate scheduling. |
|
|
|
|
|
This function provides a drop-in replacement for BitTransformerLM's |
|
|
configure_optimizer function, using Adafactor instead of AdamW. |
|
|
|
|
|
Args: |
|
|
model: PyTorch model to optimize |
|
|
lr: External learning rate (None for automatic scaling) |
|
|
weight_decay: Weight decay coefficient |
|
|
total_steps: Total training steps for scheduling |
|
|
warmup_ratio: Fraction of steps for warmup |
|
|
scale_parameter: Whether to scale learning rate by parameter RMS |
|
|
relative_step_size: Whether to use relative step size |
|
|
warmup_init: Whether to use warmup initialization |
|
|
cliping_threshold: Threshold for adaptive clipping |
|
|
decay_rate: Decay rate for second moment estimates |
|
|
beta1: Coefficient for first moment (None to disable) |
|
|
eps2: Regularization constant |
|
|
**adafactor_kwargs: Additional arguments for Adafactor |
|
|
|
|
|
Returns: |
|
|
Tuple of (optimizer, scheduler) |
|
|
""" |
|
|
|
|
|
params = [p for p in model.parameters() if p.requires_grad] |
|
|
|
|
|
optimizer = Adafactor( |
|
|
params, |
|
|
lr=lr, |
|
|
weight_decay=weight_decay, |
|
|
scale_parameter=scale_parameter, |
|
|
relative_step_size=relative_step_size, |
|
|
warmup_init=warmup_init, |
|
|
cliping_threshold=cliping_threshold, |
|
|
decay_rate=decay_rate, |
|
|
beta1=beta1, |
|
|
eps2=eps2, |
|
|
**adafactor_kwargs |
|
|
) |
|
|
|
|
|
scheduler = None |
|
|
|
|
|
if total_steps is not None and total_steps > 0 and lr is not None: |
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=lr, |
|
|
total_steps=total_steps, |
|
|
pct_start=warmup_ratio, |
|
|
anneal_strategy='cos', |
|
|
cycle_momentum=False, |
|
|
div_factor=25.0, |
|
|
final_div_factor=1e4, |
|
|
) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
|
|
|
class AdafactorScheduler(torch.optim.lr_scheduler._LRScheduler): |
|
|
""" |
|
|
Custom scheduler for Adafactor with warmup and polynomial decay. |
|
|
|
|
|
This scheduler is specifically designed to work with Adafactor's |
|
|
relative step size feature. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
optimizer: Adafactor, |
|
|
warmup_steps: int = 1000, |
|
|
total_steps: Optional[int] = None, |
|
|
min_lr_ratio: float = 0.1, |
|
|
polynomial_power: float = 1.0, |
|
|
last_epoch: int = -1, |
|
|
): |
|
|
self.warmup_steps = warmup_steps |
|
|
self.total_steps = total_steps |
|
|
self.min_lr_ratio = min_lr_ratio |
|
|
self.polynomial_power = polynomial_power |
|
|
super().__init__(optimizer, last_epoch) |
|
|
|
|
|
def get_lr(self): |
|
|
step = self.last_epoch + 1 |
|
|
|
|
|
if step < self.warmup_steps: |
|
|
|
|
|
return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs] |
|
|
|
|
|
if self.total_steps is None: |
|
|
|
|
|
return self.base_lrs |
|
|
|
|
|
|
|
|
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
|
|
progress = min(progress, 1.0) |
|
|
decay_factor = (1 - progress) ** self.polynomial_power |
|
|
decay_factor = max(decay_factor, self.min_lr_ratio) |
|
|
|
|
|
return [base_lr * decay_factor for base_lr in self.base_lrs] |
|
|
|
|
|
|
|
|
def configure_adafactor_with_scheduler( |
|
|
model: torch.nn.Module, |
|
|
lr: float = 1e-3, |
|
|
warmup_steps: int = 1000, |
|
|
total_steps: Optional[int] = None, |
|
|
weight_decay: float = 0.0, |
|
|
**kwargs |
|
|
) -> Tuple[Adafactor, AdafactorScheduler]: |
|
|
""" |
|
|
Configure Adafactor optimizer with custom Adafactor scheduler. |
|
|
|
|
|
Args: |
|
|
model: PyTorch model to optimize |
|
|
lr: Base learning rate |
|
|
warmup_steps: Number of warmup steps |
|
|
total_steps: Total training steps |
|
|
weight_decay: Weight decay coefficient |
|
|
**kwargs: Additional arguments for Adafactor |
|
|
|
|
|
Returns: |
|
|
Tuple of (optimizer, scheduler) |
|
|
""" |
|
|
params = [p for p in model.parameters() if p.requires_grad] |
|
|
|
|
|
optimizer = Adafactor( |
|
|
params, |
|
|
lr=lr, |
|
|
weight_decay=weight_decay, |
|
|
relative_step_size=False, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
scheduler = AdafactorScheduler( |
|
|
optimizer, |
|
|
warmup_steps=warmup_steps, |
|
|
total_steps=total_steps, |
|
|
) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
|
|
|
def create_adafactor_training_config( |
|
|
lr: Optional[float] = None, |
|
|
weight_decay: float = 0.0, |
|
|
scale_parameter: bool = True, |
|
|
relative_step_size: bool = True, |
|
|
warmup_init: bool = False, |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Create a training configuration dictionary for Adafactor optimizer. |
|
|
|
|
|
Args: |
|
|
lr: External learning rate (None for automatic) |
|
|
weight_decay: Weight decay coefficient |
|
|
scale_parameter: Whether to scale by parameter RMS |
|
|
relative_step_size: Whether to use relative step size |
|
|
warmup_init: Whether to use warmup initialization |
|
|
**kwargs: Additional configuration options |
|
|
|
|
|
Returns: |
|
|
Dictionary containing training configuration |
|
|
""" |
|
|
config = { |
|
|
"optimizer_type": "adafactor", |
|
|
"optimizer_config": { |
|
|
"lr": lr, |
|
|
"weight_decay": weight_decay, |
|
|
"scale_parameter": scale_parameter, |
|
|
"relative_step_size": relative_step_size, |
|
|
"warmup_init": warmup_init, |
|
|
**kwargs |
|
|
}, |
|
|
"scheduler_type": "adafactor_custom" if lr is None else "onecycle", |
|
|
} |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
|
|
|
def integrate_with_bittransformerlm(): |
|
|
""" |
|
|
Example of how to integrate Adafactor optimizer with BitTransformerLM training. |
|
|
|
|
|
Usage: |
|
|
from BTLM_Extensions.adafactor_optimizer import configure_adafactor_optimizer |
|
|
|
|
|
# Option 1: Use Adafactor with automatic learning rate scaling |
|
|
optimizer, scheduler = configure_adafactor_optimizer( |
|
|
model, lr=None, total_steps=1000 # lr=None enables auto-scaling |
|
|
) |
|
|
|
|
|
# Option 2: Use Adafactor with fixed learning rate |
|
|
optimizer, scheduler = configure_adafactor_optimizer( |
|
|
model, lr=1e-3, total_steps=1000 |
|
|
) |
|
|
|
|
|
# Option 3: Use Adafactor with custom scheduler |
|
|
from BTLM_Extensions.adafactor_optimizer import configure_adafactor_with_scheduler |
|
|
|
|
|
optimizer, scheduler = configure_adafactor_with_scheduler( |
|
|
model, lr=1e-3, warmup_steps=100, total_steps=1000 |
|
|
) |
|
|
|
|
|
# Use in training loop |
|
|
train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
def analyze_memory_usage(model: torch.nn.Module) -> Dict[str, float]: |
|
|
""" |
|
|
Analyze memory usage comparison between optimizers. |
|
|
|
|
|
Args: |
|
|
model: PyTorch model to analyze |
|
|
|
|
|
Returns: |
|
|
Dictionary with memory usage estimates in MB |
|
|
""" |
|
|
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
param_bytes = param_count * 4 |
|
|
|
|
|
|
|
|
adamw_memory = param_bytes * 4 |
|
|
|
|
|
|
|
|
adafactor_memory = param_bytes |
|
|
adafactor_memory += param_bytes |
|
|
|
|
|
|
|
|
factored_params = 0 |
|
|
unfactored_params = 0 |
|
|
|
|
|
for p in model.parameters(): |
|
|
if p.requires_grad: |
|
|
if len(p.shape) >= 2: |
|
|
factored_params += p.shape[0] + p.shape[1] |
|
|
else: |
|
|
unfactored_params += p.numel() |
|
|
|
|
|
adafactor_memory += (factored_params + unfactored_params) * 4 |
|
|
|
|
|
return { |
|
|
"adamw_mb": adamw_memory / (1024 * 1024), |
|
|
"adafactor_mb": adafactor_memory / (1024 * 1024), |
|
|
"savings_mb": (adamw_memory - adafactor_memory) / (1024 * 1024), |
|
|
"savings_percent": ((adamw_memory - adafactor_memory) / adamw_memory) * 100, |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
model = nn.Sequential( |
|
|
nn.Linear(100, 200), |
|
|
nn.ReLU(), |
|
|
nn.Linear(200, 50), |
|
|
nn.ReLU(), |
|
|
nn.Linear(50, 1) |
|
|
) |
|
|
|
|
|
print("Testing Adafactor optimizer...") |
|
|
|
|
|
|
|
|
optimizer, scheduler = configure_adafactor_optimizer( |
|
|
model, lr=None, total_steps=100 |
|
|
) |
|
|
|
|
|
|
|
|
x = torch.randn(32, 100) |
|
|
y = torch.randn(32, 1) |
|
|
|
|
|
pred = model(x) |
|
|
loss = nn.functional.mse_loss(pred, y) |
|
|
initial_loss = loss.item() |
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
if scheduler: |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
optimizer2, scheduler2 = configure_adafactor_optimizer( |
|
|
model, lr=1e-3, total_steps=100 |
|
|
) |
|
|
|
|
|
pred = model(x) |
|
|
loss = nn.functional.mse_loss(pred, y) |
|
|
loss.backward() |
|
|
optimizer2.step() |
|
|
if scheduler2: |
|
|
scheduler2.step() |
|
|
|
|
|
|
|
|
memory_analysis = analyze_memory_usage(model) |
|
|
|
|
|
print("Adafactor optimizer test completed successfully!") |
|
|
print(f"Initial loss: {initial_loss:.4f}") |
|
|
print(f"Final loss: {loss.item():.4f}") |
|
|
print(f"Memory analysis: {memory_analysis}") |