|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
from torch.optim.optimizer import Optimizer |
|
|
|
|
|
|
|
|
class MultiTensorApply(object): |
|
|
available = False |
|
|
warned = False |
|
|
|
|
|
def __init__(self, chunk_size): |
|
|
try: |
|
|
MultiTensorApply.available = True |
|
|
self.chunk_size = chunk_size |
|
|
except ImportError as err: |
|
|
MultiTensorApply.available = False |
|
|
MultiTensorApply.import_err = err |
|
|
|
|
|
def __call__(self, op, noop_flag_buffer, tensor_lists, *args): |
|
|
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) |
|
|
|
|
|
|
|
|
class Adan(Optimizer): |
|
|
""" |
|
|
Implements a pytorch variant of Adan |
|
|
Adan was proposed in |
|
|
Adan: Adaptive Nesterov Momentum Algorithm for |
|
|
Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. |
|
|
https://arxiv.org/abs/2208.06677 |
|
|
Arguments: |
|
|
params (iterable): iterable of parameters to optimize or |
|
|
dicts defining parameter groups. |
|
|
lr (float, optional): learning rate. (default: 1e-3) |
|
|
betas (Tuple[float, float, flot], optional): coefficients used for |
|
|
first- and second-order moments. (default: (0.98, 0.92, 0.99)) |
|
|
eps (float, optional): term added to the denominator to improve |
|
|
numerical stability. (default: 1e-8) |
|
|
weight_decay (float, optional): decoupled weight decay |
|
|
(L2 penalty) (default: 0) |
|
|
max_grad_norm (float, optional): value used to clip |
|
|
global grad norm (default: 0.0 no clip) |
|
|
no_prox (bool): how to perform the decoupled weight decay |
|
|
(default: False) |
|
|
foreach (bool): if True would use torch._foreach implementation. |
|
|
It's faster but uses slightly more memory. (default: True) |
|
|
fused (bool, optional): whether fused implementation is used. |
|
|
(default: False) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
lr=1e-3, |
|
|
betas=(0.98, 0.92, 0.99), |
|
|
eps=1e-8, |
|
|
weight_decay=0.0, |
|
|
max_grad_norm=0.0, |
|
|
no_prox=False, |
|
|
foreach: bool = True, |
|
|
fused: bool = False, |
|
|
): |
|
|
if not 0.0 <= max_grad_norm: |
|
|
raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) |
|
|
if not 0.0 <= lr: |
|
|
raise ValueError('Invalid learning rate: {}'.format(lr)) |
|
|
if not 0.0 <= eps: |
|
|
raise ValueError('Invalid epsilon value: {}'.format(eps)) |
|
|
if not 0.0 <= betas[0] < 1.0: |
|
|
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) |
|
|
if not 0.0 <= betas[1] < 1.0: |
|
|
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) |
|
|
if not 0.0 <= betas[2] < 1.0: |
|
|
raise ValueError('Invalid beta parameter at index 2: {}'.format(betas[2])) |
|
|
defaults = dict( |
|
|
lr=lr, |
|
|
betas=betas, |
|
|
eps=eps, |
|
|
weight_decay=weight_decay, |
|
|
max_grad_norm=max_grad_norm, |
|
|
no_prox=no_prox, |
|
|
foreach=foreach, |
|
|
fused=fused, |
|
|
) |
|
|
super().__init__(params, defaults) |
|
|
|
|
|
def __setstate__(self, state): |
|
|
super(Adan, self).__setstate__(state) |
|
|
for group in self.param_groups: |
|
|
group.setdefault('no_prox', False) |
|
|
|
|
|
@torch.no_grad() |
|
|
def restart_opt(self): |
|
|
for group in self.param_groups: |
|
|
group['step'] = 0 |
|
|
for p in group['params']: |
|
|
if p.requires_grad: |
|
|
state = self.state[p] |
|
|
|
|
|
|
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p) |
|
|
|
|
|
state['exp_avg_sq'] = torch.zeros_like(p) |
|
|
|
|
|
state['exp_avg_diff'] = torch.zeros_like(p) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure=None): |
|
|
"""Performs a single optimization step.""" |
|
|
|
|
|
loss = None |
|
|
if closure is not None: |
|
|
with torch.enable_grad(): |
|
|
loss = closure() |
|
|
|
|
|
if self.defaults['max_grad_norm'] > 0: |
|
|
device = self.param_groups[0]['params'][0].device |
|
|
global_grad_norm = torch.zeros(1, device=device) |
|
|
|
|
|
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) |
|
|
for group in self.param_groups: |
|
|
|
|
|
for p in group['params']: |
|
|
if p.grad is not None: |
|
|
grad = p.grad |
|
|
global_grad_norm.add_(grad.pow(2).sum()) |
|
|
|
|
|
global_grad_norm = torch.sqrt(global_grad_norm) |
|
|
|
|
|
clip_global_grad_norm = torch.clamp(max_grad_norm / (global_grad_norm + group['eps']), max=1.0).item() |
|
|
else: |
|
|
clip_global_grad_norm = 1.0 |
|
|
|
|
|
for group in self.param_groups: |
|
|
params_with_grad = [] |
|
|
grads = [] |
|
|
exp_avgs = [] |
|
|
exp_avg_sqs = [] |
|
|
exp_avg_diffs = [] |
|
|
neg_pre_grads = [] |
|
|
|
|
|
beta1, beta2, beta3 = group['betas'] |
|
|
|
|
|
|
|
|
|
|
|
if 'step' in group: |
|
|
group['step'] += 1 |
|
|
else: |
|
|
group['step'] = 1 |
|
|
|
|
|
bias_correction1 = 1.0 - beta1 ** group['step'] |
|
|
bias_correction2 = 1.0 - beta2 ** group['step'] |
|
|
bias_correction3 = 1.0 - beta3 ** group['step'] |
|
|
|
|
|
for p in group['params']: |
|
|
if p.grad is None: |
|
|
continue |
|
|
params_with_grad.append(p) |
|
|
grads.append(p.grad) |
|
|
|
|
|
state = self.state[p] |
|
|
if len(state) == 0: |
|
|
state['exp_avg'] = torch.zeros_like(p) |
|
|
state['exp_avg_sq'] = torch.zeros_like(p) |
|
|
state['exp_avg_diff'] = torch.zeros_like(p) |
|
|
|
|
|
if 'neg_pre_grad' not in state or group['step'] == 1: |
|
|
state['neg_pre_grad'] = p.grad.clone().mul_(-clip_global_grad_norm) |
|
|
|
|
|
exp_avgs.append(state['exp_avg']) |
|
|
exp_avg_sqs.append(state['exp_avg_sq']) |
|
|
exp_avg_diffs.append(state['exp_avg_diff']) |
|
|
neg_pre_grads.append(state['neg_pre_grad']) |
|
|
|
|
|
kwargs = dict( |
|
|
params=params_with_grad, |
|
|
grads=grads, |
|
|
exp_avgs=exp_avgs, |
|
|
exp_avg_sqs=exp_avg_sqs, |
|
|
exp_avg_diffs=exp_avg_diffs, |
|
|
neg_pre_grads=neg_pre_grads, |
|
|
beta1=beta1, |
|
|
beta2=beta2, |
|
|
beta3=beta3, |
|
|
bias_correction1=bias_correction1, |
|
|
bias_correction2=bias_correction2, |
|
|
bias_correction3_sqrt=math.sqrt(bias_correction3), |
|
|
lr=group['lr'], |
|
|
weight_decay=group['weight_decay'], |
|
|
eps=group['eps'], |
|
|
no_prox=group['no_prox'], |
|
|
clip_global_grad_norm=clip_global_grad_norm, |
|
|
) |
|
|
|
|
|
if group['foreach']: |
|
|
if group['fused']: |
|
|
if torch.cuda.is_available(): |
|
|
_fused_adan_multi_tensor(**kwargs) |
|
|
else: |
|
|
raise ValueError('Fused Adan does not support CPU') |
|
|
else: |
|
|
_multi_tensor_adan(**kwargs) |
|
|
elif group['fused']: |
|
|
if torch.cuda.is_available(): |
|
|
_fused_adan_single_tensor(**kwargs) |
|
|
else: |
|
|
raise ValueError('Fused Adan does not support CPU') |
|
|
else: |
|
|
_single_tensor_adan(**kwargs) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def _single_tensor_adan( |
|
|
params: List[Tensor], |
|
|
grads: List[Tensor], |
|
|
exp_avgs: List[Tensor], |
|
|
exp_avg_sqs: List[Tensor], |
|
|
exp_avg_diffs: List[Tensor], |
|
|
neg_pre_grads: List[Tensor], |
|
|
*, |
|
|
beta1: float, |
|
|
beta2: float, |
|
|
beta3: float, |
|
|
bias_correction1: float, |
|
|
bias_correction2: float, |
|
|
bias_correction3_sqrt: float, |
|
|
lr: float, |
|
|
weight_decay: float, |
|
|
eps: float, |
|
|
no_prox: bool, |
|
|
clip_global_grad_norm: Tensor, |
|
|
): |
|
|
for i, param in enumerate(params): |
|
|
grad = grads[i] |
|
|
exp_avg = exp_avgs[i] |
|
|
exp_avg_sq = exp_avg_sqs[i] |
|
|
exp_avg_diff = exp_avg_diffs[i] |
|
|
neg_grad_or_diff = neg_pre_grads[i] |
|
|
|
|
|
grad.mul_(clip_global_grad_norm) |
|
|
|
|
|
|
|
|
|
|
|
neg_grad_or_diff.add_(grad) |
|
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
|
|
exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, alpha=1 - beta2) |
|
|
|
|
|
neg_grad_or_diff.mul_(beta2).add_(grad) |
|
|
exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, neg_grad_or_diff, value=1 - beta3) |
|
|
|
|
|
denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) |
|
|
step_size_diff = lr * beta2 / bias_correction2 |
|
|
step_size = lr / bias_correction1 |
|
|
|
|
|
if no_prox: |
|
|
param.mul_(1 - lr * weight_decay) |
|
|
param.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) |
|
|
else: |
|
|
param.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) |
|
|
param.div_(1 + lr * weight_decay) |
|
|
|
|
|
neg_grad_or_diff.zero_().add_(grad, alpha=-1.0) |
|
|
|
|
|
|
|
|
def _multi_tensor_adan( |
|
|
params: List[Tensor], |
|
|
grads: List[Tensor], |
|
|
exp_avgs: List[Tensor], |
|
|
exp_avg_sqs: List[Tensor], |
|
|
exp_avg_diffs: List[Tensor], |
|
|
neg_pre_grads: List[Tensor], |
|
|
*, |
|
|
beta1: float, |
|
|
beta2: float, |
|
|
beta3: float, |
|
|
bias_correction1: float, |
|
|
bias_correction2: float, |
|
|
bias_correction3_sqrt: float, |
|
|
lr: float, |
|
|
weight_decay: float, |
|
|
eps: float, |
|
|
no_prox: bool, |
|
|
clip_global_grad_norm: Tensor, |
|
|
): |
|
|
if len(params) == 0: |
|
|
return |
|
|
|
|
|
torch._foreach_mul_(grads, clip_global_grad_norm) |
|
|
|
|
|
|
|
|
|
|
|
torch._foreach_add_(neg_pre_grads, grads) |
|
|
|
|
|
torch._foreach_mul_(exp_avgs, beta1) |
|
|
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) |
|
|
|
|
|
torch._foreach_mul_(exp_avg_diffs, beta2) |
|
|
torch._foreach_add_(exp_avg_diffs, neg_pre_grads, alpha=1 - beta2) |
|
|
|
|
|
torch._foreach_mul_(neg_pre_grads, beta2) |
|
|
torch._foreach_add_(neg_pre_grads, grads) |
|
|
torch._foreach_mul_(exp_avg_sqs, beta3) |
|
|
torch._foreach_addcmul_(exp_avg_sqs, neg_pre_grads, neg_pre_grads, value=1 - beta3) |
|
|
|
|
|
denom = torch._foreach_sqrt(exp_avg_sqs) |
|
|
torch._foreach_div_(denom, bias_correction3_sqrt) |
|
|
torch._foreach_add_(denom, eps) |
|
|
|
|
|
step_size_diff = lr * beta2 / bias_correction2 |
|
|
step_size = lr / bias_correction1 |
|
|
|
|
|
if no_prox: |
|
|
torch._foreach_mul_(params, 1 - lr * weight_decay) |
|
|
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) |
|
|
torch._foreach_addcdiv_(params, exp_avg_diffs, denom, value=-step_size_diff) |
|
|
else: |
|
|
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) |
|
|
torch._foreach_addcdiv_(params, exp_avg_diffs, denom, value=-step_size_diff) |
|
|
torch._foreach_div_(params, 1 + lr * weight_decay) |
|
|
torch._foreach_zero_(neg_pre_grads) |
|
|
torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0) |
|
|
|
|
|
|
|
|
def _fused_adan_multi_tensor( |
|
|
params: List[Tensor], |
|
|
grads: List[Tensor], |
|
|
exp_avgs: List[Tensor], |
|
|
exp_avg_sqs: List[Tensor], |
|
|
exp_avg_diffs: List[Tensor], |
|
|
neg_pre_grads: List[Tensor], |
|
|
*, |
|
|
beta1: float, |
|
|
beta2: float, |
|
|
beta3: float, |
|
|
bias_correction1: float, |
|
|
bias_correction2: float, |
|
|
bias_correction3_sqrt: float, |
|
|
lr: float, |
|
|
weight_decay: float, |
|
|
eps: float, |
|
|
no_prox: bool, |
|
|
clip_global_grad_norm: Tensor, |
|
|
): |
|
|
import fused_adan |
|
|
|
|
|
multi_tensor_applier = MultiTensorApply(2048 * 32) |
|
|
_dummy_overflow_buf = torch.cuda.IntTensor([0]) |
|
|
multi_tensor_applier( |
|
|
fused_adan.adan_multi_tensor, |
|
|
_dummy_overflow_buf, |
|
|
[params, grads, exp_avgs, exp_avg_sqs, exp_avg_diffs, neg_pre_grads], |
|
|
beta1, |
|
|
beta2, |
|
|
beta3, |
|
|
bias_correction1, |
|
|
bias_correction2, |
|
|
bias_correction3_sqrt, |
|
|
lr, |
|
|
weight_decay, |
|
|
eps, |
|
|
no_prox, |
|
|
clip_global_grad_norm, |
|
|
) |
|
|
torch._foreach_zero_(neg_pre_grads) |
|
|
torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0) |
|
|
|
|
|
|
|
|
def _fused_adan_single_tensor( |
|
|
params: List[Tensor], |
|
|
grads: List[Tensor], |
|
|
exp_avgs: List[Tensor], |
|
|
exp_avg_sqs: List[Tensor], |
|
|
exp_avg_diffs: List[Tensor], |
|
|
neg_pre_grads: List[Tensor], |
|
|
*, |
|
|
beta1: float, |
|
|
beta2: float, |
|
|
beta3: float, |
|
|
bias_correction1: float, |
|
|
bias_correction2: float, |
|
|
bias_correction3_sqrt: float, |
|
|
lr: float, |
|
|
weight_decay: float, |
|
|
eps: float, |
|
|
no_prox: bool, |
|
|
clip_global_grad_norm: Tensor, |
|
|
): |
|
|
for i, param in enumerate(params): |
|
|
p_data_fp32 = param.data.float() |
|
|
out_p = param.data |
|
|
grad = grads[i] |
|
|
exp_avg = exp_avgs[i] |
|
|
exp_avg_sq = exp_avg_sqs[i] |
|
|
exp_avg_diff = exp_avg_diffs[i] |
|
|
neg_grad = neg_pre_grads[i] |
|
|
with torch.cuda.device(param.device): |
|
|
import fused_adan |
|
|
|
|
|
fused_adan.adan_single_tensor( |
|
|
p_data_fp32, |
|
|
out_p, |
|
|
grad, |
|
|
exp_avg, |
|
|
exp_avg_sq, |
|
|
exp_avg_diff, |
|
|
neg_grad, |
|
|
beta1, |
|
|
beta2, |
|
|
beta3, |
|
|
bias_correction1, |
|
|
bias_correction2, |
|
|
bias_correction3_sqrt, |
|
|
lr, |
|
|
weight_decay, |
|
|
eps, |
|
|
no_prox, |
|
|
clip_global_grad_norm, |
|
|
) |
|
|
neg_grad.zero_().add_(grad, alpha=-1.0) |
|
|
|