|
|
""" |
|
|
Lion Optimizer for BitTransformerLM Extensions |
|
|
============================================== |
|
|
|
|
|
Implementation of the Lion optimizer (EvoLved Sign Momentum). |
|
|
Based on "Symbolic Discovery of Optimization Algorithms" research. |
|
|
|
|
|
Key features: |
|
|
- Sign-based momentum updates |
|
|
- Extremely memory efficient (only stores momentum) |
|
|
- Often outperforms Adam/AdamW with larger learning rates |
|
|
- Compatible with BitTransformerLM's training infrastructure |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch.optim.optimizer import Optimizer |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
class Lion(Optimizer): |
|
|
""" |
|
|
Lion optimizer implementation. |
|
|
|
|
|
Lion uses the sign of the interpolated momentum for parameter updates, |
|
|
making it very memory efficient while maintaining competitive performance. |
|
|
|
|
|
Args: |
|
|
params: Iterable of parameters to optimize |
|
|
lr: Learning rate (default: 1e-4, typically needs to be smaller than Adam) |
|
|
betas: Coefficients for computing momentum (default: (0.9, 0.99)) |
|
|
weight_decay: Weight decay coefficient (default: 0.0) |
|
|
eps: Small constant for numerical stability (default: 1e-8) |
|
|
maximize: Whether to maximize the objective (default: False) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
lr: float = 1e-4, |
|
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
|
weight_decay: float = 0.0, |
|
|
eps: float = 1e-8, |
|
|
maximize: bool = False, |
|
|
): |
|
|
if not 0.0 <= lr: |
|
|
raise ValueError(f"Invalid learning rate: {lr}") |
|
|
if not 0.0 <= betas[0] < 1.0: |
|
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") |
|
|
if not 0.0 <= betas[1] < 1.0: |
|
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") |
|
|
if not 0.0 <= weight_decay: |
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
|
|
if not 0.0 <= eps: |
|
|
raise ValueError(f"Invalid epsilon value: {eps}") |
|
|
|
|
|
defaults = dict( |
|
|
lr=lr, |
|
|
betas=betas, |
|
|
weight_decay=weight_decay, |
|
|
eps=eps, |
|
|
maximize=maximize, |
|
|
) |
|
|
super().__init__(params, defaults) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Perform a single optimization step.""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
for p in group["params"]: |
|
|
if p.grad is None: |
|
|
continue |
|
|
|
|
|
grad = p.grad |
|
|
if group["maximize"]: |
|
|
grad = -grad |
|
|
|
|
|
if grad.dtype in {torch.float16, torch.bfloat16}: |
|
|
grad = grad.float() |
|
|
|
|
|
state = self.state[p] |
|
|
|
|
|
|
|
|
if len(state) == 0: |
|
|
state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
|
|
|
momentum = state["momentum"] |
|
|
beta1, beta2 = group["betas"] |
|
|
|
|
|
|
|
|
if group["weight_decay"] != 0: |
|
|
p.mul_(1 - group["lr"] * group["weight_decay"]) |
|
|
|
|
|
|
|
|
|
|
|
interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) |
|
|
|
|
|
|
|
|
|
|
|
p.add_(torch.sign(interpolated), alpha=-group["lr"]) |
|
|
|
|
|
|
|
|
|
|
|
momentum.mul_(beta2).add_(grad, alpha=1 - beta2) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def configure_lion_optimizer( |
|
|
model: torch.nn.Module, |
|
|
lr: float = 1e-4, |
|
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
|
weight_decay: float = 0.01, |
|
|
total_steps: Optional[int] = None, |
|
|
warmup_ratio: float = 0.1, |
|
|
**lion_kwargs |
|
|
) -> Tuple[Lion, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
|
|
""" |
|
|
Configure Lion optimizer with OneCycle learning rate schedule. |
|
|
|
|
|
This function provides a drop-in replacement for BitTransformerLM's |
|
|
configure_optimizer function, using Lion instead of AdamW. |
|
|
|
|
|
Note: Lion typically works well with learning rates about 3-10x smaller |
|
|
than Adam/AdamW, but higher weight decay (0.01-0.1). |
|
|
|
|
|
Args: |
|
|
model: PyTorch model to optimize |
|
|
lr: Peak learning rate (typically smaller than Adam) |
|
|
betas: Beta coefficients for momentum computation |
|
|
weight_decay: Weight decay coefficient (can be higher than Adam) |
|
|
total_steps: Total training steps for OneCycle schedule |
|
|
warmup_ratio: Fraction of steps for warmup |
|
|
**lion_kwargs: Additional arguments for Lion optimizer |
|
|
|
|
|
Returns: |
|
|
Tuple of (optimizer, scheduler) |
|
|
""" |
|
|
|
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
|
|
|
if param.dim() >= 2: |
|
|
decay_params.append(param) |
|
|
else: |
|
|
no_decay_params.append(param) |
|
|
|
|
|
param_groups = [ |
|
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
|
{"params": no_decay_params, "weight_decay": 0.0}, |
|
|
] |
|
|
|
|
|
optimizer = Lion( |
|
|
param_groups, |
|
|
lr=lr, |
|
|
betas=betas, |
|
|
**lion_kwargs |
|
|
) |
|
|
|
|
|
scheduler = None |
|
|
if total_steps is not None and total_steps > 0: |
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=lr, |
|
|
total_steps=total_steps, |
|
|
pct_start=warmup_ratio, |
|
|
anneal_strategy='cos', |
|
|
cycle_momentum=False, |
|
|
div_factor=25.0, |
|
|
final_div_factor=1e4, |
|
|
) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
|
|
|
def create_lion_training_config( |
|
|
lr: float = 1e-4, |
|
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
|
weight_decay: float = 0.01, |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Create a training configuration dictionary for Lion optimizer. |
|
|
|
|
|
This can be used with BitTransformerLM's training scripts by passing |
|
|
the config to the training loop. |
|
|
|
|
|
Args: |
|
|
lr: Learning rate |
|
|
betas: Beta coefficients for momentum |
|
|
weight_decay: Weight decay coefficient |
|
|
**kwargs: Additional configuration options |
|
|
|
|
|
Returns: |
|
|
Dictionary containing training configuration |
|
|
""" |
|
|
config = { |
|
|
"optimizer_type": "lion", |
|
|
"optimizer_config": { |
|
|
"lr": lr, |
|
|
"betas": betas, |
|
|
"weight_decay": weight_decay, |
|
|
**kwargs |
|
|
}, |
|
|
"scheduler_type": "onecycle", |
|
|
} |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
class AdaptiveLion(Lion): |
|
|
""" |
|
|
Enhanced Lion optimizer with adaptive learning rate scaling. |
|
|
|
|
|
This variant automatically adjusts the learning rate based on the |
|
|
magnitude of gradients and momentum, potentially improving stability. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
lr: float = 1e-4, |
|
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
|
weight_decay: float = 0.0, |
|
|
eps: float = 1e-8, |
|
|
maximize: bool = False, |
|
|
adaptive_scale: float = 0.1, |
|
|
min_scale: float = 0.01, |
|
|
max_scale: float = 10.0, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
adaptive_scale: Scaling factor for adaptive adjustment |
|
|
min_scale: Minimum learning rate scale |
|
|
max_scale: Maximum learning rate scale |
|
|
""" |
|
|
self.adaptive_scale = adaptive_scale |
|
|
self.min_scale = min_scale |
|
|
self.max_scale = max_scale |
|
|
|
|
|
super().__init__(params, lr, betas, weight_decay, eps, maximize) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Perform optimization step with adaptive scaling.""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
for p in group["params"]: |
|
|
if p.grad is None: |
|
|
continue |
|
|
|
|
|
grad = p.grad |
|
|
if group["maximize"]: |
|
|
grad = -grad |
|
|
|
|
|
if grad.dtype in {torch.float16, torch.bfloat16}: |
|
|
grad = grad.float() |
|
|
|
|
|
state = self.state[p] |
|
|
|
|
|
if len(state) == 0: |
|
|
state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
state["step"] = 0 |
|
|
|
|
|
momentum = state["momentum"] |
|
|
state["step"] += 1 |
|
|
beta1, beta2 = group["betas"] |
|
|
|
|
|
|
|
|
grad_norm = grad.norm().item() |
|
|
momentum_norm = momentum.norm().item() |
|
|
|
|
|
|
|
|
if momentum_norm > 1e-8: |
|
|
scale = 1.0 + self.adaptive_scale * (grad_norm / momentum_norm - 1.0) |
|
|
scale = torch.clamp(torch.tensor(scale), self.min_scale, self.max_scale).item() |
|
|
else: |
|
|
scale = 1.0 |
|
|
|
|
|
adaptive_lr = group["lr"] * scale |
|
|
|
|
|
|
|
|
if group["weight_decay"] != 0: |
|
|
p.mul_(1 - adaptive_lr * group["weight_decay"]) |
|
|
|
|
|
|
|
|
interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) |
|
|
p.add_(torch.sign(interpolated), alpha=-adaptive_lr) |
|
|
momentum.mul_(beta2).add_(grad, alpha=1 - beta2) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def configure_adaptive_lion_optimizer( |
|
|
model: torch.nn.Module, |
|
|
lr: float = 1e-4, |
|
|
adaptive_scale: float = 0.1, |
|
|
**kwargs |
|
|
) -> Tuple[AdaptiveLion, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
|
|
"""Configure AdaptiveLion optimizer with learning rate scheduling.""" |
|
|
|
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if param.dim() >= 2: |
|
|
decay_params.append(param) |
|
|
else: |
|
|
no_decay_params.append(param) |
|
|
|
|
|
param_groups = [ |
|
|
{"params": decay_params, "weight_decay": kwargs.get("weight_decay", 0.01)}, |
|
|
{"params": no_decay_params, "weight_decay": 0.0}, |
|
|
] |
|
|
|
|
|
optimizer = AdaptiveLion( |
|
|
param_groups, |
|
|
lr=lr, |
|
|
adaptive_scale=adaptive_scale, |
|
|
**{k: v for k, v in kwargs.items() if k != "weight_decay"} |
|
|
) |
|
|
|
|
|
scheduler = None |
|
|
total_steps = kwargs.get("total_steps") |
|
|
if total_steps is not None and total_steps > 0: |
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=lr, |
|
|
total_steps=total_steps, |
|
|
pct_start=kwargs.get("warmup_ratio", 0.1), |
|
|
anneal_strategy='cos', |
|
|
cycle_momentum=False, |
|
|
div_factor=25.0, |
|
|
final_div_factor=1e4, |
|
|
) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
|
|
|
|
|
|
def integrate_with_bittransformerlm(): |
|
|
""" |
|
|
Example of how to integrate Lion optimizer with BitTransformerLM training. |
|
|
|
|
|
Usage: |
|
|
from BTLM_Extensions.lion_optimizer import configure_lion_optimizer |
|
|
|
|
|
# Replace the standard optimizer configuration |
|
|
# Note: Lion typically needs smaller learning rates than Adam |
|
|
optimizer, scheduler = configure_lion_optimizer( |
|
|
model, lr=1e-4, weight_decay=0.01, total_steps=1000 |
|
|
) |
|
|
|
|
|
# Use in training loop |
|
|
train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
|
|
|
|
|
# For adaptive version: |
|
|
from BTLM_Extensions.lion_optimizer import configure_adaptive_lion_optimizer |
|
|
|
|
|
optimizer, scheduler = configure_adaptive_lion_optimizer( |
|
|
model, lr=1e-4, adaptive_scale=0.1, total_steps=1000 |
|
|
) |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
model = nn.Sequential( |
|
|
nn.Linear(10, 20), |
|
|
nn.ReLU(), |
|
|
nn.Linear(20, 1) |
|
|
) |
|
|
|
|
|
print("Testing standard Lion optimizer...") |
|
|
optimizer, scheduler = configure_lion_optimizer(model, lr=1e-4, total_steps=100) |
|
|
|
|
|
|
|
|
x = torch.randn(32, 10) |
|
|
y = torch.randn(32, 1) |
|
|
|
|
|
pred = model(x) |
|
|
loss = nn.functional.mse_loss(pred, y) |
|
|
initial_loss = loss.item() |
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
if scheduler: |
|
|
scheduler.step() |
|
|
|
|
|
print(f"Initial loss: {initial_loss:.4f}") |
|
|
|
|
|
|
|
|
print("Testing Adaptive Lion optimizer...") |
|
|
model2 = nn.Sequential( |
|
|
nn.Linear(10, 20), |
|
|
nn.ReLU(), |
|
|
nn.Linear(20, 1) |
|
|
) |
|
|
|
|
|
optimizer2, scheduler2 = configure_adaptive_lion_optimizer( |
|
|
model2, lr=1e-4, adaptive_scale=0.1, total_steps=100 |
|
|
) |
|
|
|
|
|
pred2 = model2(x) |
|
|
loss2 = nn.functional.mse_loss(pred2, y) |
|
|
loss2.backward() |
|
|
optimizer2.step() |
|
|
if scheduler2: |
|
|
scheduler2.step() |
|
|
|
|
|
print("Lion optimizers test completed successfully!") |
|
|
print(f"Standard Lion loss: {initial_loss:.4f}") |
|
|
print(f"Adaptive Lion loss: {loss2.item():.4f}") |