Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""
Advanced Optimizers for Large Scale Training
Including Lion, AdamW variants, and SAM
"""
import math
import torch
import torch.nn as nn
from torch.optim import Optimizer
from typing import Any, Dict, Optional, Tuple, Union
class Lion(Optimizer):
"""
Lion optimizer from "Symbolic Discovery of Optimization Algorithms"
More memory efficient than AdamW for large models
"""
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
maximize: bool = False,
foreach: Optional[bool] = None,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
maximize=maximize,
foreach=foreach,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("maximize", False)
group.setdefault("foreach", None)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
beta1, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.dtype in {torch.float16, torch.bfloat16}:
grads.append(p.grad.float())
else:
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state["exp_avg"])
lion(
params_with_grad,
grads,
exp_avgs,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
maximize=group["maximize"],
)
return loss
def lion(
params,
grads,
exp_avgs,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
):
"""Functional API that performs Lion algorithm computation."""
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
# Weight update
update = exp_avg * beta1 + grad * (1 - beta1)
param.add_(torch.sign(update), alpha=-lr)
# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
class AdamWScale(torch.optim.AdamW):
"""
AdamW with learning rate scaling based on parameter norm
Useful for very large models
"""
def __init__(self, *args, scale_lr: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self.scale_lr = scale_lr
def step(self, closure=None):
if not self.scale_lr:
return super().step(closure)
# Scale learning rate based on parameter norm
for group in self.param_groups:
total_norm = 0.0
for p in group['params']:
if p.grad is not None:
param_norm = p.data.norm()
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
# Scale learning rate
if total_norm > 0:
scale = min(1.0, 1.0 / total_norm)
group['lr'] = group['lr'] * scale
return super().step(closure)
class SAM(Optimizer):
"""
Sharpness-Aware Minimization (SAM) optimizer
Improves generalization by finding flatter minima
"""
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None:
continue
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
if zero_grad:
self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update
if zero_grad:
self.zero_grad()
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
self.first_step(zero_grad=True)
closure()
self.second_step()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(dtype=torch.float32).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
dtype=torch.float32
)
return norm
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups
def state_dict(self):
return super().state_dict()
class Sophia(Optimizer):
"""
Sophia optimizer - Second-order clipped stochastic optimization
More efficient than Adam for large language models
"""
def __init__(
self,
params,
lr=1e-4,
betas=(0.965, 0.99),
rho=0.04,
weight_decay=1e-1,
*,
maximize: bool = False,
capturable: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= rho:
raise ValueError(f"Invalid rho parameter: {rho}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr, betas=betas, rho=rho, weight_decay=weight_decay,
maximize=maximize, capturable=capturable
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('maximize', False)
group.setdefault('capturable', False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@torch.no_grad()
def update_hessian(self):
"""Update Hessian diagonal approximation"""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
state = self.state[p]
if len(state) == 0:
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
if group['capturable'] else torch.tensor(0.)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
hessian = state['hessian']
beta2 = group['betas'][1]
# Compute diagonal Hessian approximation
hessian.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
@torch.no_grad()
def step(self, closure=None, bs=5120):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
hessians = []
state_steps = []
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.dtype in {torch.float16, torch.bfloat16}:
grads.append(p.grad.float())
else:
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
if group['capturable'] else torch.tensor(0.)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
hessians.append(state['hessian'])
state_steps.append(state['step'])
sophia(
params_with_grad,
grads,
exp_avgs,
hessians,
state_steps,
beta1=beta1,
beta2=beta2,
rho=group['rho'],
lr=group['lr'],
weight_decay=group['weight_decay'],
maximize=group['maximize'],
capturable=group['capturable'],
)
return loss
def sophia(
params,
grads,
exp_avgs,
hessians,
state_steps,
capturable: bool = False,
*,
beta1: float,
beta2: float,
rho: float,
lr: float,
weight_decay: float,
maximize: bool,
):
"""Functional API that performs Sophia algorithm computation."""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
hessian = hessians[i]
step_t = state_steps[i]
if capturable:
bs = torch.ones_like(step_t) * 5120
assert param.dtype == torch.float32
# update step
step_t += 1
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
# Decay the first moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# Bias correction
bias_correction1 = 1 - beta1 ** step_t.item()
# Clipped update
k = hessian.abs().clamp_(min=1e-8)
u = (exp_avg / bias_correction1) / k.sqrt()
u.clamp_(min=-rho, max=rho)
param.add_(u, alpha=-lr)
def get_optimizer(model, config):
"""Get optimizer based on configuration"""
# Separate parameters for weight decay
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim < 2 or 'bias' in name or 'norm' in name.lower():
no_decay_params.append(param)
else:
decay_params.append(param)
optim_groups = [
{'params': decay_params, 'weight_decay': config.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
]
optimizer_name = config.optimizer.lower()
if optimizer_name == "adamw":
optimizer = torch.optim.AdamW(
optim_groups,
lr=float(config.learning_rate),
betas=(float(config.beta1), float(config.beta2)),
eps=float(config.eps),
weight_decay=float(config.weight_decay),
)
elif optimizer_name == "adamw_scale":
optimizer = AdamWScale(
optim_groups,
lr=float(config.learning_rate),
betas=(float(config.beta1), float(config.beta2)),
eps=float(config.eps),
weight_decay=float(config.weight_decay),
scale_lr=True,
)
elif optimizer_name == "lion":
optimizer = Lion(
optim_groups,
lr=float(config.learning_rate) * 0.3, # Lion typically needs lower LR
betas=(float(config.beta1), float(config.beta2)),
weight_decay=float(config.weight_decay),
)
elif optimizer_name == "sophia":
optimizer = Sophia(
optim_groups,
lr=float(config.learning_rate),
betas=(float(config.beta1), float(config.beta2)),
rho=0.04,
weight_decay=float(config.weight_decay),
)
elif optimizer_name == "sam_adamw":
base_optimizer = torch.optim.AdamW
optimizer = SAM(
optim_groups,
base_optimizer,
rho=0.05,
adaptive=False,
lr=float(config.learning_rate),
betas=(float(config.beta1), float(config.beta2)),
eps=float(config.eps),
weight_decay=float(config.weight_decay),
)
else:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
return optimizer
def get_scheduler(optimizer, config):
"""Get learning rate scheduler"""
scheduler_name = config.lr_scheduler.lower()
if scheduler_name == "cosine":
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(
optimizer,
T_max=config.total_steps - config.warmup_steps,
eta_min=getattr(config, 'lr_scheduler_kwargs', {}).get('eta_min', 0),
)
elif scheduler_name == "linear":
from torch.optim.lr_scheduler import LinearLR
scheduler = LinearLR(
optimizer,
start_factor=1.0,
end_factor=0.1,
total_iters=config.total_steps - config.warmup_steps,
)
elif scheduler_name == "polynomial":
from torch.optim.lr_scheduler import PolynomialLR
scheduler = PolynomialLR(
optimizer,
total_iters=config.total_steps - config.warmup_steps,
power=getattr(config, 'lr_scheduler_kwargs', {}).get('power', 1.0),
)
else:
return None
# Add warmup if specified
if config.warmup_steps > 0:
from torch.optim.lr_scheduler import LinearLR, SequentialLR
warmup_scheduler = LinearLR(
optimizer,
start_factor=1e-8,
end_factor=1.0,
total_iters=config.warmup_steps
)
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, scheduler],
milestones=[config.warmup_steps]
)
return scheduler