|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Customized optimizer to match paper results."""
|
|
|
| import dataclasses
|
| from typing import List, Optional
|
|
|
| from absl import logging
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.modeling import optimization
|
| from official.nlp import optimization as nlp_optimization
|
|
|
|
|
| @dataclasses.dataclass
|
| class ViTAdamWConfig(optimization.AdamWeightDecayConfig):
|
| layer_decay: Optional[float] = 1.0
|
| vars_substr: Optional[List[str]] = None
|
| layers_idx: Optional[List[int]] = None
|
|
|
|
|
| @dataclasses.dataclass
|
| class OptimizerConfig(optimization.OptimizerConfig):
|
| vit_adamw: ViTAdamWConfig = dataclasses.field(default_factory=ViTAdamWConfig)
|
|
|
|
|
| @dataclasses.dataclass
|
| class OptimizationConfig(optimization.OptimizationConfig):
|
| """Configuration for optimizer and learning rate schedule.
|
|
|
| Attributes:
|
| optimizer: optimizer oneof config.
|
| ema: optional exponential moving average optimizer config, if specified, ema
|
| optimizer will be used.
|
| learning_rate: learning rate oneof config.
|
| warmup: warmup oneof config.
|
| """
|
| optimizer: OptimizerConfig = dataclasses.field(
|
| default_factory=OptimizerConfig
|
| )
|
|
|
|
|
|
|
|
|
| class _ViTAdamW(nlp_optimization.AdamWeightDecay):
|
| """Custom AdamW to support different lr scaling for backbone.
|
|
|
| The code is copied from AdamWeightDecay and Adam with learning scaling.
|
| """
|
|
|
| def __init__(self,
|
| learning_rate=0.001,
|
| beta_1=0.9,
|
| beta_2=0.999,
|
| epsilon=1e-7,
|
| amsgrad=False,
|
| weight_decay_rate=0.0,
|
| include_in_weight_decay=None,
|
| exclude_from_weight_decay=None,
|
| gradient_clip_norm=1.0,
|
| layer_decay=1.0,
|
| vars_substr=None,
|
| layers_idx=None,
|
| name='ViTAdamWeightDecay',
|
| **kwargs):
|
| super(_ViTAdamW,
|
| self).__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad,
|
| weight_decay_rate, include_in_weight_decay,
|
| exclude_from_weight_decay, gradient_clip_norm, name,
|
| **kwargs)
|
| self._layer_decay = layer_decay
|
| self._vars_substr = vars_substr
|
| self._layers_idx = layers_idx
|
| self._max_idx = max(layers_idx) + 1 if layers_idx is not None else 1
|
|
|
| def _resource_apply_dense(self, grad, var, apply_state=None):
|
| lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
| apply_state = kwargs['apply_state']
|
| if (
|
| self._layer_decay != 1.0
|
| and self._vars_substr is not None
|
| and self._layers_idx is not None
|
| ):
|
| is_decayed = False
|
| for var_substr, idx in zip(self._vars_substr, self._layers_idx):
|
| if var_substr in var.name:
|
| decay_factor = self._layer_decay ** (self._max_idx - idx)
|
| lr_t = lr_t * decay_factor
|
| is_decayed = True
|
| logging.debug(
|
| 'Applying layer-wise lr decay: %s: %f', var.name, decay_factor)
|
| break
|
| if not is_decayed:
|
| logging.debug('Ignore layer-wise lr decay: %s', var.name)
|
| decay = self._decay_weights_op(var, lr_t, apply_state)
|
| with tf.control_dependencies([decay]):
|
| var_device, var_dtype = var.device, var.dtype.base_dtype
|
| coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
| or self._fallback_apply_state(var_device, var_dtype))
|
|
|
| m = self.get_slot(var, 'm')
|
| v = self.get_slot(var, 'v')
|
| lr = coefficients['lr_t']
|
| if (
|
| self._layer_decay != 1.0
|
| and self._vars_substr is not None
|
| and self._layers_idx is not None
|
| ):
|
| for var_substr, idx in zip(self._vars_substr, self._layers_idx):
|
| if var_substr in var.name:
|
| lr = lr * (self._layer_decay ** (self._max_idx - idx))
|
| break
|
|
|
| if not self.amsgrad:
|
| return tf.raw_ops.ResourceApplyAdam(
|
| var=var.handle,
|
| m=m.handle,
|
| v=v.handle,
|
| beta1_power=coefficients['beta_1_power'],
|
| beta2_power=coefficients['beta_2_power'],
|
| lr=lr,
|
| beta1=coefficients['beta_1_t'],
|
| beta2=coefficients['beta_2_t'],
|
| epsilon=coefficients['epsilon'],
|
| grad=grad,
|
| use_locking=self._use_locking)
|
| else:
|
| vhat = self.get_slot(var, 'vhat')
|
| return tf.raw_ops.ResourceApplyAdamWithAmsgrad(
|
| var=var.handle,
|
| m=m.handle,
|
| v=v.handle,
|
| vhat=vhat.handle,
|
| beta1_power=coefficients['beta_1_power'],
|
| beta2_power=coefficients['beta_2_power'],
|
| lr=lr,
|
| beta1=coefficients['beta_1_t'],
|
| beta2=coefficients['beta_2_t'],
|
| epsilon=coefficients['epsilon'],
|
| grad=grad,
|
| use_locking=self._use_locking)
|
|
|
| def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
| lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
| apply_state = kwargs['apply_state']
|
| if (
|
| self._layer_decay != 1.0
|
| and self._vars_substr is not None
|
| and self._layers_idx is not None
|
| ):
|
| is_decayed = False
|
| for var_substr, idx in zip(self._vars_substr, self._layers_idx):
|
| if var_substr in var.name:
|
| decay_factor = self._layer_decay ** (self._max_idx - idx)
|
| lr_t = lr_t * decay_factor
|
| is_decayed = True
|
| logging.debug(
|
| 'Applying layer-wise lr decay: %s: %f', var.name, decay_factor)
|
| break
|
| if not is_decayed:
|
| logging.debug('Ignore layer-wise lr decay: %s', var.name)
|
| decay = self._decay_weights_op(var, lr_t, apply_state)
|
| with tf.control_dependencies([decay]):
|
| var_device, var_dtype = var.device, var.dtype.base_dtype
|
| coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
| or self._fallback_apply_state(var_device, var_dtype))
|
|
|
|
|
| m = self.get_slot(var, 'm')
|
| m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
|
| m_t = tf.compat.v1.assign(m, m * coefficients['beta_1_t'],
|
| use_locking=self._use_locking)
|
| with tf.control_dependencies([m_t]):
|
| m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
|
|
|
|
|
| v = self.get_slot(var, 'v')
|
| v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
|
| v_t = tf.compat.v1.assign(v, v * coefficients['beta_2_t'],
|
| use_locking=self._use_locking)
|
| with tf.control_dependencies([v_t]):
|
| v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
|
| lr = coefficients['lr_t']
|
| if (
|
| self._layer_decay != 1.0
|
| and self._vars_substr is not None
|
| and self._layers_idx is not None
|
| ):
|
| for var_substr, idx in zip(self._vars_substr, self._layers_idx):
|
| if var_substr in var.name:
|
| lr = lr * (self._layer_decay ** (self._max_idx - idx))
|
| break
|
| if not self.amsgrad:
|
| v_sqrt = tf.sqrt(v_t)
|
| var_update = tf.compat.v1.assign_sub(
|
| var, lr * m_t / (v_sqrt + coefficients['epsilon']),
|
| use_locking=self._use_locking)
|
| return tf.group(*[var_update, m_t, v_t])
|
| else:
|
| v_hat = self.get_slot(var, 'vhat')
|
| v_hat_t = tf.maximum(v_hat, v_t)
|
| with tf.control_dependencies([v_hat_t]):
|
| v_hat_t = tf.compat.v1.assign(
|
| v_hat, v_hat_t, use_locking=self._use_locking)
|
| v_hat_sqrt = tf.sqrt(v_hat_t)
|
| var_update = tf.compat.v1.assign_sub(
|
| var,
|
| lr* m_t / (v_hat_sqrt + coefficients['epsilon']),
|
| use_locking=self._use_locking)
|
| return tf.group(*[var_update, m_t, v_t, v_hat_t])
|
|
|
| optimization.register_optimizer_cls('vit_adamw', _ViTAdamW)
|
|
|