WCNegentropy's picture
๐Ÿš€ Refined BitTransformerLM: Organized codebase with best practices
6ddf8d6 verified
"""
Adafactor Optimizer for BitTransformerLM Extensions
===================================================
Implementation of the Adafactor optimizer with memory-efficient factorization.
Based on "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" research.
Key features:
- Factorized second moment estimates for memory efficiency
- Automatic scaling of learning rates
- Relative step size and clip threshold
- Compatible with BitTransformerLM's training infrastructure
"""
import math
import torch
from torch.optim.optimizer import Optimizer
from typing import Any, Dict, List, Optional, Tuple, Union
class Adafactor(Optimizer):
"""
Adafactor optimizer implementation.
Adafactor reduces memory usage by factorizing the second moment estimates
for parameters with 2 or more dimensions, making it highly memory efficient
for large transformer models.
Args:
params: Iterable of parameters to optimize
lr: External learning rate (default: None, uses automatic scaling)
eps2: Regularization constant for second moment (default: 1e-30)
cliping_threshold: Threshold for adaptive clipping (default: 1.0)
decay_rate: Coefficient used for computing running averages (default: -0.8)
beta1: Coefficient used for computing running averages of gradient (default: None)
weight_decay: Weight decay coefficient (default: 0.0)
scale_parameter: If True, learning rate is scaled by root mean square of parameter (default: True)
relative_step_size: If True, use relative step size (default: True)
warmup_init: If True, warmup learning rate (default: False)
"""
def __init__(
self,
params,
lr: Optional[float] = None,
eps2: float = 1e-30,
cliping_threshold: float = 1.0,
decay_rate: float = -0.8,
beta1: Optional[float] = None,
weight_decay: float = 0.0,
scale_parameter: bool = True,
relative_step_size: bool = True,
warmup_init: bool = False,
):
if lr is not None and lr <= 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
eps2=eps2,
cliping_threshold=cliping_threshold,
decay_rate=decay_rate,
beta1=beta1,
weight_decay=weight_decay,
scale_parameter=scale_parameter,
relative_step_size=relative_step_size,
warmup_init=warmup_init,
)
super().__init__(params, defaults)
def _get_lr(self, param_group, param_state):
"""Compute learning rate for parameter group."""
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps2"], param_state["RMS"])
return param_scale * rel_step_sz
def _get_options(self, param_group, param_shape):
"""Get optimization options for parameter."""
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
def _rms(self, tensor):
"""Root mean square."""
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
"""Approximation of exponential moving average of square of gradient."""
r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
.rsqrt_())
c_factor = ((exp_avg_sq_col).rsqrt())
return torch.mul(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))
@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad).float()
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).float()
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[:-2] + grad_shape[-1:]).float()
else:
state["exp_avg_sq"] = torch.zeros_like(grad).float()
state["RMS"] = 0
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = self._rms(p_data_fp32)
lr = group["lr"]
if group["lr"] is None:
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = grad**2 + group["eps2"]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(
update.mean(dim=-1), alpha=1.0 - beta2t)
exp_avg_sq_col.mul_(beta2t).add_(
update.mean(dim=-2), alpha=1.0 - beta2t)
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_(max(1.0, self._rms(update) / group["cliping_threshold"]))
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"])
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.mul_(1 - group["weight_decay"] * lr)
p_data_fp32.add_(update, alpha=-lr)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss
def configure_adafactor_optimizer(
model: torch.nn.Module,
lr: Optional[float] = None,
weight_decay: float = 0.0,
total_steps: Optional[int] = None,
warmup_ratio: float = 0.1,
scale_parameter: bool = True,
relative_step_size: bool = True,
warmup_init: bool = False,
cliping_threshold: float = 1.0,
decay_rate: float = -0.8,
beta1: Optional[float] = None,
eps2: float = 1e-30,
**adafactor_kwargs
) -> Tuple[Adafactor, Optional[torch.optim.lr_scheduler._LRScheduler]]:
"""
Configure Adafactor optimizer with optional learning rate scheduling.
This function provides a drop-in replacement for BitTransformerLM's
configure_optimizer function, using Adafactor instead of AdamW.
Args:
model: PyTorch model to optimize
lr: External learning rate (None for automatic scaling)
weight_decay: Weight decay coefficient
total_steps: Total training steps for scheduling
warmup_ratio: Fraction of steps for warmup
scale_parameter: Whether to scale learning rate by parameter RMS
relative_step_size: Whether to use relative step size
warmup_init: Whether to use warmup initialization
cliping_threshold: Threshold for adaptive clipping
decay_rate: Decay rate for second moment estimates
beta1: Coefficient for first moment (None to disable)
eps2: Regularization constant
**adafactor_kwargs: Additional arguments for Adafactor
Returns:
Tuple of (optimizer, scheduler)
"""
# Adafactor can handle all parameters in one group efficiently
params = [p for p in model.parameters() if p.requires_grad]
optimizer = Adafactor(
params,
lr=lr,
weight_decay=weight_decay,
scale_parameter=scale_parameter,
relative_step_size=relative_step_size,
warmup_init=warmup_init,
cliping_threshold=cliping_threshold,
decay_rate=decay_rate,
beta1=beta1,
eps2=eps2,
**adafactor_kwargs
)
scheduler = None
# Adafactor has built-in learning rate scaling, but we can still use OneCycle
if total_steps is not None and total_steps > 0 and lr is not None:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=lr,
total_steps=total_steps,
pct_start=warmup_ratio,
anneal_strategy='cos',
cycle_momentum=False, # Adafactor doesn't use momentum cycling
div_factor=25.0,
final_div_factor=1e4,
)
return optimizer, scheduler
class AdafactorScheduler(torch.optim.lr_scheduler._LRScheduler):
"""
Custom scheduler for Adafactor with warmup and polynomial decay.
This scheduler is specifically designed to work with Adafactor's
relative step size feature.
"""
def __init__(
self,
optimizer: Adafactor,
warmup_steps: int = 1000,
total_steps: Optional[int] = None,
min_lr_ratio: float = 0.1,
polynomial_power: float = 1.0,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr_ratio = min_lr_ratio
self.polynomial_power = polynomial_power
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = self.last_epoch + 1
if step < self.warmup_steps:
# Linear warmup
return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs]
if self.total_steps is None:
# No decay after warmup
return self.base_lrs
# Polynomial decay
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
progress = min(progress, 1.0)
decay_factor = (1 - progress) ** self.polynomial_power
decay_factor = max(decay_factor, self.min_lr_ratio)
return [base_lr * decay_factor for base_lr in self.base_lrs]
def configure_adafactor_with_scheduler(
model: torch.nn.Module,
lr: float = 1e-3,
warmup_steps: int = 1000,
total_steps: Optional[int] = None,
weight_decay: float = 0.0,
**kwargs
) -> Tuple[Adafactor, AdafactorScheduler]:
"""
Configure Adafactor optimizer with custom Adafactor scheduler.
Args:
model: PyTorch model to optimize
lr: Base learning rate
warmup_steps: Number of warmup steps
total_steps: Total training steps
weight_decay: Weight decay coefficient
**kwargs: Additional arguments for Adafactor
Returns:
Tuple of (optimizer, scheduler)
"""
params = [p for p in model.parameters() if p.requires_grad]
optimizer = Adafactor(
params,
lr=lr,
weight_decay=weight_decay,
relative_step_size=False, # We'll use external scheduler
**kwargs
)
scheduler = AdafactorScheduler(
optimizer,
warmup_steps=warmup_steps,
total_steps=total_steps,
)
return optimizer, scheduler
def create_adafactor_training_config(
lr: Optional[float] = None,
weight_decay: float = 0.0,
scale_parameter: bool = True,
relative_step_size: bool = True,
warmup_init: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Create a training configuration dictionary for Adafactor optimizer.
Args:
lr: External learning rate (None for automatic)
weight_decay: Weight decay coefficient
scale_parameter: Whether to scale by parameter RMS
relative_step_size: Whether to use relative step size
warmup_init: Whether to use warmup initialization
**kwargs: Additional configuration options
Returns:
Dictionary containing training configuration
"""
config = {
"optimizer_type": "adafactor",
"optimizer_config": {
"lr": lr,
"weight_decay": weight_decay,
"scale_parameter": scale_parameter,
"relative_step_size": relative_step_size,
"warmup_init": warmup_init,
**kwargs
},
"scheduler_type": "adafactor_custom" if lr is None else "onecycle",
}
return config
# Example usage and integration helpers
def integrate_with_bittransformerlm():
"""
Example of how to integrate Adafactor optimizer with BitTransformerLM training.
Usage:
from BTLM_Extensions.adafactor_optimizer import configure_adafactor_optimizer
# Option 1: Use Adafactor with automatic learning rate scaling
optimizer, scheduler = configure_adafactor_optimizer(
model, lr=None, total_steps=1000 # lr=None enables auto-scaling
)
# Option 2: Use Adafactor with fixed learning rate
optimizer, scheduler = configure_adafactor_optimizer(
model, lr=1e-3, total_steps=1000
)
# Option 3: Use Adafactor with custom scheduler
from BTLM_Extensions.adafactor_optimizer import configure_adafactor_with_scheduler
optimizer, scheduler = configure_adafactor_with_scheduler(
model, lr=1e-3, warmup_steps=100, total_steps=1000
)
# Use in training loop
train_loop(model, data, optimizer=optimizer, scheduler=scheduler)
"""
pass
def analyze_memory_usage(model: torch.nn.Module) -> Dict[str, float]:
"""
Analyze memory usage comparison between optimizers.
Args:
model: PyTorch model to analyze
Returns:
Dictionary with memory usage estimates in MB
"""
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
param_bytes = param_count * 4 # Assume float32
# AdamW memory: parameters + gradients + 2 momentum states
adamw_memory = param_bytes * 4
# Adafactor memory estimation
adafactor_memory = param_bytes # parameters
adafactor_memory += param_bytes # gradients
# For factored parameters (2D), Adafactor stores row and column means
factored_params = 0
unfactored_params = 0
for p in model.parameters():
if p.requires_grad:
if len(p.shape) >= 2:
factored_params += p.shape[0] + p.shape[1] # row + col means
else:
unfactored_params += p.numel()
adafactor_memory += (factored_params + unfactored_params) * 4 # second moments
return {
"adamw_mb": adamw_memory / (1024 * 1024),
"adafactor_mb": adafactor_memory / (1024 * 1024),
"savings_mb": (adamw_memory - adafactor_memory) / (1024 * 1024),
"savings_percent": ((adamw_memory - adafactor_memory) / adamw_memory) * 100,
}
if __name__ == "__main__":
# Simple test of the optimizer
import torch.nn as nn
model = nn.Sequential(
nn.Linear(100, 200),
nn.ReLU(),
nn.Linear(200, 50),
nn.ReLU(),
nn.Linear(50, 1)
)
print("Testing Adafactor optimizer...")
# Test with automatic learning rate
optimizer, scheduler = configure_adafactor_optimizer(
model, lr=None, total_steps=100
)
# Simple training step
x = torch.randn(32, 100)
y = torch.randn(32, 1)
pred = model(x)
loss = nn.functional.mse_loss(pred, y)
initial_loss = loss.item()
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
# Test with fixed learning rate
optimizer2, scheduler2 = configure_adafactor_optimizer(
model, lr=1e-3, total_steps=100
)
pred = model(x)
loss = nn.functional.mse_loss(pred, y)
loss.backward()
optimizer2.step()
if scheduler2:
scheduler2.step()
# Analyze memory usage
memory_analysis = analyze_memory_usage(model)
print("Adafactor optimizer test completed successfully!")
print(f"Initial loss: {initial_loss:.4f}")
print(f"Final loss: {loss.item():.4f}")
print(f"Memory analysis: {memory_analysis}")