| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """PyTorch optimization for BERT model.""" |
| |
|
| | import math |
| | import warnings |
| | from functools import partial |
| | from typing import Callable, Iterable, Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.optim import Optimizer |
| | from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau |
| |
|
| | from .trainer_utils import SchedulerType |
| | from .utils import logging |
| | from .utils.versions import require_version |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def _get_constant_lambda(_=None): |
| | return 1 |
| |
|
| |
|
| | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): |
| | """ |
| | Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | """ |
| |
|
| | return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch) |
| |
|
| |
|
| | def get_reduce_on_plateau_schedule(optimizer: Optimizer): |
| | """ |
| | Create a schedule with a constant learning rate that decreases when a metric has stopped improving. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. |
| | """ |
| |
|
| | return ReduceLROnPlateau(optimizer) |
| |
|
| |
|
| | def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): |
| | if current_step < num_warmup_steps: |
| | return float(current_step) / float(max(1.0, num_warmup_steps)) |
| | return 1.0 |
| |
|
| |
|
| | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): |
| | """ |
| | Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate |
| | increases linearly between 0 and the initial lr set in the optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | num_warmup_steps (`int`): |
| | The number of steps for the warmup phase. |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | """ |
| |
|
| | lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps) |
| | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
| |
|
| |
|
| | def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int): |
| | if current_step < num_warmup_steps: |
| | return float(current_step) / float(max(1, num_warmup_steps)) |
| | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) |
| |
|
| |
|
| | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): |
| | """ |
| | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
| | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | num_warmup_steps (`int`): |
| | The number of steps for the warmup phase. |
| | num_training_steps (`int`): |
| | The total number of training steps. |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | """ |
| |
|
| | lr_lambda = partial( |
| | _get_linear_schedule_with_warmup_lr_lambda, |
| | num_warmup_steps=num_warmup_steps, |
| | num_training_steps=num_training_steps, |
| | ) |
| | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| |
|
| |
|
| | def _get_cosine_schedule_with_warmup_lr_lambda( |
| | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float |
| | ): |
| | if current_step < num_warmup_steps: |
| | return float(current_step) / float(max(1, num_warmup_steps)) |
| | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
| |
|
| |
|
| | def get_cosine_schedule_with_warmup( |
| | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 |
| | ): |
| | """ |
| | Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
| | initial lr set in the optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | num_warmup_steps (`int`): |
| | The number of steps for the warmup phase. |
| | num_training_steps (`int`): |
| | The total number of training steps. |
| | num_cycles (`float`, *optional*, defaults to 0.5): |
| | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
| | following a half-cosine). |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | """ |
| |
|
| | lr_lambda = partial( |
| | _get_cosine_schedule_with_warmup_lr_lambda, |
| | num_warmup_steps=num_warmup_steps, |
| | num_training_steps=num_training_steps, |
| | num_cycles=num_cycles, |
| | ) |
| | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| |
|
| |
|
| | def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda( |
| | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int |
| | ): |
| | if current_step < num_warmup_steps: |
| | return float(current_step) / float(max(1, num_warmup_steps)) |
| | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| | if progress >= 1.0: |
| | return 0.0 |
| | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) |
| |
|
| |
|
| | def get_cosine_with_hard_restarts_schedule_with_warmup( |
| | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 |
| | ): |
| | """ |
| | Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| | initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases |
| | linearly between 0 and the initial lr set in the optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | num_warmup_steps (`int`): |
| | The number of steps for the warmup phase. |
| | num_training_steps (`int`): |
| | The total number of training steps. |
| | num_cycles (`int`, *optional*, defaults to 1): |
| | The number of hard restarts to use. |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | """ |
| |
|
| | lr_lambda = partial( |
| | _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda, |
| | num_warmup_steps=num_warmup_steps, |
| | num_training_steps=num_training_steps, |
| | num_cycles=num_cycles, |
| | ) |
| | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| |
|
| |
|
| | def _get_polynomial_decay_schedule_with_warmup_lr_lambda( |
| | current_step: int, |
| | *, |
| | num_warmup_steps: int, |
| | num_training_steps: int, |
| | lr_end: float, |
| | power: float, |
| | lr_init: int, |
| | ): |
| | if current_step < num_warmup_steps: |
| | return float(current_step) / float(max(1, num_warmup_steps)) |
| | elif current_step > num_training_steps: |
| | return lr_end / lr_init |
| | else: |
| | lr_range = lr_init - lr_end |
| | decay_steps = num_training_steps - num_warmup_steps |
| | pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps |
| | decay = lr_range * pct_remaining**power + lr_end |
| | return decay / lr_init |
| |
|
| |
|
| | def get_polynomial_decay_schedule_with_warmup( |
| | optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 |
| | ): |
| | """ |
| | Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the |
| | optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the |
| | initial lr set in the optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | num_warmup_steps (`int`): |
| | The number of steps for the warmup phase. |
| | num_training_steps (`int`): |
| | The total number of training steps. |
| | lr_end (`float`, *optional*, defaults to 1e-7): |
| | The end LR. |
| | power (`float`, *optional*, defaults to 1.0): |
| | Power factor. |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT |
| | implementation at |
| | https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | |
| | """ |
| |
|
| | lr_init = optimizer.defaults["lr"] |
| | if not (lr_init > lr_end): |
| | raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") |
| |
|
| | lr_lambda = partial( |
| | _get_polynomial_decay_schedule_with_warmup_lr_lambda, |
| | num_warmup_steps=num_warmup_steps, |
| | num_training_steps=num_training_steps, |
| | lr_end=lr_end, |
| | power=power, |
| | lr_init=lr_init, |
| | ) |
| | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| |
|
| |
|
| | def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): |
| | if current_step < num_warmup_steps: |
| | return float(current_step) / float(max(1, num_warmup_steps)) |
| | shift = timescale - num_warmup_steps |
| | decay = 1.0 / math.sqrt((current_step + shift) / timescale) |
| | return decay |
| |
|
| |
|
| | def get_inverse_sqrt_schedule( |
| | optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1 |
| | ): |
| | """ |
| | Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a |
| | warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | num_warmup_steps (`int`): |
| | The number of steps for the warmup phase. |
| | timescale (`int`, *optional*, defaults to `num_warmup_steps`): |
| | Time scale. |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | """ |
| | |
| | |
| |
|
| | if timescale is None: |
| | timescale = num_warmup_steps |
| |
|
| | lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale) |
| | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
| |
|
| |
|
| | TYPE_TO_SCHEDULER_FUNCTION = { |
| | SchedulerType.LINEAR: get_linear_schedule_with_warmup, |
| | SchedulerType.COSINE: get_cosine_schedule_with_warmup, |
| | SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, |
| | SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, |
| | SchedulerType.CONSTANT: get_constant_schedule, |
| | SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, |
| | SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, |
| | SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, |
| | } |
| |
|
| |
|
| | def get_scheduler( |
| | name: Union[str, SchedulerType], |
| | optimizer: Optimizer, |
| | num_warmup_steps: Optional[int] = None, |
| | num_training_steps: Optional[int] = None, |
| | ): |
| | """ |
| | Unified API to get any scheduler from its name. |
| | |
| | Args: |
| | name (`str` or `SchedulerType`): |
| | The name of the scheduler to use. |
| | optimizer (`torch.optim.Optimizer`): |
| | The optimizer that will be used during training. |
| | num_warmup_steps (`int`, *optional*): |
| | The number of warmup steps to do. This is not required by all schedulers (hence the argument being |
| | optional), the function will raise an error if it's unset and the scheduler type requires it. |
| | num_training_steps (`int``, *optional*): |
| | The number of training steps to do. This is not required by all schedulers (hence the argument being |
| | optional), the function will raise an error if it's unset and the scheduler type requires it. |
| | """ |
| | name = SchedulerType(name) |
| | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] |
| | if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU: |
| | return schedule_func(optimizer) |
| |
|
| | |
| | if num_warmup_steps is None: |
| | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") |
| |
|
| | if name == SchedulerType.CONSTANT_WITH_WARMUP: |
| | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) |
| |
|
| | if name == SchedulerType.INVERSE_SQRT: |
| | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) |
| |
|
| | |
| | if num_training_steps is None: |
| | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") |
| |
|
| | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) |
| |
|
| |
|
| | class AdamW(Optimizer): |
| | """ |
| | Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay |
| | Regularization](https://arxiv.org/abs/1711.05101). |
| | |
| | Parameters: |
| | params (`Iterable[nn.parameter.Parameter]`): |
| | Iterable of parameters to optimize or dictionaries defining parameter groups. |
| | lr (`float`, *optional*, defaults to 1e-3): |
| | The learning rate to use. |
| | betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)): |
| | Adam's betas parameters (b1, b2). |
| | eps (`float`, *optional*, defaults to 1e-6): |
| | Adam's epsilon for numerical stability. |
| | weight_decay (`float`, *optional*, defaults to 0): |
| | Decoupled weight decay to apply. |
| | correct_bias (`bool`, *optional*, defaults to `True`): |
| | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). |
| | no_deprecation_warning (`bool`, *optional*, defaults to `False`): |
| | A flag used to disable the deprecation warning (set to `True` to disable the warning). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | params: Iterable[nn.parameter.Parameter], |
| | lr: float = 1e-3, |
| | betas: Tuple[float, float] = (0.9, 0.999), |
| | eps: float = 1e-6, |
| | weight_decay: float = 0.0, |
| | correct_bias: bool = True, |
| | no_deprecation_warning: bool = False, |
| | ): |
| | if not no_deprecation_warning: |
| | warnings.warn( |
| | "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" |
| | " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" |
| | " warning", |
| | FutureWarning, |
| | ) |
| | require_version("torch>=1.5.0") |
| | if lr < 0.0: |
| | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") |
| | if not 0.0 <= betas[0] < 1.0: |
| | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") |
| | if not 0.0 <= betas[1] < 1.0: |
| | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") |
| | if not 0.0 <= eps: |
| | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") |
| | defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} |
| | super().__init__(params, defaults) |
| |
|
| | @torch.no_grad() |
| | def step(self, closure: Callable = None): |
| | """ |
| | Performs a single optimization step. |
| | |
| | Arguments: |
| | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. |
| | """ |
| | loss = None |
| | if closure is not None: |
| | loss = closure() |
| |
|
| | for group in self.param_groups: |
| | for p in group["params"]: |
| | if p.grad is None: |
| | continue |
| | grad = p.grad |
| | if grad.is_sparse: |
| | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") |
| |
|
| | state = self.state[p] |
| |
|
| | |
| | if len(state) == 0: |
| | state["step"] = 0 |
| | |
| | state["exp_avg"] = torch.zeros_like(p) |
| | |
| | state["exp_avg_sq"] = torch.zeros_like(p) |
| |
|
| | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
| | beta1, beta2 = group["betas"] |
| |
|
| | state["step"] += 1 |
| |
|
| | |
| | |
| | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) |
| | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) |
| | denom = exp_avg_sq.sqrt().add_(group["eps"]) |
| |
|
| | step_size = group["lr"] |
| | if group["correct_bias"]: |
| | bias_correction1 = 1.0 - beta1 ** state["step"] |
| | bias_correction2 = 1.0 - beta2 ** state["step"] |
| | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 |
| |
|
| | p.addcdiv_(exp_avg, denom, value=-step_size) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if group["weight_decay"] > 0.0: |
| | p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) |
| |
|
| | return loss |
| |
|
| |
|
| | class Adafactor(Optimizer): |
| | """ |
| | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: |
| | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py |
| | |
| | Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that |
| | this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and |
| | `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and |
| | `relative_step=False`. |
| | |
| | Arguments: |
| | params (`Iterable[nn.parameter.Parameter]`): |
| | Iterable of parameters to optimize or dictionaries defining parameter groups. |
| | lr (`float`, *optional*): |
| | The external learning rate. |
| | eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)): |
| | Regularization constants for square gradient and parameter scale respectively |
| | clip_threshold (`float`, *optional*, defaults 1.0): |
| | Threshold of root mean square of final gradient update |
| | decay_rate (`float`, *optional*, defaults to -0.8): |
| | Coefficient used to compute running averages of square |
| | beta1 (`float`, *optional*): |
| | Coefficient used for computing running averages of gradient |
| | weight_decay (`float`, *optional*, defaults to 0): |
| | Weight decay (L2 penalty) |
| | scale_parameter (`bool`, *optional*, defaults to `True`): |
| | If True, learning rate is scaled by root mean square |
| | relative_step (`bool`, *optional*, defaults to `True`): |
| | If True, time-dependent learning rate is computed instead of external learning rate |
| | warmup_init (`bool`, *optional*, defaults to `False`): |
| | Time-dependent learning rate computation depends on whether warm-up initialization is being used |
| | |
| | This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. |
| | |
| | Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): |
| | |
| | - Training without LR warmup or clip_threshold is not recommended. |
| | |
| | - use scheduled LR warm-up to fixed LR |
| | - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) |
| | - Disable relative updates |
| | - Use scale_parameter=False |
| | - Additional optimizer operations like gradient clipping should not be used alongside Adafactor |
| | |
| | Example: |
| | |
| | ```python |
| | Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) |
| | ``` |
| | |
| | Others reported the following combination to work well: |
| | |
| | ```python |
| | Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) |
| | ``` |
| | |
| | When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] |
| | scheduler as following: |
| | |
| | ```python |
| | from transformers.optimization import Adafactor, AdafactorSchedule |
| | |
| | optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) |
| | lr_scheduler = AdafactorSchedule(optimizer) |
| | trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) |
| | ``` |
| | |
| | Usage: |
| | |
| | ```python |
| | # replace AdamW with Adafactor |
| | optimizer = Adafactor( |
| | model.parameters(), |
| | lr=1e-3, |
| | eps=(1e-30, 1e-3), |
| | clip_threshold=1.0, |
| | decay_rate=-0.8, |
| | beta1=None, |
| | weight_decay=0.0, |
| | relative_step=False, |
| | scale_parameter=False, |
| | warmup_init=False, |
| | ) |
| | ```""" |
| |
|
| | def __init__( |
| | self, |
| | params, |
| | lr=None, |
| | eps=(1e-30, 1e-3), |
| | clip_threshold=1.0, |
| | decay_rate=-0.8, |
| | beta1=None, |
| | weight_decay=0.0, |
| | scale_parameter=True, |
| | relative_step=True, |
| | warmup_init=False, |
| | ): |
| | require_version("torch>=1.5.0") |
| | if lr is not None and relative_step: |
| | raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") |
| | if warmup_init and not relative_step: |
| | raise ValueError("`warmup_init=True` requires `relative_step=True`") |
| |
|
| | defaults = { |
| | "lr": lr, |
| | "eps": eps, |
| | "clip_threshold": clip_threshold, |
| | "decay_rate": decay_rate, |
| | "beta1": beta1, |
| | "weight_decay": weight_decay, |
| | "scale_parameter": scale_parameter, |
| | "relative_step": relative_step, |
| | "warmup_init": warmup_init, |
| | } |
| | super().__init__(params, defaults) |
| |
|
| | @staticmethod |
| | def _get_lr(param_group, param_state): |
| | rel_step_sz = param_group["lr"] |
| | if param_group["relative_step"]: |
| | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 |
| | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) |
| | param_scale = 1.0 |
| | if param_group["scale_parameter"]: |
| | param_scale = max(param_group["eps"][1], param_state["RMS"]) |
| | return param_scale * rel_step_sz |
| |
|
| | @staticmethod |
| | def _get_options(param_group, param_shape): |
| | factored = len(param_shape) >= 2 |
| | use_first_moment = param_group["beta1"] is not None |
| | return factored, use_first_moment |
| |
|
| | @staticmethod |
| | def _rms(tensor): |
| | return tensor.norm(2) / (tensor.numel() ** 0.5) |
| |
|
| | @staticmethod |
| | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): |
| | |
| | |
| | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) |
| | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() |
| | return torch.mul(r_factor, c_factor) |
| |
|
| | @torch.no_grad() |
| | def step(self, closure=None): |
| | """ |
| | Performs a single optimization step |
| | |
| | Arguments: |
| | closure (callable, optional): A closure that reevaluates the model |
| | and returns the loss. |
| | """ |
| | loss = None |
| | if closure is not None: |
| | loss = closure() |
| |
|
| | for group in self.param_groups: |
| | for p in group["params"]: |
| | if p.grad is None: |
| | continue |
| | grad = p.grad |
| | if grad.dtype in {torch.float16, torch.bfloat16}: |
| | grad = grad.float() |
| | if grad.is_sparse: |
| | raise RuntimeError("Adafactor does not support sparse gradients.") |
| |
|
| | state = self.state[p] |
| | grad_shape = grad.shape |
| |
|
| | factored, use_first_moment = self._get_options(group, grad_shape) |
| | |
| | if len(state) == 0: |
| | state["step"] = 0 |
| |
|
| | if use_first_moment: |
| | |
| | state["exp_avg"] = torch.zeros_like(grad) |
| | if factored: |
| | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) |
| | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) |
| | else: |
| | state["exp_avg_sq"] = torch.zeros_like(grad) |
| |
|
| | state["RMS"] = 0 |
| | else: |
| | if use_first_moment: |
| | state["exp_avg"] = state["exp_avg"].to(grad) |
| | if factored: |
| | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) |
| | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) |
| | else: |
| | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) |
| |
|
| | p_data_fp32 = p |
| | if p.dtype in {torch.float16, torch.bfloat16}: |
| | p_data_fp32 = p_data_fp32.float() |
| |
|
| | state["step"] += 1 |
| | state["RMS"] = self._rms(p_data_fp32) |
| | lr = self._get_lr(group, state) |
| |
|
| | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) |
| | update = (grad**2) + group["eps"][0] |
| | if factored: |
| | exp_avg_sq_row = state["exp_avg_sq_row"] |
| | exp_avg_sq_col = state["exp_avg_sq_col"] |
| |
|
| | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) |
| | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) |
| |
|
| | |
| | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
| | update.mul_(grad) |
| | else: |
| | exp_avg_sq = state["exp_avg_sq"] |
| |
|
| | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) |
| | update = exp_avg_sq.rsqrt().mul_(grad) |
| |
|
| | update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) |
| | update.mul_(lr) |
| |
|
| | if use_first_moment: |
| | exp_avg = state["exp_avg"] |
| | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) |
| | update = exp_avg |
| |
|
| | if group["weight_decay"] != 0: |
| | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) |
| |
|
| | p_data_fp32.add_(-update) |
| |
|
| | if p.dtype in {torch.float16, torch.bfloat16}: |
| | p.copy_(p_data_fp32) |
| |
|
| | return loss |
| |
|
| |
|
| | class AdafactorSchedule(LambdaLR): |
| | """ |
| | Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g., |
| | for logging), this class creates a proxy object that retrieves the current lr values from the optimizer. |
| | |
| | It returns `initial_lr` during startup and the actual `lr` during stepping. |
| | """ |
| |
|
| | def __init__(self, optimizer, initial_lr=0.0): |
| | def lr_lambda(_): |
| | return initial_lr |
| |
|
| | for group in optimizer.param_groups: |
| | group["initial_lr"] = initial_lr |
| | super().__init__(optimizer, lr_lambda) |
| | for group in optimizer.param_groups: |
| | del group["initial_lr"] |
| |
|
| | def get_lr(self): |
| | opt = self.optimizer |
| | lrs = [ |
| | opt._get_lr(group, opt.state[group["params"][0]]) |
| | for group in opt.param_groups |
| | if group["params"][0].grad is not None |
| | ] |
| | if len(lrs) == 0: |
| | lrs = self.base_lrs |
| | return lrs |
| |
|
| |
|
| | def get_adafactor_schedule(optimizer, initial_lr=0.0): |
| | """ |
| | Get a proxy schedule for [`~optimization.Adafactor`] |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | initial_lr (`float`, *optional*, defaults to 0.0): |
| | Initial lr |
| | |
| | Return: |
| | [`~optimization.Adafactor`] proxy schedule object. |
| | |
| | |
| | """ |
| | return AdafactorSchedule(optimizer, initial_lr) |
| |
|