| """ PyTorch MADGRAD optimizer |
| |
| MADGRAD: https://arxiv.org/abs/2101.11075 |
| |
| Code from: https://github.com/facebookresearch/madgrad |
| """ |
| |
| |
| |
| |
|
|
| import math |
| from typing import TYPE_CHECKING, Any, Callable, Optional |
|
|
| import torch |
| import torch.optim |
|
|
| if TYPE_CHECKING: |
| from torch.optim.optimizer import _params_t |
| else: |
| _params_t = Any |
|
|
|
|
| class MADGRAD(torch.optim.Optimizer): |
| """ |
| MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic |
| Optimization. |
| |
| .. _MADGRAD: https://arxiv.org/abs/2101.11075 |
| |
| MADGRAD is a general purpose optimizer that can be used in place of SGD or |
| Adam may converge faster and generalize better. Currently GPU-only. |
| Typically, the same learning rate schedule that is used for SGD or Adam may |
| be used. The overall learning rate is not comparable to either method and |
| should be determined by a hyper-parameter sweep. |
| |
| MADGRAD requires less weight decay than other methods, often as little as |
| zero. Momentum values used for SGD or Adam's beta1 should work here also. |
| |
| On sparse problems both weight_decay and momentum should be set to 0. |
| |
| Arguments: |
| params (iterable): |
| Iterable of parameters to optimize or dicts defining parameter groups. |
| lr (float): |
| Learning rate (default: 1e-2). |
| momentum (float): |
| Momentum value in the range [0,1) (default: 0.9). |
| weight_decay (float): |
| Weight decay, i.e. a L2 penalty (default: 0). |
| eps (float): |
| Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). |
| """ |
|
|
| def __init__( |
| self, |
| params: _params_t, |
| lr: float = 1e-2, |
| momentum: float = 0.9, |
| weight_decay: float = 0, |
| eps: float = 1e-6, |
| decoupled_decay: bool = False, |
| ): |
| if momentum < 0 or momentum >= 1: |
| raise ValueError(f"Momentum {momentum} must be in the range [0,1]") |
| if lr <= 0: |
| raise ValueError(f"Learning rate {lr} must be positive") |
| if weight_decay < 0: |
| raise ValueError(f"Weight decay {weight_decay} must be non-negative") |
| if eps < 0: |
| raise ValueError(f"Eps must be non-negative") |
|
|
| defaults = dict( |
| lr=lr, |
| eps=eps, |
| momentum=momentum, |
| weight_decay=weight_decay, |
| decoupled_decay=decoupled_decay, |
| ) |
| super().__init__(params, defaults) |
|
|
| @property |
| def supports_memory_efficient_fp16(self) -> bool: |
| return False |
|
|
| @property |
| def supports_flat_params(self) -> bool: |
| return True |
|
|
| @torch.no_grad() |
| def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: |
| """Performs a single optimization step. |
| |
| Arguments: |
| closure (callable, optional): A closure that reevaluates the model and returns the loss. |
| """ |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
|
|
| for group in self.param_groups: |
| eps = group['eps'] |
| lr = group['lr'] + eps |
| weight_decay = group['weight_decay'] |
| momentum = group['momentum'] |
| ck = 1 - momentum |
|
|
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| grad = p.grad |
| if momentum != 0.0 and grad.is_sparse: |
| raise RuntimeError("momentum != 0 is not compatible with sparse gradients") |
|
|
| state = self.state[p] |
| if len(state) == 0: |
| state['step'] = 0 |
| state['grad_sum_sq'] = torch.zeros_like(p) |
| state['s'] = torch.zeros_like(p) |
| if momentum != 0: |
| state['x0'] = torch.clone(p).detach() |
|
|
| state['step'] += 1 |
| grad_sum_sq = state['grad_sum_sq'] |
| s = state['s'] |
| lamb = lr * math.sqrt(state['step']) |
|
|
| |
| if weight_decay != 0: |
| if group['decoupled_decay']: |
| p.mul_(1.0 - group['lr'] * weight_decay) |
| else: |
| if grad.is_sparse: |
| raise RuntimeError("weight_decay option is not compatible with sparse gradients") |
| grad.add_(p, alpha=weight_decay) |
|
|
| if grad.is_sparse: |
| grad = grad.coalesce() |
| grad_val = grad._values() |
|
|
| p_masked = p.sparse_mask(grad) |
| grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) |
| s_masked = s.sparse_mask(grad) |
|
|
| |
| rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) |
| x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) |
|
|
| |
| grad_sq = grad * grad |
| grad_sum_sq.add_(grad_sq, alpha=lamb) |
| grad_sum_sq_masked.add_(grad_sq, alpha=lamb) |
|
|
| rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) |
|
|
| s.add_(grad, alpha=lamb) |
| s_masked._values().add_(grad_val, alpha=lamb) |
|
|
| |
| p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) |
| |
| p_masked._values().add_(p_kp1_masked_vals, alpha=-1) |
| p.add_(p_masked, alpha=-1) |
| else: |
| if momentum == 0: |
| |
| rms = grad_sum_sq.pow(1 / 3).add_(eps) |
| x0 = p.addcdiv(s, rms, value=1) |
| else: |
| x0 = state['x0'] |
|
|
| |
| grad_sum_sq.addcmul_(grad, grad, value=lamb) |
| rms = grad_sum_sq.pow(1 / 3).add_(eps) |
|
|
| |
| s.add_(grad, alpha=lamb) |
|
|
| |
| if momentum == 0: |
| p.copy_(x0.addcdiv(s, rms, value=-1)) |
| else: |
| z = x0.addcdiv(s, rms, value=-1) |
|
|
| |
| p.mul_(1 - ck).add_(z, alpha=ck) |
|
|
| return loss |
|
|