| | import torch |
| | from torch.optim import Optimizer |
| |
|
| |
|
| | class Nadam(Optimizer): |
| | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). |
| | |
| | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. |
| | |
| | Arguments: |
| | params (iterable): iterable of parameters to optimize or dicts defining |
| | parameter groups |
| | lr (float, optional): learning rate (default: 2e-3) |
| | betas (Tuple[float, float], optional): coefficients used for computing |
| | running averages of gradient and its square |
| | eps (float, optional): term added to the denominator to improve |
| | numerical stability (default: 1e-8) |
| | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
| | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) |
| | |
| | __ http://cs229.stanford.edu/proj2015/054_report.pdf |
| | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf |
| | |
| | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 |
| | NOTE: Has potential issues but does work well on some problems. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | params, |
| | lr=2e-3, |
| | betas=(0.9, 0.999), |
| | eps=1e-8, |
| | weight_decay=0, |
| | schedule_decay=4e-3, |
| | ): |
| | defaults = dict( |
| | lr=lr, |
| | betas=betas, |
| | eps=eps, |
| | weight_decay=weight_decay, |
| | schedule_decay=schedule_decay, |
| | ) |
| | super(Nadam, self).__init__(params, defaults) |
| |
|
| | def step(self, closure=None): |
| | """Performs a single optimization step. |
| | |
| | Arguments: |
| | closure (callable, optional): A closure that reevaluates the model |
| | and returns the loss. |
| | """ |
| | loss = None |
| | if closure is not None: |
| | loss = closure() |
| |
|
| | for group in self.param_groups: |
| | for p in group["params"]: |
| | if p.grad is None: |
| | continue |
| | grad = p.grad.data |
| | state = self.state[p] |
| |
|
| | |
| | if len(state) == 0: |
| | state["step"] = 0 |
| | state["m_schedule"] = 1.0 |
| | state["exp_avg"] = grad.new().resize_as_(grad).zero_() |
| | state["exp_avg_sq"] = grad.new().resize_as_(grad).zero_() |
| |
|
| | |
| | m_schedule = state["m_schedule"] |
| | schedule_decay = group["schedule_decay"] |
| | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
| | beta1, beta2 = group["betas"] |
| | eps = group["eps"] |
| | state["step"] += 1 |
| | t = state["step"] |
| |
|
| | if group["weight_decay"] != 0: |
| | grad = grad.add(group["weight_decay"], p.data) |
| |
|
| | momentum_cache_t = beta1 * (1.0 - 0.5 * (0.96 ** (t * schedule_decay))) |
| | momentum_cache_t_1 = beta1 * ( |
| | 1.0 - 0.5 * (0.96 ** ((t + 1) * schedule_decay)) |
| | ) |
| | m_schedule_new = m_schedule * momentum_cache_t |
| | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 |
| | state["m_schedule"] = m_schedule_new |
| |
|
| | |
| | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) |
| | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) |
| | exp_avg_sq_prime = exp_avg_sq / (1.0 - beta2 ** t) |
| | denom = exp_avg_sq_prime.sqrt_().add_(eps) |
| |
|
| | p.data.addcdiv_( |
| | -group["lr"] * (1.0 - momentum_cache_t) / (1.0 - m_schedule_new), |
| | grad, |
| | denom, |
| | ) |
| | p.data.addcdiv_( |
| | -group["lr"] * momentum_cache_t_1 / (1.0 - m_schedule_next), |
| | exp_avg, |
| | denom, |
| | ) |
| |
|
| | return loss |
| |
|