Spaces:
Running
Running
| import torch | |
| import math | |
| from torch.nn.utils import clip_grad_norm_, clip_grad_value_ | |
| from typing import Union, Iterable, Tuple, Callable, List | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import pdb | |
| import numpy as np | |
| import copy | |
| import random | |
| inf = math.inf | |
| def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float: | |
| """ | |
| Overview: | |
| calculate grad norm of the parameters whose grad norms are not None in the model. | |
| Arguments: | |
| - model: torch.nn.Module | |
| - norm_type (:obj:`int` or `inf`) | |
| """ | |
| parameters = list(filter(lambda p: p.grad is not None, model.parameters())) | |
| if parameters == []: | |
| parameters = 0 | |
| return 0 | |
| if norm_type == 'inf': | |
| total_norm = max(p.grad.data.abs().max() for p in parameters) | |
| return float(total_norm) | |
| else: | |
| total_norm = 0 | |
| for p in parameters: | |
| param_norm = p.grad.data.norm(norm_type) | |
| total_norm += param_norm.item() ** norm_type | |
| total_norm = total_norm ** (1. / norm_type) | |
| return float(total_norm) | |
| def calculate_grad_norm_without_bias_two_norm(model: torch.nn.Module) -> float: | |
| """ | |
| Overview: | |
| calculate grad norm of the parameters whose grad norms are not None in the model. | |
| Arguments: | |
| - model: torch.nn.Module | |
| """ | |
| _list = [] | |
| for name, param in model.named_parameters(): | |
| if 'bias' not in name and param.requires_grad: | |
| if param.grad is None: | |
| return 0 | |
| _list.append(param.grad.data.norm(2).item() ** 2) | |
| return float(sum(_list) ** (1. / 2)) | |
| def grad_ignore_norm(parameters, max_norm, norm_type=2): | |
| """ | |
| Overview: | |
| Clip the gradient norm of an iterable of parameters. | |
| Arguments: | |
| - parameters (:obj:`Iterable`): an iterable of torch.Tensor | |
| - max_norm (:obj:`float`): the max norm of the gradients | |
| - norm_type (:obj:`float`): 2.0 means use norm2 to clip | |
| """ | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| parameters = list(filter(lambda p: p.grad is not None, parameters)) | |
| max_norm = float(max_norm) | |
| norm_type = float(norm_type) | |
| if norm_type == inf: | |
| total_norm = max(p.grad.data.abs().max() for p in parameters) | |
| else: | |
| total_norm = 0 | |
| for p in parameters: | |
| param_norm = p.grad.data.norm(norm_type) | |
| total_norm += param_norm.item() ** norm_type | |
| total_norm = total_norm ** (1. / norm_type) | |
| clip_coef = max_norm / (total_norm + 1e-6) | |
| if clip_coef < 1: | |
| for p in parameters: | |
| p.grad.zero_() | |
| return total_norm | |
| def grad_ignore_value(parameters, clip_value): | |
| """ | |
| Overview: | |
| Clip the gradient value of an iterable of parameters. | |
| Arguments: | |
| - parameters (:obj:`Iterable`): an iterable of torch.Tensor | |
| - clip_value (:obj:`float`): the value to start clipping | |
| """ | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| clip_value = float(clip_value) | |
| flag = False | |
| for p in filter(lambda p: p.grad is not None, parameters): | |
| val = p.grad.data.abs().max() | |
| if val >= clip_value: | |
| flag = True | |
| break | |
| if flag: | |
| for p in filter(lambda p: p.grad is not None, parameters): | |
| p.grad.data.zero_() | |
| class Adam(torch.optim.Adam): | |
| """ | |
| Overview: | |
| Rewrited Adam optimizer to support more features. | |
| Interfaces: | |
| ``__init__``, ``step``, ``_state_init``, ``get_grad`` | |
| """ | |
| def __init__( | |
| self, | |
| params: Iterable, | |
| lr: float = 1e-3, | |
| betas: Tuple[float, float] = (0.9, 0.999), | |
| eps: float = 1e-8, | |
| weight_decay: float = 0, | |
| amsgrad: bool = False, | |
| optim_type: str = 'adam', | |
| grad_clip_type: str = None, | |
| clip_value: Union[float, None] = None, | |
| clip_coef: float = 5, | |
| clip_norm_type: float = 2.0, | |
| clip_momentum_timestep: int = 100, | |
| grad_norm_type: str = None, | |
| grad_ignore_type: str = None, | |
| ignore_value: Union[float, None] = None, | |
| ignore_coef: float = 5, | |
| ignore_norm_type: float = 2.0, | |
| ignore_momentum_timestep: int = 100, | |
| ): | |
| """ | |
| Overview: | |
| init method of refactored Adam class | |
| Arguments: | |
| - params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \ | |
| Specifies what Tensors should be optimized | |
| - lr (:obj:`float`): learning rate, default set to 1e-3 | |
| - betas (:obj:`Tuple[float, float]`): coefficients used for computing running averages of gradient and its\ | |
| square, default set to (0.9, 0.999)) | |
| - eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8 | |
| - weight_decay (:obj:`float`): weight decay coefficient, deault set to 0 | |
| - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\ | |
| On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237> | |
| - optim_type (:obj:str): support ["adam", "adamw"] | |
| - grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \ | |
| 'clip_momentum_norm'] | |
| - clip_value (:obj:`float`): the value to start clipping | |
| - clip_coef (:obj:`float`): the cliping coefficient | |
| - clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip | |
| - clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping | |
| - grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \ | |
| 'ignore_momentum_norm'] | |
| - ignore_value (:obj:`float`): the value to start ignoring | |
| - ignore_coef (:obj:`float`): the ignoreing coefficient | |
| - ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore | |
| - ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring | |
| """ | |
| self._support_type = { | |
| 'optim': ['adam', 'adamw'], | |
| 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'], | |
| 'grad_norm': [None], | |
| 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'], | |
| } | |
| assert optim_type in self._support_type['optim'] | |
| assert grad_clip_type in self._support_type['grad_clip'] | |
| assert grad_norm_type in self._support_type['grad_norm'] | |
| assert grad_ignore_type in self._support_type['grad_ignore'] | |
| if grad_clip_type: | |
| assert clip_value is not None | |
| if grad_ignore_type: | |
| assert ignore_value is not None | |
| self._optim_type = optim_type | |
| self._grad_clip_type = grad_clip_type | |
| self._grad_norm_type = grad_norm_type | |
| self._grad_ignore_type = grad_ignore_type | |
| self._clip_value = clip_value | |
| self._clip_norm_type = clip_norm_type | |
| self._clip_coef = clip_coef | |
| self._ignore_value = ignore_value | |
| self._ignore_norm_type = ignore_norm_type | |
| self._ignore_coef = ignore_coef | |
| self._clip_momentum_timestep = clip_momentum_timestep | |
| self._ignore_momentum_timestep = ignore_momentum_timestep | |
| if self._optim_type == 'adamw': | |
| self._weight_decay = weight_decay | |
| super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=0, amsgrad=amsgrad) | |
| elif self._optim_type == 'adam': | |
| super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) | |
| else: | |
| raise NotImplementedError( | |
| "optimizer type {} is not implemented, support type is {}".format( | |
| self._optim_type, self._support_type['optim'] | |
| ) | |
| ) | |
| def _state_init(self, p, amsgrad): | |
| """ | |
| Overview: | |
| Initialize the state of the optimizer | |
| Arguments: | |
| - p (:obj:`torch.Tensor`): the parameter to be optimized | |
| - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\ | |
| On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237> | |
| """ | |
| state = self.state[p] | |
| state['thre_exp_avg_sq'] = torch.zeros_like(p.data, device=p.data.device) | |
| # others | |
| if torch.__version__ < "1.12.0": | |
| state['step'] = 0 | |
| # TODO | |
| # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 | |
| else: | |
| state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ | |
| if self.defaults['capturable'] else torch.tensor(0.) | |
| state['exp_avg'] = torch.zeros_like(p.data) | |
| # Exponential moving average of squared gradient values | |
| state['exp_avg_sq'] = torch.zeros_like(p.data) | |
| if amsgrad: | |
| # Maintains max of all exp. moving avg. of sq. grad. values | |
| state['max_exp_avg_sq'] = torch.zeros_like(p.data) | |
| def step(self, closure: Union[Callable, None] = None): | |
| """ | |
| Overview: | |
| Performs a single optimization step | |
| Arguments: | |
| - closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None | |
| """ | |
| # clipping | |
| new_params = [ | |
| t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None | |
| ] | |
| if self._grad_clip_type == 'clip_value': | |
| clip_grad_value_(new_params, self._clip_value) | |
| elif self._grad_clip_type == 'clip_norm': | |
| clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type) | |
| elif self._grad_clip_type == 'clip_momentum': | |
| ''' | |
| This is the implimentation mimic the clip used in OPENAI, quote: | |
| 'Gradients are additionally clipped per parameter to be within between ±5√v | |
| where v is the running estimate of the second moment of the (unclipped) gradient' | |
| ''' | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['amsgrad']) | |
| grad = p.grad.data | |
| # should we use same beta group? | |
| beta1, beta2 = group['betas'] | |
| bias_correction2 = 1 - beta2 ** state['step'] | |
| state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
| if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate | |
| flag = grad.abs( | |
| ) > (state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._clip_coef | |
| grad.mul_(~flag).add_( | |
| ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * | |
| self._clip_coef).mul_(flag) | |
| ) | |
| elif self._grad_clip_type == 'clip_momentum_norm': | |
| # might have multi param_group, we should calculate each group differently. | |
| for group in self.param_groups: | |
| total_norm = 0 | |
| total_momentum_norm = 0 | |
| step = inf | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['amsgrad']) | |
| grad = p.grad.data | |
| # should we use same beta group? | |
| beta1, beta2 = group['betas'] | |
| bias_correction2 = 1 - beta2 ** state['step'] | |
| state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
| # sum total_norm | |
| param_norm = grad.norm(self._clip_norm_type) | |
| total_norm += param_norm.item() ** self._clip_norm_type | |
| # sum momentum_norm | |
| momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * | |
| self._clip_coef).norm(self._clip_norm_type) | |
| total_momentum_norm += momentum.item() ** self._clip_norm_type | |
| step = min(step, state['step']) | |
| if step > self._clip_momentum_timestep: | |
| total_norm = total_norm ** (1. / self._clip_norm_type) | |
| total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type) | |
| clip_coef = total_momentum_norm / (total_norm + 1e-6) | |
| if clip_coef < 1: | |
| for p in group['params']: | |
| p.grad.data.mul_(clip_coef) | |
| if self._grad_ignore_type == 'ignore_value': | |
| grad_ignore_value(new_params, self._ignore_value) | |
| elif self._grad_ignore_type == 'ignore_norm': | |
| grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type) | |
| elif self._grad_ignore_type == 'ignore_momentum': | |
| flag = False | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['amsgrad']) | |
| grad = p.grad.data | |
| # should we use same beta group? | |
| beta1, beta2 = group['betas'] | |
| bias_correction2 = 1 - beta2 ** state['step'] | |
| state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
| if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate | |
| if grad.abs() > (state['thre_exp_avg_sq'].sqrt() / | |
| math.sqrt(bias_correction2)) * self._ignore_coef: | |
| flag = True | |
| break | |
| else: | |
| continue | |
| break | |
| if flag: | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| p.grad.zero_() | |
| elif self._grad_ignore_type == 'ignore_momentum_norm': | |
| # might have multi param_group, we should calculate each group differently. | |
| step = inf | |
| for group in self.param_groups: | |
| total_norm = 0 | |
| total_momentum_norm = 0 | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['amsgrad']) | |
| grad = p.grad.data | |
| # should we use same beta group? | |
| beta1, beta2 = group['betas'] | |
| bias_correction2 = 1 - beta2 ** state['step'] | |
| state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
| # sum total_norm | |
| param_norm = grad.norm(self._ignore_norm_type) | |
| total_norm += param_norm.item() ** self._ignore_norm_type | |
| # sum momentum_norm | |
| momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * | |
| self._ignore_coef).norm(self._ignore_norm_type) | |
| total_momentum_norm += momentum.item() ** self._ignore_norm_type | |
| step = min(step, state['step']) | |
| if step > self._ignore_momentum_timestep: | |
| total_norm = total_norm ** (1. / self._ignore_norm_type) | |
| total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type) | |
| ignore_coef = total_momentum_norm / (total_norm + 1e-6) | |
| if ignore_coef < 1: | |
| for p in group['params']: | |
| p.grad.zero_() | |
| # Adam optim type | |
| if self._optim_type == 'adamw': | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| p.data = p.data.add(-self._weight_decay * group['lr'], p.data) | |
| return super().step(closure=closure) | |
| elif self._optim_type == 'adam': | |
| return super().step(closure=closure) | |
| def get_grad(self) -> float: | |
| total_norm = 0. | |
| params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None] | |
| for p in params: | |
| param_norm = p.grad.data.norm(self._clip_norm_type) | |
| total_norm += param_norm.item() ** self._clip_norm_type | |
| return total_norm | |
| class RMSprop(torch.optim.RMSprop): | |
| r""" | |
| Overview: | |
| Rewrited RMSprop optimizer to support more features. | |
| Interfaces: | |
| ``__init__``, ``step``, ``_state_init``, ``get_grad`` | |
| """ | |
| def __init__( | |
| self, | |
| params: Iterable, | |
| lr: float = 1e-2, | |
| alpha: float = 0.99, | |
| eps: float = 1e-8, | |
| weight_decay: float = 0, | |
| momentum: float = 0, | |
| centered: bool = False, | |
| grad_clip_type: str = None, | |
| clip_value: Union[float, None] = None, | |
| clip_coef: float = 5, | |
| clip_norm_type: float = 2.0, | |
| clip_momentum_timestep: int = 100, | |
| grad_norm_type: str = None, | |
| grad_ignore_type: str = None, | |
| ignore_value: Union[float, None] = None, | |
| ignore_coef: float = 5, | |
| ignore_norm_type: float = 2.0, | |
| ignore_momentum_timestep: int = 100, | |
| ): | |
| """ | |
| Overview: | |
| init method of refactored Adam class | |
| Arguments: | |
| - params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \ | |
| Specifies what Tensors should be optimized | |
| - lr (:obj:`float`): learning rate, default set to 1e-3 | |
| - alpha (:obj:`float`): smoothing constant, default set to 0.99 | |
| - eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8 | |
| - weight_decay (:obj:`float`): weight decay coefficient, deault set to 0 | |
| - centred (:obj:`bool`): if True, compute the centered RMSprop, \ | |
| the gradient is normalized by an estimation of its variance | |
| - grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \ | |
| 'clip_momentum_norm'] | |
| - clip_value (:obj:`float`): the value to start clipping | |
| - clip_coef (:obj:`float`): the cliping coefficient | |
| - clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip | |
| - clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping | |
| - grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \ | |
| 'ignore_momentum_norm'] | |
| - ignore_value (:obj:`float`): the value to start ignoring | |
| - ignore_coef (:obj:`float`): the ignoreing coefficient | |
| - ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore | |
| - ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring | |
| """ | |
| self._support_type = { | |
| 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'], | |
| 'grad_norm': [None], | |
| 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'], | |
| } | |
| assert grad_clip_type in self._support_type['grad_clip'] | |
| assert grad_norm_type in self._support_type['grad_norm'] | |
| assert grad_ignore_type in self._support_type['grad_ignore'] | |
| if grad_clip_type: | |
| assert clip_value is not None | |
| if grad_ignore_type: | |
| assert ignore_value is not None | |
| self._grad_clip_type = grad_clip_type | |
| self._grad_norm_type = grad_norm_type | |
| self._grad_ignore_type = grad_ignore_type | |
| self._clip_value = clip_value | |
| self._clip_norm_type = clip_norm_type | |
| self._clip_coef = clip_coef | |
| self._ignore_value = ignore_value | |
| self._ignore_norm_type = ignore_norm_type | |
| self._ignore_coef = ignore_coef | |
| self._clip_momentum_timestep = clip_momentum_timestep | |
| self._ignore_momentum_timestep = ignore_momentum_timestep | |
| super(RMSprop, self).__init__( | |
| params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum, centered=centered | |
| ) | |
| def _state_init(self, p, momentum, centered): | |
| """ | |
| Overview: | |
| Initialize the state of the optimizer | |
| Arguments: | |
| - p (:obj:`torch.Tensor`): the parameter to be optimized | |
| - momentum (:obj:`float`): the momentum coefficient | |
| - centered (:obj:`bool`): if True, compute the centered RMSprop, \ | |
| the gradient is normalized by an estimation of its variance | |
| """ | |
| state = self.state[p] | |
| state['step'] = 0 | |
| state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device) | |
| state['square_avg'] = torch.zeros_like(p.data, device=p.data.device) | |
| if momentum: | |
| state['momentum_buffer'] = torch.zeros_like(p.data, device=p.data.device) | |
| if centered: | |
| state['grad_avg'] = torch.zeros_like(p.data, device=p.data.device) | |
| def step(self, closure: Union[Callable, None] = None): | |
| """ | |
| Overview: | |
| Performs a single optimization step | |
| Arguments: | |
| - closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None | |
| """ | |
| # clipping | |
| new_params = [ | |
| t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None | |
| ] | |
| if self._grad_clip_type == 'clip_value': | |
| clip_grad_value_(new_params, self._clip_value) | |
| elif self._grad_clip_type == 'clip_norm': | |
| clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type) | |
| elif self._grad_clip_type == 'clip_momentum': | |
| ''' | |
| This implementation mimics the clip used in OPENAI, quote: | |
| 'Gradients are additionally clipped per parameter to be within between ±5√v | |
| where v is the running estimate of the second moment of the (unclipped) gradient' | |
| ''' | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['momentum'], group['centered']) | |
| grad = p.grad.data | |
| # beta1, beta2 = group['betas'] | |
| alpha = group['alpha'] | |
| state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
| if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate | |
| flag = grad.abs() > state['thre_square_avg'].sqrt() * self._clip_coef | |
| grad.mul_(~flag).add_((state['thre_square_avg'].sqrt() * self._clip_coef).mul_(flag)) | |
| elif self._grad_clip_type == 'clip_momentum_norm': | |
| # might have multi param_group, we should calculate each group differently. | |
| for group in self.param_groups: | |
| total_norm = 0 | |
| total_momentum_norm = 0 | |
| step = inf | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['momentum'], group['centered']) | |
| grad = p.grad.data | |
| alpha = group['alpha'] | |
| state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
| # sum total_norm | |
| param_norm = grad.norm(self._clip_norm_type) | |
| total_norm += param_norm.item() ** self._clip_norm_type | |
| # sum momentum_norm | |
| momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type) | |
| total_momentum_norm += momentum.item() ** self._clip_norm_type | |
| step = min(step, state['step']) | |
| if step > self._clip_momentum_timestep: | |
| total_norm = total_norm ** (1. / self._clip_norm_type) | |
| total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type) | |
| clip_coef = total_momentum_norm / (total_norm + 1e-6) | |
| if clip_coef < 1: | |
| for p in group['params']: | |
| p.grad.data.mul_(clip_coef) | |
| if self._grad_ignore_type == 'ignore_value': | |
| grad_ignore_value(new_params, self._ignore_value) | |
| elif self._grad_ignore_type == 'ignore_norm': | |
| grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type) | |
| elif self._grad_ignore_type == 'ignore_momentum': | |
| flag = False | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['momentum'], group['centered']) | |
| grad = p.grad.data | |
| alpha = group['alpha'] | |
| state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
| if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate | |
| if grad.abs() > state['thre_square_avg'].sqrt() * self._ignore_coef: | |
| flag = True | |
| break | |
| else: | |
| continue | |
| break | |
| if flag: | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| p.grad.zero_() | |
| elif self._grad_ignore_type == 'ignore_momentum_norm': | |
| # might have multi param_group, we should calculate each group differently. | |
| step = inf | |
| for group in self.param_groups: | |
| total_norm = 0 | |
| total_momentum_norm = 0 | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| self._state_init(p, group['momentum'], group['centered']) | |
| grad = p.grad.data | |
| alpha = group['alpha'] | |
| state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
| # sum total_norm | |
| param_norm = grad.norm(self._ignore_norm_type) | |
| total_norm += param_norm.item() ** self._ignore_norm_type | |
| # sum momentum_norm | |
| momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type) | |
| total_momentum_norm += momentum.item() ** self._ignore_norm_type | |
| step = min(step, state['step']) | |
| if step > self._ignore_momentum_timestep: | |
| total_norm = total_norm ** (1. / self._ignore_norm_type) | |
| total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type) | |
| ignore_coef = total_momentum_norm / (total_norm + 1e-6) | |
| if ignore_coef < 1: | |
| for p in group['params']: | |
| p.grad.zero_() | |
| return super().step(closure=closure) | |
| def get_grad(self) -> float: | |
| """ | |
| Overview: | |
| calculate grad norm of the parameters whose grad norms are not None in the model. | |
| """ | |
| total_norm = 0. | |
| params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None] | |
| for p in params: | |
| param_norm = p.grad.data.norm(self._clip_norm_type) | |
| total_norm += param_norm.item() ** self._clip_norm_type | |
| return total_norm | |
| class PCGrad(): | |
| """ | |
| Overview: | |
| PCGrad optimizer to support multi-task. | |
| you can view the paper in the following link https://arxiv.org/pdf/2001.06782.pdf | |
| Interfaces: | |
| ``__init__``, ``zero_grad``, ``step``, ``pc_backward`` | |
| Properties: | |
| - optimizer (:obj:`torch.optim`): the optimizer to be used | |
| """ | |
| def __init__(self, optimizer, reduction='mean'): | |
| """ | |
| Overview: | |
| Initialization of PCGrad optimizer | |
| Arguments: | |
| - optimizer (:obj:`torch.optim`): the optimizer to be used | |
| - reduction (:obj:`str`): the reduction method, support ['mean', 'sum'] | |
| """ | |
| self._optim, self._reduction = optimizer, reduction | |
| def optimizer(self): | |
| """ | |
| Overview: | |
| get the optimizer | |
| """ | |
| return self._optim | |
| def zero_grad(self): | |
| """ | |
| Overview: | |
| clear the gradient of the parameters | |
| """ | |
| return self._optim.zero_grad(set_to_none=True) | |
| def step(self): | |
| """ | |
| Overview: | |
| update the parameters with the gradient | |
| """ | |
| return self._optim.step() | |
| def pc_backward(self, objectives): | |
| """ | |
| Overview: | |
| calculate the gradient of the parameters | |
| Arguments: | |
| - objectives: a list of objectives | |
| """ | |
| grads, shapes, has_grads = self._pack_grad(objectives) | |
| pc_grad = self._project_conflicting(grads, has_grads) | |
| pc_grad = self._unflatten_grad(pc_grad, shapes[0]) | |
| self._set_grad(pc_grad) | |
| return | |
| def _project_conflicting(self, grads, has_grads, shapes=None): | |
| """ | |
| Overview: | |
| project the conflicting gradient to the orthogonal space | |
| Arguments: | |
| - grads (:obj:`list`): a list of the gradient of the parameters | |
| - has_grads (:obj:`list`): a list of mask represent whether the parameter has gradient | |
| - shapes (:obj:`list`): a list of the shape of the parameters | |
| """ | |
| shared = torch.stack(has_grads).prod(0).bool() | |
| pc_grad, num_task = copy.deepcopy(grads), len(grads) | |
| for g_i in pc_grad: | |
| random.shuffle(grads) | |
| for g_j in grads: | |
| g_i_g_j = torch.dot(g_i, g_j) | |
| if g_i_g_j < 0: | |
| g_i -= (g_i_g_j) * g_j / (g_j.norm() ** 2) | |
| merged_grad = torch.zeros_like(grads[0]).to(grads[0].device) | |
| if self._reduction: | |
| merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0) | |
| elif self._reduction == 'sum': | |
| merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0) | |
| else: | |
| raise KeyError("invalid reduction method") | |
| merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0) | |
| return merged_grad | |
| def _set_grad(self, grads): | |
| """ | |
| Overview: | |
| set the modified gradients to the network | |
| Arguments: | |
| - grads (:obj:`list`): a list of the gradient of the parameters | |
| """ | |
| idx = 0 | |
| for group in self._optim.param_groups: | |
| for p in group['params']: | |
| # if p.grad is None: continue | |
| p.grad = grads[idx] | |
| idx += 1 | |
| return | |
| def _pack_grad(self, objectives): | |
| """ | |
| Overview: | |
| pack the gradient of the parameters of the network for each objective | |
| Arguments: | |
| - objectives: a list of objectives | |
| Returns: | |
| - grad: a list of the gradient of the parameters | |
| - shape: a list of the shape of the parameters | |
| - has_grad: a list of mask represent whether the parameter has gradient | |
| """ | |
| grads, shapes, has_grads = [], [], [] | |
| for obj in objectives: | |
| self._optim.zero_grad(set_to_none=True) | |
| obj.backward(retain_graph=True) | |
| grad, shape, has_grad = self._retrieve_grad() | |
| grads.append(self._flatten_grad(grad, shape)) | |
| has_grads.append(self._flatten_grad(has_grad, shape)) | |
| shapes.append(shape) | |
| return grads, shapes, has_grads | |
| def _unflatten_grad(self, grads, shapes): | |
| """ | |
| Overview: | |
| unflatten the gradient of the parameters of the network | |
| Arguments: | |
| - grads (:obj:`list`): a list of the gradient of the parameters | |
| - shapes (:obj:`list`): a list of the shape of the parameters | |
| """ | |
| unflatten_grad, idx = [], 0 | |
| for shape in shapes: | |
| length = np.prod(shape) | |
| unflatten_grad.append(grads[idx:idx + length].view(shape).clone()) | |
| idx += length | |
| return unflatten_grad | |
| def _flatten_grad(self, grads, shapes): | |
| """ | |
| Overview: | |
| flatten the gradient of the parameters of the network | |
| Arguments: | |
| - grads (:obj:`list`): a list of the gradient of the parameters | |
| - shapes (:obj:`list`): a list of the shape of the parameters | |
| """ | |
| flatten_grad = torch.cat([g.flatten() for g in grads]) | |
| return flatten_grad | |
| def _retrieve_grad(self): | |
| """ | |
| Overview: | |
| get the gradient of the parameters of the network with specific objective | |
| Returns: | |
| - grad: a list of the gradient of the parameters | |
| - shape: a list of the shape of the parameters | |
| - has_grad: a list of mask represent whether the parameter has gradient | |
| """ | |
| grad, shape, has_grad = [], [], [] | |
| for group in self._optim.param_groups: | |
| for p in group['params']: | |
| # if p.grad is None: continue | |
| # tackle the multi-head scenario | |
| if p.grad is None: | |
| shape.append(p.shape) | |
| grad.append(torch.zeros_like(p).to(p.device)) | |
| has_grad.append(torch.zeros_like(p).to(p.device)) | |
| continue | |
| shape.append(p.grad.shape) | |
| grad.append(p.grad.clone()) | |
| has_grad.append(torch.ones_like(p).to(p.device)) | |
| return grad, shape, has_grad | |
| def configure_weight_decay(model: nn.Module, weight_decay: float) -> List: | |
| """ | |
| Overview: | |
| Separating out all parameters of the model into two buckets: those that will experience | |
| weight decay for regularization and those that won't (biases, and layer-norm or embedding weights). | |
| Arguments: | |
| - model (:obj:`nn.Module`): the given PyTorch model. | |
| - weight_decay (:obj:`float`): weight decay value for optimizer. | |
| Returns: | |
| - optim groups (:obj:`List`): the parameter groups to be set in the latter optimizer. | |
| """ | |
| decay = set() | |
| no_decay = set() | |
| whitelist_weight_modules = (torch.nn.Linear, ) | |
| blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
| for mn, m in model.named_modules(): | |
| for pn, p in m.named_parameters(): | |
| fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
| # Because named_modules and named_parameters are recursive | |
| # we will see the same tensors p many times. But doing it this way | |
| # allows us to know which parent module any tensor p belongs to. | |
| if pn.endswith('bias'): | |
| # all biases will not be decayed | |
| no_decay.add(fpn) | |
| elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): | |
| # weights of whitelist modules will be weight decayed | |
| decay.add(fpn) | |
| elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): | |
| # weights of blacklist modules will NOT be weight decayed | |
| no_decay.add(fpn) | |
| else: | |
| decay.add(fpn) | |
| decay = decay - no_decay | |
| # validate that we considered every parameter | |
| param_dict = {pn: p for pn, p in model.named_parameters()} | |
| union_params = decay | no_decay | |
| assert len( | |
| param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ | |
| % (str(param_dict.keys() - union_params),) | |
| optim_groups = [ | |
| { | |
| "params": [param_dict[pn] for pn in sorted(list(decay))], | |
| "weight_decay": weight_decay | |
| }, | |
| { | |
| "params": [param_dict[pn] for pn in sorted(list(no_decay))], | |
| "weight_decay": 0.0 | |
| }, | |
| ] | |
| return optim_groups | |