Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Layer-wise Adaptive Moments (LAMB) optimizer. | |
| See paper [Large Batch Optimization for Deep Learning: Training BERT in | |
| 76 minutes](https://arxiv.org/abs/1904.00962). | |
| """ | |
| import re | |
| from typing import Optional, Union, Callable, List | |
| import numpy as np | |
| import tensorflow as tf, tf_keras | |
| FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32] | |
| class LAMB(tf_keras.optimizers.legacy.Optimizer): | |
| """Optimizer that implements the Layer-wise Adaptive Moments (LAMB). | |
| See paper [Large Batch Optimization for Deep Learning: Training BERT | |
| in 76 minutes](https://arxiv.org/abs/1904.00962). | |
| """ | |
| def __init__( | |
| self, | |
| learning_rate: Union[FloatTensorLike, Callable] = 0.001, | |
| beta_1: FloatTensorLike = 0.9, | |
| beta_2: FloatTensorLike = 0.999, | |
| epsilon: FloatTensorLike = 1e-6, | |
| weight_decay_rate: FloatTensorLike = 0.0, | |
| exclude_from_weight_decay: Optional[List[str]] = None, | |
| exclude_from_layer_adaptation: Optional[List[str]] = None, | |
| name: str = "LAMB", | |
| **kwargs, | |
| ): | |
| """Construct a new LAMB optimizer. | |
| Args: | |
| learning_rate: A `Tensor` or a floating point value. or a schedule that | |
| is a `tf_keras.optimizers.schedules.LearningRateSchedule` The learning | |
| rate. | |
| beta_1: A `float` value or a constant `float` tensor. The exponential | |
| decay rate for the 1st moment estimates. | |
| beta_2: A `float` value or a constant `float` tensor. The exponential | |
| decay rate for the 2nd moment estimates. | |
| epsilon: A small constant for numerical stability. | |
| weight_decay_rate: weight decay rate. | |
| exclude_from_weight_decay: List of regex patterns of variables excluded | |
| from weight decay. Variables whose name contain a substring matching | |
| the pattern will be excluded. | |
| exclude_from_layer_adaptation: List of regex patterns of variables | |
| excluded from layer adaptation. Variables whose name contain a | |
| substring matching the pattern will be excluded. | |
| name: Optional name for the operations created when applying gradients. | |
| Defaults to "LAMB". | |
| **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, | |
| `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is | |
| clip gradients by value, `decay` is included for backward | |
| compatibility to allow time inverse decay of learning rate. `lr` is | |
| included for backward compatibility, recommended to use | |
| `learning_rate` instead. | |
| """ | |
| super().__init__(name, **kwargs) | |
| # Just adding the square of the weights to the loss function is *not* | |
| # the correct way of using L2 regularization/weight decay with Adam, | |
| # since that will interact with the m and v parameters in strange ways. | |
| # | |
| # Instead we want to decay the weights in a manner that doesn't interact | |
| # with the m/v parameters. | |
| self._set_hyper("weight_decay_rate", weight_decay_rate) | |
| self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) | |
| # This is learning rate decay for using keras learning rate schedule. | |
| self._set_hyper("decay", self._initial_decay) | |
| self._set_hyper("beta_1", beta_1) | |
| self._set_hyper("beta_2", beta_2) | |
| self.epsilon = epsilon or tf.backend_config.epsilon() | |
| self.exclude_from_weight_decay = exclude_from_weight_decay | |
| # exclude_from_layer_adaptation is set to exclude_from_weight_decay if | |
| # the arg is None. | |
| if exclude_from_layer_adaptation: | |
| self.exclude_from_layer_adaptation = exclude_from_layer_adaptation | |
| else: | |
| self.exclude_from_layer_adaptation = exclude_from_weight_decay | |
| def _create_slots(self, var_list): | |
| # Create slots for the first and second moments. | |
| # Separate for-loops to respect the ordering of slot variables from v1. | |
| for var in var_list: | |
| self.add_slot(var, "m") | |
| for var in var_list: | |
| self.add_slot(var, "v") | |
| def _prepare_local(self, var_device, var_dtype, apply_state): | |
| super()._prepare_local(var_device, var_dtype, apply_state) | |
| local_step = tf.cast(self.iterations + 1, var_dtype) | |
| beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype)) | |
| beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype)) | |
| weight_decay_rate = tf.identity( | |
| self._get_hyper("weight_decay_rate", var_dtype) | |
| ) | |
| beta_1_power = tf.pow(beta_1_t, local_step) | |
| beta_2_power = tf.pow(beta_2_t, local_step) | |
| apply_state[(var_device, var_dtype)].update( | |
| dict( | |
| weight_decay_rate=weight_decay_rate, | |
| epsilon=tf.convert_to_tensor(self.epsilon, var_dtype), | |
| beta_1_t=beta_1_t, | |
| beta_1_power=beta_1_power, | |
| one_minus_beta_1_t=1 - beta_1_t, | |
| beta_2_t=beta_2_t, | |
| beta_2_power=beta_2_power, | |
| one_minus_beta_2_t=1 - beta_2_t, | |
| ) | |
| ) | |
| def _resource_apply_dense(self, grad, var, apply_state=None): | |
| var_device, var_dtype = var.device, var.dtype.base_dtype | |
| coefficients = (apply_state or {}).get( | |
| (var_device, var_dtype) | |
| ) or self._fallback_apply_state(var_device, var_dtype) | |
| # m_t = beta1 * m + (1 - beta1) * g_t | |
| m = self.get_slot(var, "m") | |
| m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] | |
| m_t = m * coefficients["beta_1_t"] + m_scaled_g_values | |
| m_t = m.assign(m_t, use_locking=self._use_locking) | |
| # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) | |
| v = self.get_slot(var, "v") | |
| v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] | |
| v_t = v * coefficients["beta_2_t"] + v_scaled_g_values | |
| v_t = v.assign(v_t, use_locking=self._use_locking) | |
| m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) | |
| v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) | |
| v_sqrt = tf.sqrt(v_t_hat) | |
| update = m_t_hat / (v_sqrt + coefficients["epsilon"]) | |
| var_name = self._get_variable_name(var.name) | |
| if self._do_use_weight_decay(var_name): | |
| update += coefficients["weight_decay_rate"] * var | |
| ratio = 1.0 | |
| if self._do_layer_adaptation(var_name): | |
| w_norm = tf.norm(var, ord=2) | |
| g_norm = tf.norm(update, ord=2) | |
| ratio = tf.where( | |
| tf.greater(w_norm, 0), | |
| tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), | |
| 1.0, | |
| ) | |
| var_update = var - ratio * coefficients["lr_t"] * update | |
| return var.assign(var_update, use_locking=self._use_locking) | |
| def _resource_apply_sparse(self, grad, var, indices, apply_state=None): | |
| var_device, var_dtype = var.device, var.dtype.base_dtype | |
| coefficients = (apply_state or {}).get( | |
| (var_device, var_dtype) | |
| ) or self._fallback_apply_state(var_device, var_dtype) | |
| # m_t = beta1 * m + (1 - beta1) * g_t | |
| m = self.get_slot(var, "m") | |
| m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] | |
| m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking) | |
| with tf.control_dependencies([m_t]): | |
| m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) | |
| # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) | |
| v = self.get_slot(var, "v") | |
| v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] | |
| v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking) | |
| with tf.control_dependencies([v_t]): | |
| v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) | |
| m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) | |
| v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) | |
| v_sqrt = tf.sqrt(v_t_hat) | |
| update = m_t_hat / (v_sqrt + coefficients["epsilon"]) | |
| var_name = self._get_variable_name(var.name) | |
| if self._do_use_weight_decay(var_name): | |
| update += coefficients["weight_decay_rate"] * var | |
| ratio = 1.0 | |
| if self._do_layer_adaptation(var_name): | |
| w_norm = tf.norm(var, ord=2) | |
| g_norm = tf.norm(update, ord=2) | |
| ratio = tf.where( | |
| tf.greater(w_norm, 0), | |
| tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), | |
| 1.0, | |
| ) | |
| var_update = var.assign_sub( | |
| ratio * coefficients["lr_t"] * update, use_locking=self._use_locking | |
| ) | |
| return tf.group(*[var_update, m_t, v_t]) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| "learning_rate": self._serialize_hyperparameter("learning_rate"), | |
| "weight_decay_rate": self._serialize_hyperparameter( | |
| "weight_decay_rate" | |
| ), | |
| "decay": self._serialize_hyperparameter("decay"), | |
| "beta_1": self._serialize_hyperparameter("beta_1"), | |
| "beta_2": self._serialize_hyperparameter("beta_2"), | |
| "epsilon": self.epsilon, | |
| }) | |
| return config | |
| def _do_use_weight_decay(self, param_name): | |
| """Whether to use L2 weight decay for `param_name`.""" | |
| if self.exclude_from_weight_decay: | |
| for r in self.exclude_from_weight_decay: | |
| if re.search(r, param_name) is not None: | |
| return False | |
| return True | |
| def _do_layer_adaptation(self, param_name): | |
| """Whether to do layer-wise learning rate adaptation for `param_name`.""" | |
| if self.exclude_from_layer_adaptation: | |
| for r in self.exclude_from_layer_adaptation: | |
| if re.search(r, param_name) is not None: | |
| return False | |
| return True | |
| def _get_variable_name(self, param_name): | |
| """Get the variable name from the tensor name.""" | |
| m = re.match("^(.*):\\d+$", param_name) | |
| if m is not None: | |
| param_name = m.group(1) | |
| return param_name | |