|
|
import math |
|
|
import torch |
|
|
from torch import Tensor |
|
|
from .optimizer import Optimizer |
|
|
from typing import List, Optional |
|
|
|
|
|
__all__ = ['AdamW', 'adamw'] |
|
|
|
|
|
class AdamW(Optimizer): |
|
|
r"""Implements AdamW algorithm. |
|
|
|
|
|
.. math:: |
|
|
\begin{aligned} |
|
|
&\rule{110mm}{0.4pt} \\ |
|
|
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 |
|
|
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, |
|
|
\: \epsilon \text{ (epsilon)} \\ |
|
|
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, |
|
|
\: \textit{maximize} \\ |
|
|
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 |
|
|
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] |
|
|
&\rule{110mm}{0.4pt} \\ |
|
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ |
|
|
|
|
|
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ |
|
|
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ |
|
|
&\hspace{5mm}\textbf{else} \\ |
|
|
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ |
|
|
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ |
|
|
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ |
|
|
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ |
|
|
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ |
|
|
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ |
|
|
&\hspace{5mm}\textbf{if} \: amsgrad \\ |
|
|
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, |
|
|
\widehat{v_t}) \\ |
|
|
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ |
|
|
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ |
|
|
&\hspace{5mm}\textbf{else} \\ |
|
|
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ |
|
|
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ |
|
|
&\rule{110mm}{0.4pt} \\[-1.ex] |
|
|
&\bf{return} \: \theta_t \\[-1.ex] |
|
|
&\rule{110mm}{0.4pt} \\[-1.ex] |
|
|
\end{aligned} |
|
|
|
|
|
For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. |
|
|
|
|
|
Args: |
|
|
params (iterable): iterable of parameters to optimize or dicts defining |
|
|
parameter groups |
|
|
lr (float, optional): learning rate (default: 1e-3) |
|
|
betas (Tuple[float, float], optional): coefficients used for computing |
|
|
running averages of gradient and its square (default: (0.9, 0.999)) |
|
|
eps (float, optional): term added to the denominator to improve |
|
|
numerical stability (default: 1e-8) |
|
|
weight_decay (float, optional): weight decay coefficient (default: 1e-2) |
|
|
amsgrad (bool, optional): whether to use the AMSGrad variant of this |
|
|
algorithm from the paper `On the Convergence of Adam and Beyond`_ |
|
|
(default: False) |
|
|
maximize (bool, optional): maximize the params based on the objective, instead of |
|
|
minimizing (default: False) |
|
|
foreach (bool, optional): whether foreach implementation of optimizer |
|
|
is used (default: None) |
|
|
capturable (bool, optional): whether this instance is safe to capture in a CUDA graph. |
|
|
Passing True can impair ungraphed performance, so if you don't intend to |
|
|
graph capture this instance, leave it False (default: False) |
|
|
|
|
|
.. _Decoupled Weight Decay Regularization: |
|
|
https://arxiv.org/abs/1711.05101 |
|
|
.. _On the Convergence of Adam and Beyond: |
|
|
https://openreview.net/forum?id=ryQu7f-RZ |
|
|
""" |
|
|
|
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, |
|
|
weight_decay=1e-2, amsgrad=False, *, maximize: bool = False, |
|
|
foreach: Optional[bool] = None, |
|
|
capturable: bool = False): |
|
|
if not 0.0 <= lr: |
|
|
raise ValueError("Invalid learning rate: {}".format(lr)) |
|
|
if not 0.0 <= eps: |
|
|
raise ValueError("Invalid epsilon value: {}".format(eps)) |
|
|
if not 0.0 <= betas[0] < 1.0: |
|
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) |
|
|
if not 0.0 <= betas[1] < 1.0: |
|
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) |
|
|
if not 0.0 <= weight_decay: |
|
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) |
|
|
defaults = dict(lr=lr, betas=betas, eps=eps, |
|
|
weight_decay=weight_decay, amsgrad=amsgrad, |
|
|
foreach=foreach, maximize=maximize, capturable=capturable) |
|
|
super(AdamW, self).__init__(params, defaults) |
|
|
|
|
|
def __setstate__(self, state): |
|
|
super().__setstate__(state) |
|
|
for group in self.param_groups: |
|
|
group.setdefault('amsgrad', False) |
|
|
group.setdefault('maximize', False) |
|
|
group.setdefault('foreach', None) |
|
|
group.setdefault('capturable', False) |
|
|
state_values = list(self.state.values()) |
|
|
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) |
|
|
if not step_is_tensor: |
|
|
for s in state_values: |
|
|
s['step'] = torch.tensor(float(s['step'])) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Performs a single optimization step. |
|
|
|
|
|
Args: |
|
|
closure (Callable, optional): A closure that reevaluates the model |
|
|
and returns the loss. |
|
|
""" |
|
|
self._cuda_graph_capture_health_check() |
|
|
|
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
params_with_grad = [] |
|
|
grads = [] |
|
|
exp_avgs = [] |
|
|
exp_avg_sqs = [] |
|
|
max_exp_avg_sqs = [] |
|
|
state_steps = [] |
|
|
amsgrad = group['amsgrad'] |
|
|
beta1, beta2 = group['betas'] |
|
|
|
|
|
for p in group['params']: |
|
|
if p.grad is None: |
|
|
continue |
|
|
params_with_grad.append(p) |
|
|
if p.grad.is_sparse: |
|
|
raise RuntimeError('AdamW does not support sparse gradients') |
|
|
grads.append(p.grad) |
|
|
|
|
|
state = self.state[p] |
|
|
|
|
|
|
|
|
if len(state) == 0: |
|
|
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ |
|
|
if self.defaults['capturable'] else torch.tensor(0.) |
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
|
|
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
if amsgrad: |
|
|
|
|
|
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
|
|
|
exp_avgs.append(state['exp_avg']) |
|
|
exp_avg_sqs.append(state['exp_avg_sq']) |
|
|
|
|
|
if amsgrad: |
|
|
max_exp_avg_sqs.append(state['max_exp_avg_sq']) |
|
|
|
|
|
state_steps.append(state['step']) |
|
|
|
|
|
adamw(params_with_grad, |
|
|
grads, |
|
|
exp_avgs, |
|
|
exp_avg_sqs, |
|
|
max_exp_avg_sqs, |
|
|
state_steps, |
|
|
amsgrad=amsgrad, |
|
|
beta1=beta1, |
|
|
beta2=beta2, |
|
|
lr=group['lr'], |
|
|
weight_decay=group['weight_decay'], |
|
|
eps=group['eps'], |
|
|
maximize=group['maximize'], |
|
|
foreach=group['foreach'], |
|
|
capturable=group['capturable']) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def adamw(params: List[Tensor], |
|
|
grads: List[Tensor], |
|
|
exp_avgs: List[Tensor], |
|
|
exp_avg_sqs: List[Tensor], |
|
|
max_exp_avg_sqs: List[Tensor], |
|
|
state_steps: List[Tensor], |
|
|
|
|
|
|
|
|
foreach: bool = None, |
|
|
capturable: bool = False, |
|
|
*, |
|
|
amsgrad: bool, |
|
|
beta1: float, |
|
|
beta2: float, |
|
|
lr: float, |
|
|
weight_decay: float, |
|
|
eps: float, |
|
|
maximize: bool): |
|
|
r"""Functional API that performs AdamW algorithm computation. |
|
|
|
|
|
See :class:`~torch.optim.AdamW` for details. |
|
|
""" |
|
|
|
|
|
if not all(isinstance(t, torch.Tensor) for t in state_steps): |
|
|
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") |
|
|
|
|
|
if foreach is None: |
|
|
|
|
|
foreach = False |
|
|
|
|
|
if foreach and torch.jit.is_scripting(): |
|
|
raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
|
|
|
|
|
if foreach and not torch.jit.is_scripting(): |
|
|
func = _multi_tensor_adamw |
|
|
else: |
|
|
func = _single_tensor_adamw |
|
|
|
|
|
func(params, |
|
|
grads, |
|
|
exp_avgs, |
|
|
exp_avg_sqs, |
|
|
max_exp_avg_sqs, |
|
|
state_steps, |
|
|
amsgrad=amsgrad, |
|
|
beta1=beta1, |
|
|
beta2=beta2, |
|
|
lr=lr, |
|
|
weight_decay=weight_decay, |
|
|
eps=eps, |
|
|
maximize=maximize, |
|
|
capturable=capturable) |
|
|
|
|
|
|
|
|
def _single_tensor_adamw(params: List[Tensor], |
|
|
grads: List[Tensor], |
|
|
exp_avgs: List[Tensor], |
|
|
exp_avg_sqs: List[Tensor], |
|
|
max_exp_avg_sqs: List[Tensor], |
|
|
state_steps: List[Tensor], |
|
|
*, |
|
|
amsgrad: bool, |
|
|
beta1: float, |
|
|
beta2: float, |
|
|
lr: float, |
|
|
weight_decay: float, |
|
|
eps: float, |
|
|
maximize: bool, |
|
|
capturable: bool): |
|
|
|
|
|
for i, param in enumerate(params): |
|
|
grad = grads[i] if not maximize else -grads[i] |
|
|
exp_avg = exp_avgs[i] |
|
|
exp_avg_sq = exp_avg_sqs[i] |
|
|
step_t = state_steps[i] |
|
|
|
|
|
if capturable: |
|
|
assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors." |
|
|
|
|
|
if torch.is_complex(param): |
|
|
grad = torch.view_as_real(grad) |
|
|
exp_avg = torch.view_as_real(exp_avg) |
|
|
exp_avg_sq = torch.view_as_real(exp_avg_sq) |
|
|
param = torch.view_as_real(param) |
|
|
|
|
|
|
|
|
step_t += 1 |
|
|
|
|
|
|
|
|
param.mul_(1 - lr * weight_decay) |
|
|
|
|
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
|
|
|
|
|
if capturable: |
|
|
step = step_t |
|
|
|
|
|
|
|
|
|
|
|
bias_correction1 = 1 - torch.pow(beta1, step) |
|
|
bias_correction2 = 1 - torch.pow(beta2, step) |
|
|
|
|
|
step_size = lr / bias_correction1 |
|
|
step_size_neg = step_size.neg() |
|
|
|
|
|
bias_correction2_sqrt = bias_correction2.sqrt() |
|
|
|
|
|
if amsgrad: |
|
|
|
|
|
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) |
|
|
|
|
|
|
|
|
|
|
|
denom = (max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) |
|
|
else: |
|
|
denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) |
|
|
|
|
|
param.addcdiv_(exp_avg, denom) |
|
|
else: |
|
|
step = step_t.item() |
|
|
|
|
|
bias_correction1 = 1 - beta1 ** step |
|
|
bias_correction2 = 1 - beta2 ** step |
|
|
|
|
|
step_size = lr / bias_correction1 |
|
|
|
|
|
bias_correction2_sqrt = math.sqrt(bias_correction2) |
|
|
|
|
|
if amsgrad: |
|
|
|
|
|
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) |
|
|
|
|
|
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) |
|
|
else: |
|
|
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) |
|
|
|
|
|
param.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
|
|
|
|
|
|
def _multi_tensor_adamw(params: List[Tensor], |
|
|
grads: List[Tensor], |
|
|
exp_avgs: List[Tensor], |
|
|
exp_avg_sqs: List[Tensor], |
|
|
max_exp_avg_sqs: List[Tensor], |
|
|
state_steps: List[Tensor], |
|
|
*, |
|
|
amsgrad: bool, |
|
|
beta1: float, |
|
|
beta2: float, |
|
|
lr: float, |
|
|
weight_decay: float, |
|
|
eps: float, |
|
|
maximize: bool, |
|
|
capturable: bool): |
|
|
if len(params) == 0: |
|
|
return |
|
|
|
|
|
if capturable: |
|
|
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \ |
|
|
"If capturable=True, params and state_steps must be CUDA tensors." |
|
|
|
|
|
if maximize: |
|
|
grads = torch._foreach_neg(tuple(grads)) |
|
|
|
|
|
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] |
|
|
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] |
|
|
exp_avg_sqs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avg_sqs] |
|
|
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] |
|
|
|
|
|
|
|
|
torch._foreach_add_(state_steps, 1) |
|
|
|
|
|
|
|
|
torch._foreach_mul_(params, 1 - lr * weight_decay) |
|
|
|
|
|
|
|
|
torch._foreach_mul_(exp_avgs, beta1) |
|
|
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) |
|
|
|
|
|
torch._foreach_mul_(exp_avg_sqs, beta2) |
|
|
torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2) |
|
|
|
|
|
if capturable: |
|
|
|
|
|
bias_correction1 = [torch.pow(beta1, step) for step in state_steps] |
|
|
bias_correction2 = [torch.pow(beta2, step) for step in state_steps] |
|
|
|
|
|
torch._foreach_sub_(bias_correction1, 1) |
|
|
torch._foreach_sub_(bias_correction2, 1) |
|
|
torch._foreach_neg_(bias_correction1) |
|
|
torch._foreach_neg_(bias_correction2) |
|
|
|
|
|
|
|
|
step_size = torch._foreach_div(bias_correction1, lr) |
|
|
torch._foreach_reciprocal_(step_size) |
|
|
torch._foreach_neg_(step_size) |
|
|
|
|
|
bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2) |
|
|
|
|
|
if amsgrad: |
|
|
|
|
|
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs) |
|
|
|
|
|
|
|
|
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs) |
|
|
|
|
|
|
|
|
torch._foreach_div_(max_exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)) |
|
|
eps_over_step_size = torch._foreach_div(step_size, eps) |
|
|
torch._foreach_reciprocal_(eps_over_step_size) |
|
|
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size) |
|
|
else: |
|
|
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs) |
|
|
torch._foreach_div_(exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)) |
|
|
eps_over_step_size = torch._foreach_div(step_size, eps) |
|
|
torch._foreach_reciprocal_(eps_over_step_size) |
|
|
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size) |
|
|
|
|
|
torch._foreach_addcdiv_(params, exp_avgs, denom) |
|
|
else: |
|
|
bias_correction1 = [1 - beta1 ** step.item() for step in state_steps] |
|
|
bias_correction2 = [1 - beta2 ** step.item() for step in state_steps] |
|
|
|
|
|
step_size = [(lr / bc) * -1 for bc in bias_correction1] |
|
|
|
|
|
bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2] |
|
|
|
|
|
if amsgrad: |
|
|
|
|
|
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs) |
|
|
|
|
|
|
|
|
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs) |
|
|
torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt) |
|
|
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps) |
|
|
else: |
|
|
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs) |
|
|
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) |
|
|
denom = torch._foreach_add(exp_avg_sq_sqrt, eps) |
|
|
|
|
|
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size) |
|
|
|