dragon / optimizers /Ademamix.py
alexandretl's picture
CompletedP | memory+norm logging | proper MoE with ScatterMoE, update bias, Latent-MoE | Muon experiments | VE for Mamba3 | fix torch recompiles during varlen training
b9f197c
"""
Adapted from: https://pytorch.org/docs/1.6.0/_modules/torch/optim/adam.html
"""
import math
import torch
from torch.optim import Optimizer
def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1):
if step < warmup:
a = step / float(warmup)
return (1.0-a) * alpha_start + a * alpha_end
return alpha_end
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
def f(beta, eps=1e-8):
return math.log(0.5)/math.log(beta+eps)-1
def f_inv(t):
return math.pow(0.5, 1/(t+1))
if step < warmup:
a = step / float(warmup)
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
return beta_end
class AdEMAMix(Optimizer):
r"""Implements the AdEMAMix algorithm.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999, 0.9999))
corresponding to beta_1, beta_2, beta_3 in AdEMAMix
alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2)
beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None)
alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay as in AdamW (default: 0)
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.95, 0.999), alpha=8.0, #0.999
beta3_warmup=None, alpha_warmup=None, eps=1e-8, normalize_alpha=False,
weight_decay=0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
self.normalize_alpha = normalize_alpha
defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
alpha_warmup=alpha_warmup, weight_decay=weight_decay)
super(AdEMAMix, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdEMAMix, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
lmbda = group["weight_decay"]
eps = group["eps"]
beta1, beta2, beta3_final = group["betas"]
beta3_warmup = group["beta3_warmup"]
alpha_final = group["alpha"]
alpha_warmup = group["alpha_warmup"]
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('AdEMAMix does not support sparse gradients.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
if beta1 != 0.0: # save memory in case beta1 is 0.0
state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format)
else:
state['exp_avg_fast'] = None
state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Compute the effective alpha and beta3 in case warmup is used
if alpha_warmup is not None:
alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup)
else:
alpha = alpha_final
if beta3_warmup is not None:
beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup)
else:
beta3 = beta3_final
# Decay the first and second moment running average coefficient
if beta1 != 0.0:
exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1)
else:
exp_avg_fast = grad
exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
if self.normalize_alpha:
denom = denom * (1.0 + alpha)
update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
# decay
update.add_(p, alpha=lmbda)
p.add_(-lr * update)
return loss
if __name__ == "__main__": # small dummy test
x = torch.randn((10,7))
model = torch.nn.Linear(7, 1, bias=False)
opt = AdEMAMix(params=model.parameters(), lr=1e-2, betas=(0.9, 0.999, 0.9999), alpha=2.0, beta3_warmup=45, alpha_warmup=45, weight_decay=0.1)
print(model.weight)
for itr in range(50):
y = model(x).mean()
opt.zero_grad()
y.backward()
opt.step()
print(model.weight)