| from keras.src.api_export import keras_export |
| from keras.src.optimizers import adam |
| from keras.src.optimizers import optimizer |
|
|
|
|
| @keras_export(["keras.optimizers.AdamW"]) |
| class AdamW(adam.Adam): |
| """Optimizer that implements the AdamW algorithm. |
| |
| AdamW optimization is a stochastic gradient descent method that is based on |
| adaptive estimation of first-order and second-order moments with an added |
| method to decay weights per the techniques discussed in the paper, |
| 'Decoupled Weight Decay Regularization' by |
| [Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101). |
| |
| According to |
| [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), |
| the underlying Adam method is "*computationally |
| efficient, has little memory requirement, invariant to diagonal rescaling of |
| gradients, and is well suited for problems that are large in terms of |
| data/parameters*". |
| |
| Args: |
| learning_rate: A float, a |
| `keras.optimizers.schedules.LearningRateSchedule` instance, or |
| a callable that takes no arguments and returns the actual value to |
| use. The learning rate. Defaults to `0.001`. |
| beta_1: A float value or a constant float tensor, or a callable |
| that takes no arguments and returns the actual value to use. The |
| exponential decay rate for the 1st moment estimates. |
| Defaults to `0.9`. |
| beta_2: A float value or a constant float tensor, or a callable |
| that takes no arguments and returns the actual value to use. The |
| exponential decay rate for the 2nd moment estimates. |
| Defaults to `0.999`. |
| epsilon: A small constant for numerical stability. This epsilon is |
| "epsilon hat" in the Kingma and Ba paper (in the formula just |
| before Section 2.1), not the epsilon in Algorithm 1 of the paper. |
| Defaults to 1e-7. |
| amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm |
| from the paper "On the Convergence of Adam and beyond". |
| Defaults to `False`. |
| {{base_optimizer_keyword_args}} |
| |
| References: |
| |
| - [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101) |
| - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam` |
| - [Reddi et al., 2018]( |
| https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`. |
| """ |
|
|
| def __init__( |
| self, |
| learning_rate=0.001, |
| weight_decay=0.004, |
| beta_1=0.9, |
| beta_2=0.999, |
| epsilon=1e-7, |
| amsgrad=False, |
| clipnorm=None, |
| clipvalue=None, |
| global_clipnorm=None, |
| use_ema=False, |
| ema_momentum=0.99, |
| ema_overwrite_frequency=None, |
| loss_scale_factor=None, |
| gradient_accumulation_steps=None, |
| name="adamw", |
| **kwargs, |
| ): |
| super().__init__( |
| learning_rate=learning_rate, |
| beta_1=beta_1, |
| beta_2=beta_2, |
| epsilon=epsilon, |
| amsgrad=amsgrad, |
| name=name, |
| weight_decay=weight_decay, |
| clipnorm=clipnorm, |
| clipvalue=clipvalue, |
| global_clipnorm=global_clipnorm, |
| use_ema=use_ema, |
| ema_momentum=ema_momentum, |
| ema_overwrite_frequency=ema_overwrite_frequency, |
| loss_scale_factor=loss_scale_factor, |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| **kwargs, |
| ) |
|
|
| if self.weight_decay is None: |
| raise ValueError( |
| "Argument `weight_decay` must be a float. Received: " |
| "weight_decay=None" |
| ) |
|
|
|
|
| AdamW.__doc__ = AdamW.__doc__.replace( |
| "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args |
| ) |
|
|