Kasamuday's picture
Upload 4315 files
07ef7ab verified
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized optimizer to match paper results."""
import dataclasses
from typing import List, Optional
from absl import logging
import tensorflow as tf, tf_keras
from official.modeling import optimization
from official.nlp import optimization as nlp_optimization
@dataclasses.dataclass
class ViTAdamWConfig(optimization.AdamWeightDecayConfig):
layer_decay: Optional[float] = 1.0
vars_substr: Optional[List[str]] = None
layers_idx: Optional[List[int]] = None
@dataclasses.dataclass
class OptimizerConfig(optimization.OptimizerConfig):
vit_adamw: ViTAdamWConfig = dataclasses.field(default_factory=ViTAdamWConfig)
@dataclasses.dataclass
class OptimizationConfig(optimization.OptimizationConfig):
"""Configuration for optimizer and learning rate schedule.
Attributes:
optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified, ema
optimizer will be used.
learning_rate: learning rate oneof config.
warmup: warmup oneof config.
"""
optimizer: OptimizerConfig = dataclasses.field(
default_factory=OptimizerConfig
)
# TODO(frederickliu): figure out how to make this configuable.
# TODO(frederickliu): Study if this is needed.
class _ViTAdamW(nlp_optimization.AdamWeightDecay):
"""Custom AdamW to support different lr scaling for backbone.
The code is copied from AdamWeightDecay and Adam with learning scaling.
"""
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
weight_decay_rate=0.0,
include_in_weight_decay=None,
exclude_from_weight_decay=None,
gradient_clip_norm=1.0,
layer_decay=1.0,
vars_substr=None,
layers_idx=None,
name='ViTAdamWeightDecay',
**kwargs):
super(_ViTAdamW,
self).__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad,
weight_decay_rate, include_in_weight_decay,
exclude_from_weight_decay, gradient_clip_norm, name,
**kwargs)
self._layer_decay = layer_decay
self._vars_substr = vars_substr
self._layers_idx = layers_idx
self._max_idx = max(layers_idx) + 1 if layers_idx is not None else 1
def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
apply_state = kwargs['apply_state']
if (
self._layer_decay != 1.0
and self._vars_substr is not None
and self._layers_idx is not None
):
is_decayed = False
for var_substr, idx in zip(self._vars_substr, self._layers_idx):
if var_substr in var.name:
decay_factor = self._layer_decay ** (self._max_idx - idx)
lr_t = lr_t * decay_factor
is_decayed = True
logging.debug(
'Applying layer-wise lr decay: %s: %f', var.name, decay_factor)
break
if not is_decayed:
logging.debug('Ignore layer-wise lr decay: %s', var.name)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
m = self.get_slot(var, 'm')
v = self.get_slot(var, 'v')
lr = coefficients['lr_t']
if (
self._layer_decay != 1.0
and self._vars_substr is not None
and self._layers_idx is not None
):
for var_substr, idx in zip(self._vars_substr, self._layers_idx):
if var_substr in var.name:
lr = lr * (self._layer_decay ** (self._max_idx - idx))
break
if not self.amsgrad:
return tf.raw_ops.ResourceApplyAdam(
var=var.handle,
m=m.handle,
v=v.handle,
beta1_power=coefficients['beta_1_power'],
beta2_power=coefficients['beta_2_power'],
lr=lr,
beta1=coefficients['beta_1_t'],
beta2=coefficients['beta_2_t'],
epsilon=coefficients['epsilon'],
grad=grad,
use_locking=self._use_locking)
else:
vhat = self.get_slot(var, 'vhat')
return tf.raw_ops.ResourceApplyAdamWithAmsgrad(
var=var.handle,
m=m.handle,
v=v.handle,
vhat=vhat.handle,
beta1_power=coefficients['beta_1_power'],
beta2_power=coefficients['beta_2_power'],
lr=lr,
beta1=coefficients['beta_1_t'],
beta2=coefficients['beta_2_t'],
epsilon=coefficients['epsilon'],
grad=grad,
use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
apply_state = kwargs['apply_state']
if (
self._layer_decay != 1.0
and self._vars_substr is not None
and self._layers_idx is not None
):
is_decayed = False
for var_substr, idx in zip(self._vars_substr, self._layers_idx):
if var_substr in var.name:
decay_factor = self._layer_decay ** (self._max_idx - idx)
lr_t = lr_t * decay_factor
is_decayed = True
logging.debug(
'Applying layer-wise lr decay: %s: %f', var.name, decay_factor)
break
if not is_decayed:
logging.debug('Ignore layer-wise lr decay: %s', var.name)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = tf.compat.v1.assign(m, m * coefficients['beta_1_t'],
use_locking=self._use_locking)
with tf.control_dependencies([m_t]):
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
v_t = tf.compat.v1.assign(v, v * coefficients['beta_2_t'],
use_locking=self._use_locking)
with tf.control_dependencies([v_t]):
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
lr = coefficients['lr_t']
if (
self._layer_decay != 1.0
and self._vars_substr is not None
and self._layers_idx is not None
):
for var_substr, idx in zip(self._vars_substr, self._layers_idx):
if var_substr in var.name:
lr = lr * (self._layer_decay ** (self._max_idx - idx))
break
if not self.amsgrad:
v_sqrt = tf.sqrt(v_t)
var_update = tf.compat.v1.assign_sub(
var, lr * m_t / (v_sqrt + coefficients['epsilon']),
use_locking=self._use_locking)
return tf.group(*[var_update, m_t, v_t])
else:
v_hat = self.get_slot(var, 'vhat')
v_hat_t = tf.maximum(v_hat, v_t)
with tf.control_dependencies([v_hat_t]):
v_hat_t = tf.compat.v1.assign(
v_hat, v_hat_t, use_locking=self._use_locking)
v_hat_sqrt = tf.sqrt(v_hat_t)
var_update = tf.compat.v1.assign_sub(
var,
lr* m_t / (v_hat_sqrt + coefficients['epsilon']),
use_locking=self._use_locking)
return tf.group(*[var_update, m_t, v_t, v_hat_t])
optimization.register_optimizer_cls('vit_adamw', _ViTAdamW)