|
|
"""
|
|
|
Advanced Optimizers for Large Scale Training
|
|
|
Including Lion, AdamW variants, and SAM
|
|
|
"""
|
|
|
|
|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.optim import Optimizer
|
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
class Lion(Optimizer):
|
|
|
"""
|
|
|
Lion optimizer from "Symbolic Discovery of Optimization Algorithms"
|
|
|
More memory efficient than AdamW for large models
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
params,
|
|
|
lr: float = 1e-4,
|
|
|
betas: Tuple[float, float] = (0.9, 0.99),
|
|
|
weight_decay: float = 0.0,
|
|
|
maximize: bool = False,
|
|
|
foreach: Optional[bool] = None,
|
|
|
):
|
|
|
if not 0.0 <= lr:
|
|
|
raise ValueError(f"Invalid learning rate: {lr}")
|
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
|
if not 0.0 <= weight_decay:
|
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
|
|
|
|
defaults = dict(
|
|
|
lr=lr,
|
|
|
betas=betas,
|
|
|
weight_decay=weight_decay,
|
|
|
maximize=maximize,
|
|
|
foreach=foreach,
|
|
|
)
|
|
|
super().__init__(params, defaults)
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
super().__setstate__(state)
|
|
|
for group in self.param_groups:
|
|
|
group.setdefault("maximize", False)
|
|
|
group.setdefault("foreach", None)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def step(self, closure=None):
|
|
|
"""Performs a single optimization step."""
|
|
|
loss = None
|
|
|
if closure is not None:
|
|
|
with torch.enable_grad():
|
|
|
loss = closure()
|
|
|
|
|
|
for group in self.param_groups:
|
|
|
params_with_grad = []
|
|
|
grads = []
|
|
|
exp_avgs = []
|
|
|
|
|
|
beta1, beta2 = group["betas"]
|
|
|
|
|
|
for p in group["params"]:
|
|
|
if p.grad is None:
|
|
|
continue
|
|
|
params_with_grad.append(p)
|
|
|
if p.grad.dtype in {torch.float16, torch.bfloat16}:
|
|
|
grads.append(p.grad.float())
|
|
|
else:
|
|
|
grads.append(p.grad)
|
|
|
|
|
|
state = self.state[p]
|
|
|
|
|
|
if len(state) == 0:
|
|
|
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
|
|
exp_avgs.append(state["exp_avg"])
|
|
|
|
|
|
lion(
|
|
|
params_with_grad,
|
|
|
grads,
|
|
|
exp_avgs,
|
|
|
beta1=beta1,
|
|
|
beta2=beta2,
|
|
|
lr=group["lr"],
|
|
|
weight_decay=group["weight_decay"],
|
|
|
maximize=group["maximize"],
|
|
|
)
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
def lion(
|
|
|
params,
|
|
|
grads,
|
|
|
exp_avgs,
|
|
|
*,
|
|
|
beta1: float,
|
|
|
beta2: float,
|
|
|
lr: float,
|
|
|
weight_decay: float,
|
|
|
maximize: bool,
|
|
|
):
|
|
|
"""Functional API that performs Lion algorithm computation."""
|
|
|
|
|
|
for i, param in enumerate(params):
|
|
|
grad = grads[i] if not maximize else -grads[i]
|
|
|
exp_avg = exp_avgs[i]
|
|
|
|
|
|
|
|
|
param.mul_(1 - lr * weight_decay)
|
|
|
|
|
|
|
|
|
update = exp_avg * beta1 + grad * (1 - beta1)
|
|
|
param.add_(torch.sign(update), alpha=-lr)
|
|
|
|
|
|
|
|
|
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
|
|
|
|
|
|
|
|
|
class AdamWScale(torch.optim.AdamW):
|
|
|
"""
|
|
|
AdamW with learning rate scaling based on parameter norm
|
|
|
Useful for very large models
|
|
|
"""
|
|
|
|
|
|
def __init__(self, *args, scale_lr: bool = True, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self.scale_lr = scale_lr
|
|
|
|
|
|
def step(self, closure=None):
|
|
|
if not self.scale_lr:
|
|
|
return super().step(closure)
|
|
|
|
|
|
|
|
|
for group in self.param_groups:
|
|
|
total_norm = 0.0
|
|
|
for p in group['params']:
|
|
|
if p.grad is not None:
|
|
|
param_norm = p.data.norm()
|
|
|
total_norm += param_norm.item() ** 2
|
|
|
total_norm = total_norm ** (1. / 2)
|
|
|
|
|
|
|
|
|
if total_norm > 0:
|
|
|
scale = min(1.0, 1.0 / total_norm)
|
|
|
group['lr'] = group['lr'] * scale
|
|
|
|
|
|
return super().step(closure)
|
|
|
|
|
|
|
|
|
class SAM(Optimizer):
|
|
|
"""
|
|
|
Sharpness-Aware Minimization (SAM) optimizer
|
|
|
Improves generalization by finding flatter minima
|
|
|
"""
|
|
|
|
|
|
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
|
|
|
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
|
|
|
|
|
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
|
|
|
super(SAM, self).__init__(params, defaults)
|
|
|
|
|
|
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
|
|
self.param_groups = self.base_optimizer.param_groups
|
|
|
self.defaults.update(self.base_optimizer.defaults)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def first_step(self, zero_grad=False):
|
|
|
grad_norm = self._grad_norm()
|
|
|
for group in self.param_groups:
|
|
|
scale = group["rho"] / (grad_norm + 1e-12)
|
|
|
|
|
|
for p in group["params"]:
|
|
|
if p.grad is None:
|
|
|
continue
|
|
|
self.state[p]["old_p"] = p.data.clone()
|
|
|
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
|
|
|
p.add_(e_w)
|
|
|
|
|
|
if zero_grad:
|
|
|
self.zero_grad()
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def second_step(self, zero_grad=False):
|
|
|
for group in self.param_groups:
|
|
|
for p in group["params"]:
|
|
|
if p.grad is None:
|
|
|
continue
|
|
|
p.data = self.state[p]["old_p"]
|
|
|
|
|
|
self.base_optimizer.step()
|
|
|
|
|
|
if zero_grad:
|
|
|
self.zero_grad()
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def step(self, closure=None):
|
|
|
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
|
|
|
closure = torch.enable_grad()(closure)
|
|
|
|
|
|
self.first_step(zero_grad=True)
|
|
|
closure()
|
|
|
self.second_step()
|
|
|
|
|
|
def _grad_norm(self):
|
|
|
shared_device = self.param_groups[0]["params"][0].device
|
|
|
norm = torch.norm(
|
|
|
torch.stack([
|
|
|
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(dtype=torch.float32).to(shared_device)
|
|
|
for group in self.param_groups for p in group["params"]
|
|
|
if p.grad is not None
|
|
|
]),
|
|
|
dtype=torch.float32
|
|
|
)
|
|
|
return norm
|
|
|
|
|
|
def load_state_dict(self, state_dict):
|
|
|
super().load_state_dict(state_dict)
|
|
|
self.base_optimizer.param_groups = self.param_groups
|
|
|
|
|
|
def state_dict(self):
|
|
|
return super().state_dict()
|
|
|
|
|
|
|
|
|
class Sophia(Optimizer):
|
|
|
"""
|
|
|
Sophia optimizer - Second-order clipped stochastic optimization
|
|
|
More efficient than Adam for large language models
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
params,
|
|
|
lr=1e-4,
|
|
|
betas=(0.965, 0.99),
|
|
|
rho=0.04,
|
|
|
weight_decay=1e-1,
|
|
|
*,
|
|
|
maximize: bool = False,
|
|
|
capturable: bool = False,
|
|
|
):
|
|
|
if not 0.0 <= lr:
|
|
|
raise ValueError(f"Invalid learning rate: {lr}")
|
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
|
if not 0.0 <= rho:
|
|
|
raise ValueError(f"Invalid rho parameter: {rho}")
|
|
|
if not 0.0 <= weight_decay:
|
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
|
|
|
|
defaults = dict(
|
|
|
lr=lr, betas=betas, rho=rho, weight_decay=weight_decay,
|
|
|
maximize=maximize, capturable=capturable
|
|
|
)
|
|
|
super().__init__(params, defaults)
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
super().__setstate__(state)
|
|
|
for group in self.param_groups:
|
|
|
group.setdefault('maximize', False)
|
|
|
group.setdefault('capturable', False)
|
|
|
state_values = list(self.state.values())
|
|
|
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
|
|
if not step_is_tensor:
|
|
|
for s in state_values:
|
|
|
s['step'] = torch.tensor(float(s['step']))
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def update_hessian(self):
|
|
|
"""Update Hessian diagonal approximation"""
|
|
|
for group in self.param_groups:
|
|
|
for p in group['params']:
|
|
|
if p.grad is not None:
|
|
|
state = self.state[p]
|
|
|
if len(state) == 0:
|
|
|
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
|
|
|
if group['capturable'] else torch.tensor(0.)
|
|
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
|
|
hessian = state['hessian']
|
|
|
beta2 = group['betas'][1]
|
|
|
|
|
|
|
|
|
hessian.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def step(self, closure=None, bs=5120):
|
|
|
"""Performs a single optimization step."""
|
|
|
loss = None
|
|
|
if closure is not None:
|
|
|
with torch.enable_grad():
|
|
|
loss = closure()
|
|
|
|
|
|
for group in self.param_groups:
|
|
|
params_with_grad = []
|
|
|
grads = []
|
|
|
exp_avgs = []
|
|
|
hessians = []
|
|
|
state_steps = []
|
|
|
beta1, beta2 = group['betas']
|
|
|
|
|
|
for p in group['params']:
|
|
|
if p.grad is None:
|
|
|
continue
|
|
|
params_with_grad.append(p)
|
|
|
if p.grad.dtype in {torch.float16, torch.bfloat16}:
|
|
|
grads.append(p.grad.float())
|
|
|
else:
|
|
|
grads.append(p.grad)
|
|
|
|
|
|
state = self.state[p]
|
|
|
|
|
|
if len(state) == 0:
|
|
|
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
|
|
|
if group['capturable'] else torch.tensor(0.)
|
|
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
|
|
exp_avgs.append(state['exp_avg'])
|
|
|
hessians.append(state['hessian'])
|
|
|
state_steps.append(state['step'])
|
|
|
|
|
|
sophia(
|
|
|
params_with_grad,
|
|
|
grads,
|
|
|
exp_avgs,
|
|
|
hessians,
|
|
|
state_steps,
|
|
|
beta1=beta1,
|
|
|
beta2=beta2,
|
|
|
rho=group['rho'],
|
|
|
lr=group['lr'],
|
|
|
weight_decay=group['weight_decay'],
|
|
|
maximize=group['maximize'],
|
|
|
capturable=group['capturable'],
|
|
|
)
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
def sophia(
|
|
|
params,
|
|
|
grads,
|
|
|
exp_avgs,
|
|
|
hessians,
|
|
|
state_steps,
|
|
|
capturable: bool = False,
|
|
|
*,
|
|
|
beta1: float,
|
|
|
beta2: float,
|
|
|
rho: float,
|
|
|
lr: float,
|
|
|
weight_decay: float,
|
|
|
maximize: bool,
|
|
|
):
|
|
|
"""Functional API that performs Sophia algorithm computation."""
|
|
|
|
|
|
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
|
|
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
|
|
|
|
|
|
for i, param in enumerate(params):
|
|
|
grad = grads[i] if not maximize else -grads[i]
|
|
|
exp_avg = exp_avgs[i]
|
|
|
hessian = hessians[i]
|
|
|
step_t = state_steps[i]
|
|
|
|
|
|
if capturable:
|
|
|
bs = torch.ones_like(step_t) * 5120
|
|
|
assert param.dtype == torch.float32
|
|
|
|
|
|
|
|
|
step_t += 1
|
|
|
|
|
|
|
|
|
param.mul_(1 - lr * weight_decay)
|
|
|
|
|
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
|
|
|
|
|
|
|
bias_correction1 = 1 - beta1 ** step_t.item()
|
|
|
|
|
|
|
|
|
k = hessian.abs().clamp_(min=1e-8)
|
|
|
u = (exp_avg / bias_correction1) / k.sqrt()
|
|
|
u.clamp_(min=-rho, max=rho)
|
|
|
|
|
|
param.add_(u, alpha=-lr)
|
|
|
|
|
|
|
|
|
def get_optimizer(model, config):
|
|
|
"""Get optimizer based on configuration"""
|
|
|
|
|
|
decay_params = []
|
|
|
no_decay_params = []
|
|
|
|
|
|
for name, param in model.named_parameters():
|
|
|
if not param.requires_grad:
|
|
|
continue
|
|
|
if param.ndim < 2 or 'bias' in name or 'norm' in name.lower():
|
|
|
no_decay_params.append(param)
|
|
|
else:
|
|
|
decay_params.append(param)
|
|
|
|
|
|
optim_groups = [
|
|
|
{'params': decay_params, 'weight_decay': config.weight_decay},
|
|
|
{'params': no_decay_params, 'weight_decay': 0.0}
|
|
|
]
|
|
|
|
|
|
optimizer_name = config.optimizer.lower()
|
|
|
|
|
|
if optimizer_name == "adamw":
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
optim_groups,
|
|
|
lr=float(config.learning_rate),
|
|
|
betas=(float(config.beta1), float(config.beta2)),
|
|
|
eps=float(config.eps),
|
|
|
weight_decay=float(config.weight_decay),
|
|
|
)
|
|
|
elif optimizer_name == "adamw_scale":
|
|
|
optimizer = AdamWScale(
|
|
|
optim_groups,
|
|
|
lr=float(config.learning_rate),
|
|
|
betas=(float(config.beta1), float(config.beta2)),
|
|
|
eps=float(config.eps),
|
|
|
weight_decay=float(config.weight_decay),
|
|
|
scale_lr=True,
|
|
|
)
|
|
|
elif optimizer_name == "lion":
|
|
|
optimizer = Lion(
|
|
|
optim_groups,
|
|
|
lr=float(config.learning_rate) * 0.3,
|
|
|
betas=(float(config.beta1), float(config.beta2)),
|
|
|
weight_decay=float(config.weight_decay),
|
|
|
)
|
|
|
elif optimizer_name == "sophia":
|
|
|
optimizer = Sophia(
|
|
|
optim_groups,
|
|
|
lr=float(config.learning_rate),
|
|
|
betas=(float(config.beta1), float(config.beta2)),
|
|
|
rho=0.04,
|
|
|
weight_decay=float(config.weight_decay),
|
|
|
)
|
|
|
elif optimizer_name == "sam_adamw":
|
|
|
base_optimizer = torch.optim.AdamW
|
|
|
optimizer = SAM(
|
|
|
optim_groups,
|
|
|
base_optimizer,
|
|
|
rho=0.05,
|
|
|
adaptive=False,
|
|
|
lr=float(config.learning_rate),
|
|
|
betas=(float(config.beta1), float(config.beta2)),
|
|
|
eps=float(config.eps),
|
|
|
weight_decay=float(config.weight_decay),
|
|
|
)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
def get_scheduler(optimizer, config):
|
|
|
"""Get learning rate scheduler"""
|
|
|
scheduler_name = config.lr_scheduler.lower()
|
|
|
|
|
|
if scheduler_name == "cosine":
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
scheduler = CosineAnnealingLR(
|
|
|
optimizer,
|
|
|
T_max=config.total_steps - config.warmup_steps,
|
|
|
eta_min=getattr(config, 'lr_scheduler_kwargs', {}).get('eta_min', 0),
|
|
|
)
|
|
|
elif scheduler_name == "linear":
|
|
|
from torch.optim.lr_scheduler import LinearLR
|
|
|
scheduler = LinearLR(
|
|
|
optimizer,
|
|
|
start_factor=1.0,
|
|
|
end_factor=0.1,
|
|
|
total_iters=config.total_steps - config.warmup_steps,
|
|
|
)
|
|
|
elif scheduler_name == "polynomial":
|
|
|
from torch.optim.lr_scheduler import PolynomialLR
|
|
|
scheduler = PolynomialLR(
|
|
|
optimizer,
|
|
|
total_iters=config.total_steps - config.warmup_steps,
|
|
|
power=getattr(config, 'lr_scheduler_kwargs', {}).get('power', 1.0),
|
|
|
)
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
|
|
|
if config.warmup_steps > 0:
|
|
|
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
|
|
warmup_scheduler = LinearLR(
|
|
|
optimizer,
|
|
|
start_factor=1e-8,
|
|
|
end_factor=1.0,
|
|
|
total_iters=config.warmup_steps
|
|
|
)
|
|
|
|
|
|
scheduler = SequentialLR(
|
|
|
optimizer,
|
|
|
schedulers=[warmup_scheduler, scheduler],
|
|
|
milestones=[config.warmup_steps]
|
|
|
)
|
|
|
|
|
|
return scheduler
|
|
|
|