|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Layer-wise Adaptive Moments (LAMB) optimizer.
|
|
|
| See paper [Large Batch Optimization for Deep Learning: Training BERT in
|
| 76 minutes](https://arxiv.org/abs/1904.00962).
|
| """
|
| import re
|
| from typing import Optional, Union, Callable, List
|
|
|
| import numpy as np
|
| import tensorflow as tf, tf_keras
|
|
|
| FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32]
|
|
|
|
|
| @tf_keras.utils.register_keras_serializable(package="Addons")
|
| class LAMB(tf_keras.optimizers.legacy.Optimizer):
|
| """Optimizer that implements the Layer-wise Adaptive Moments (LAMB).
|
|
|
| See paper [Large Batch Optimization for Deep Learning: Training BERT
|
| in 76 minutes](https://arxiv.org/abs/1904.00962).
|
| """
|
|
|
| def __init__(
|
| self,
|
| learning_rate: Union[FloatTensorLike, Callable] = 0.001,
|
| beta_1: FloatTensorLike = 0.9,
|
| beta_2: FloatTensorLike = 0.999,
|
| epsilon: FloatTensorLike = 1e-6,
|
| weight_decay_rate: FloatTensorLike = 0.0,
|
| exclude_from_weight_decay: Optional[List[str]] = None,
|
| exclude_from_layer_adaptation: Optional[List[str]] = None,
|
| name: str = "LAMB",
|
| **kwargs,
|
| ):
|
| """Construct a new LAMB optimizer.
|
|
|
| Args:
|
| learning_rate: A `Tensor` or a floating point value. or a schedule that
|
| is a `tf_keras.optimizers.schedules.LearningRateSchedule` The learning
|
| rate.
|
| beta_1: A `float` value or a constant `float` tensor. The exponential
|
| decay rate for the 1st moment estimates.
|
| beta_2: A `float` value or a constant `float` tensor. The exponential
|
| decay rate for the 2nd moment estimates.
|
| epsilon: A small constant for numerical stability.
|
| weight_decay_rate: weight decay rate.
|
| exclude_from_weight_decay: List of regex patterns of variables excluded
|
| from weight decay. Variables whose name contain a substring matching
|
| the pattern will be excluded.
|
| exclude_from_layer_adaptation: List of regex patterns of variables
|
| excluded from layer adaptation. Variables whose name contain a
|
| substring matching the pattern will be excluded.
|
| name: Optional name for the operations created when applying gradients.
|
| Defaults to "LAMB".
|
| **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
|
| `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is
|
| clip gradients by value, `decay` is included for backward
|
| compatibility to allow time inverse decay of learning rate. `lr` is
|
| included for backward compatibility, recommended to use
|
| `learning_rate` instead.
|
| """
|
| super().__init__(name, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self._set_hyper("weight_decay_rate", weight_decay_rate)
|
| self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
|
|
|
|
|
| self._set_hyper("decay", self._initial_decay)
|
| self._set_hyper("beta_1", beta_1)
|
| self._set_hyper("beta_2", beta_2)
|
| self.epsilon = epsilon or tf.backend_config.epsilon()
|
| self.exclude_from_weight_decay = exclude_from_weight_decay
|
|
|
|
|
| if exclude_from_layer_adaptation:
|
| self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
|
| else:
|
| self.exclude_from_layer_adaptation = exclude_from_weight_decay
|
|
|
| def _create_slots(self, var_list):
|
|
|
|
|
| for var in var_list:
|
| self.add_slot(var, "m")
|
| for var in var_list:
|
| self.add_slot(var, "v")
|
|
|
| def _prepare_local(self, var_device, var_dtype, apply_state):
|
| super()._prepare_local(var_device, var_dtype, apply_state)
|
|
|
| local_step = tf.cast(self.iterations + 1, var_dtype)
|
| beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype))
|
| beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype))
|
| weight_decay_rate = tf.identity(
|
| self._get_hyper("weight_decay_rate", var_dtype)
|
| )
|
| beta_1_power = tf.pow(beta_1_t, local_step)
|
| beta_2_power = tf.pow(beta_2_t, local_step)
|
| apply_state[(var_device, var_dtype)].update(
|
| dict(
|
| weight_decay_rate=weight_decay_rate,
|
| epsilon=tf.convert_to_tensor(self.epsilon, var_dtype),
|
| beta_1_t=beta_1_t,
|
| beta_1_power=beta_1_power,
|
| one_minus_beta_1_t=1 - beta_1_t,
|
| beta_2_t=beta_2_t,
|
| beta_2_power=beta_2_power,
|
| one_minus_beta_2_t=1 - beta_2_t,
|
| )
|
| )
|
|
|
| def _resource_apply_dense(self, grad, var, apply_state=None):
|
| var_device, var_dtype = var.device, var.dtype.base_dtype
|
| coefficients = (apply_state or {}).get(
|
| (var_device, var_dtype)
|
| ) or self._fallback_apply_state(var_device, var_dtype)
|
|
|
|
|
| m = self.get_slot(var, "m")
|
| m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
|
| m_t = m * coefficients["beta_1_t"] + m_scaled_g_values
|
| m_t = m.assign(m_t, use_locking=self._use_locking)
|
|
|
| v = self.get_slot(var, "v")
|
| v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
|
| v_t = v * coefficients["beta_2_t"] + v_scaled_g_values
|
| v_t = v.assign(v_t, use_locking=self._use_locking)
|
|
|
| m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
|
| v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])
|
|
|
| v_sqrt = tf.sqrt(v_t_hat)
|
| update = m_t_hat / (v_sqrt + coefficients["epsilon"])
|
|
|
| var_name = self._get_variable_name(var.name)
|
| if self._do_use_weight_decay(var_name):
|
| update += coefficients["weight_decay_rate"] * var
|
|
|
| ratio = 1.0
|
| if self._do_layer_adaptation(var_name):
|
| w_norm = tf.norm(var, ord=2)
|
| g_norm = tf.norm(update, ord=2)
|
| ratio = tf.where(
|
| tf.greater(w_norm, 0),
|
| tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
|
| 1.0,
|
| )
|
|
|
| var_update = var - ratio * coefficients["lr_t"] * update
|
| return var.assign(var_update, use_locking=self._use_locking)
|
|
|
| def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
| var_device, var_dtype = var.device, var.dtype.base_dtype
|
| coefficients = (apply_state or {}).get(
|
| (var_device, var_dtype)
|
| ) or self._fallback_apply_state(var_device, var_dtype)
|
|
|
|
|
| m = self.get_slot(var, "m")
|
| m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
|
| m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking)
|
| with tf.control_dependencies([m_t]):
|
| m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
|
|
|
|
|
| v = self.get_slot(var, "v")
|
| v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
|
| v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking)
|
| with tf.control_dependencies([v_t]):
|
| v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
|
|
|
| m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
|
| v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])
|
|
|
| v_sqrt = tf.sqrt(v_t_hat)
|
| update = m_t_hat / (v_sqrt + coefficients["epsilon"])
|
|
|
| var_name = self._get_variable_name(var.name)
|
| if self._do_use_weight_decay(var_name):
|
| update += coefficients["weight_decay_rate"] * var
|
|
|
| ratio = 1.0
|
| if self._do_layer_adaptation(var_name):
|
| w_norm = tf.norm(var, ord=2)
|
| g_norm = tf.norm(update, ord=2)
|
| ratio = tf.where(
|
| tf.greater(w_norm, 0),
|
| tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
|
| 1.0,
|
| )
|
|
|
| var_update = var.assign_sub(
|
| ratio * coefficients["lr_t"] * update, use_locking=self._use_locking
|
| )
|
| return tf.group(*[var_update, m_t, v_t])
|
|
|
| def get_config(self):
|
| config = super().get_config()
|
| config.update({
|
| "learning_rate": self._serialize_hyperparameter("learning_rate"),
|
| "weight_decay_rate": self._serialize_hyperparameter(
|
| "weight_decay_rate"
|
| ),
|
| "decay": self._serialize_hyperparameter("decay"),
|
| "beta_1": self._serialize_hyperparameter("beta_1"),
|
| "beta_2": self._serialize_hyperparameter("beta_2"),
|
| "epsilon": self.epsilon,
|
| })
|
| return config
|
|
|
| def _do_use_weight_decay(self, param_name):
|
| """Whether to use L2 weight decay for `param_name`."""
|
| if self.exclude_from_weight_decay:
|
| for r in self.exclude_from_weight_decay:
|
| if re.search(r, param_name) is not None:
|
| return False
|
| return True
|
|
|
| def _do_layer_adaptation(self, param_name):
|
| """Whether to do layer-wise learning rate adaptation for `param_name`."""
|
| if self.exclude_from_layer_adaptation:
|
| for r in self.exclude_from_layer_adaptation:
|
| if re.search(r, param_name) is not None:
|
| return False
|
| return True
|
|
|
| def _get_variable_name(self, param_name):
|
| """Get the variable name from the tensor name."""
|
| m = re.match("^(.*):\\d+$", param_name)
|
| if m is not None:
|
| param_name = m.group(1)
|
| return param_name
|
|
|