CompletedP | memory+norm logging | proper MoE with ScatterMoE, update bias, Latent-MoE | Muon experiments | VE for Mamba3 | fix torch recompiles during varlen training
b9f197c | """ | |
| Adapted from: https://pytorch.org/docs/1.6.0/_modules/torch/optim/adam.html | |
| """ | |
| import math | |
| import torch | |
| from torch.optim import Optimizer | |
| def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): | |
| if step < warmup: | |
| a = step / float(warmup) | |
| return (1.0-a) * alpha_start + a * alpha_end | |
| return alpha_end | |
| def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): | |
| def f(beta, eps=1e-8): | |
| return math.log(0.5)/math.log(beta+eps)-1 | |
| def f_inv(t): | |
| return math.pow(0.5, 1/(t+1)) | |
| if step < warmup: | |
| a = step / float(warmup) | |
| return f_inv((1.0-a) * f(beta_start) + a * f(beta_end)) | |
| return beta_end | |
| class AdEMAMix(Optimizer): | |
| r"""Implements the AdEMAMix algorithm. | |
| Arguments: | |
| params (iterable): iterable of parameters to optimize or dicts defining | |
| parameter groups | |
| lr (float, optional): learning rate (default: 1e-3) | |
| betas (Tuple[float, float, float], optional): coefficients used for computing | |
| running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) | |
| corresponding to beta_1, beta_2, beta_3 in AdEMAMix | |
| alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2) | |
| beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None) | |
| alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None) | |
| eps (float, optional): term added to the denominator to improve | |
| numerical stability (default: 1e-8) | |
| weight_decay (float, optional): weight decay as in AdamW (default: 0) | |
| """ | |
| def __init__(self, params, lr=1e-3, betas=(0.9, 0.95, 0.999), alpha=8.0, #0.999 | |
| beta3_warmup=None, alpha_warmup=None, eps=1e-8, normalize_alpha=False, | |
| weight_decay=0): | |
| if not 0.0 <= lr: | |
| raise ValueError("Invalid learning rate: {}".format(lr)) | |
| if not 0.0 <= eps: | |
| raise ValueError("Invalid epsilon value: {}".format(eps)) | |
| if not 0.0 <= betas[0] < 1.0: | |
| raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |
| if not 0.0 <= betas[1] < 1.0: | |
| raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |
| if not 0.0 <= betas[2] < 1.0: | |
| raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) | |
| if not 0.0 <= weight_decay: | |
| raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |
| if not 0.0 <= alpha: | |
| raise ValueError("Invalid alpha value: {}".format(alpha)) | |
| self.normalize_alpha = normalize_alpha | |
| defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup, | |
| alpha_warmup=alpha_warmup, weight_decay=weight_decay) | |
| super(AdEMAMix, self).__init__(params, defaults) | |
| def __setstate__(self, state): | |
| super(AdEMAMix, self).__setstate__(state) | |
| def step(self, closure=None): | |
| """Performs a single optimization step. | |
| Arguments: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| lr = group["lr"] | |
| lmbda = group["weight_decay"] | |
| eps = group["eps"] | |
| beta1, beta2, beta3_final = group["betas"] | |
| beta3_warmup = group["beta3_warmup"] | |
| alpha_final = group["alpha"] | |
| alpha_warmup = group["alpha_warmup"] | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| grad = p.grad | |
| if grad.is_sparse: | |
| raise RuntimeError('AdEMAMix does not support sparse gradients.') | |
| state = self.state[p] | |
| # State initialization | |
| if len(state) == 0: | |
| state['step'] = 0 | |
| # Exponential moving average of gradient values | |
| if beta1 != 0.0: # save memory in case beta1 is 0.0 | |
| state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| else: | |
| state['exp_avg_fast'] = None | |
| state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| # Exponential moving average of squared gradient values | |
| state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq'] | |
| state['step'] += 1 | |
| bias_correction1 = 1 - beta1 ** state['step'] | |
| bias_correction2 = 1 - beta2 ** state['step'] | |
| # Compute the effective alpha and beta3 in case warmup is used | |
| if alpha_warmup is not None: | |
| alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup) | |
| else: | |
| alpha = alpha_final | |
| if beta3_warmup is not None: | |
| beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup) | |
| else: | |
| beta3 = beta3_final | |
| # Decay the first and second moment running average coefficient | |
| if beta1 != 0.0: | |
| exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1) | |
| else: | |
| exp_avg_fast = grad | |
| exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3) | |
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | |
| denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) | |
| if self.normalize_alpha: | |
| denom = denom * (1.0 + alpha) | |
| update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom | |
| # decay | |
| update.add_(p, alpha=lmbda) | |
| p.add_(-lr * update) | |
| return loss | |
| if __name__ == "__main__": # small dummy test | |
| x = torch.randn((10,7)) | |
| model = torch.nn.Linear(7, 1, bias=False) | |
| opt = AdEMAMix(params=model.parameters(), lr=1e-2, betas=(0.9, 0.999, 0.9999), alpha=2.0, beta3_warmup=45, alpha_warmup=45, weight_decay=0.1) | |
| print(model.weight) | |
| for itr in range(50): | |
| y = model(x).mean() | |
| opt.zero_grad() | |
| y.backward() | |
| opt.step() | |
| print(model.weight) |