|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """VisionTransformer models."""
|
|
|
| import math
|
| from typing import Optional, Tuple
|
|
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.modeling import activations
|
| from official.vision.modeling.backbones import factory
|
| from official.vision.modeling.backbones.vit_specs import VIT_SPECS
|
| from official.vision.modeling.layers import nn_blocks
|
| from official.vision.modeling.layers import nn_layers
|
|
|
|
|
| layers = tf_keras.layers
|
|
|
|
|
| class AddPositionEmbs(layers.Layer):
|
| """Adds (optionally learned) positional embeddings to the inputs."""
|
|
|
| def __init__(self,
|
| posemb_init: Optional[tf_keras.initializers.Initializer] = None,
|
| posemb_origin_shape: Optional[Tuple[int, int]] = None,
|
| posemb_target_shape: Optional[Tuple[int, int]] = None,
|
| **kwargs):
|
| """Constructs Positional Embedding module.
|
|
|
| The logic of this module is: the learnable positional embeddings length will
|
| be determined by the inputs_shape or posemb_origin_shape (if provided)
|
| during the construction. If the posemb_target_shape is provided and is
|
| different from the positional embeddings length, the embeddings will be
|
| interpolated during the forward call.
|
|
|
| Args:
|
| posemb_init: The positional embedding initializer.
|
| posemb_origin_shape: The intended positional embedding shape.
|
| posemb_target_shape: The potential target shape positional embedding may
|
| be interpolated to.
|
| **kwargs: other args.
|
| """
|
| super().__init__(**kwargs)
|
| self.posemb_init = posemb_init
|
| self.posemb_origin_shape = posemb_origin_shape
|
| self.posemb_target_shape = posemb_target_shape
|
|
|
| def build(self, inputs_shape):
|
| if self.posemb_origin_shape is not None:
|
| pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
|
| else:
|
| pos_emb_length = inputs_shape[1]
|
| pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
|
| self.pos_embedding = self.add_weight(
|
| 'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
|
|
|
| def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int],
|
| to_shape: Tuple[int, int]) -> tf.Tensor:
|
| """Interpolates the positional embeddings."""
|
| logging.info('Interpolating postional embedding from length: %s to %s',
|
| from_shape, to_shape)
|
| grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
|
|
|
| grid_emb = tf.image.resize(grid_emb, to_shape)
|
| return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])
|
|
|
| def call(self, inputs, inputs_positions=None):
|
| del inputs_positions
|
| pos_embedding = self.pos_embedding
|
|
|
| if inputs.shape[1] != pos_embedding.shape[1]:
|
| pos_embedding = self._interpolate(
|
| pos_embedding,
|
| from_shape=self.posemb_origin_shape,
|
| to_shape=self.posemb_target_shape)
|
| pos_embedding = tf.cast(pos_embedding, inputs.dtype)
|
|
|
| return inputs + pos_embedding
|
|
|
|
|
| class TokenLayer(layers.Layer):
|
| """A simple layer to wrap token parameters."""
|
|
|
| def build(self, inputs_shape):
|
| self.cls = self.add_weight(
|
| 'cls', (1, 1, inputs_shape[-1]), initializer='zeros')
|
|
|
| def call(self, inputs):
|
| cls = tf.cast(self.cls, inputs.dtype)
|
| cls = cls + tf.zeros_like(inputs[:, 0:1])
|
| x = tf.concat([cls, inputs], axis=1)
|
| return x
|
|
|
|
|
| class Encoder(layers.Layer):
|
| """Transformer Encoder."""
|
|
|
| def __init__(self,
|
| num_layers,
|
| mlp_dim,
|
| num_heads,
|
| dropout_rate=0.1,
|
| attention_dropout_rate=0.1,
|
| kernel_regularizer=None,
|
| inputs_positions=None,
|
| init_stochastic_depth_rate=0.0,
|
| kernel_initializer='glorot_uniform',
|
| add_pos_embed=True,
|
| pos_embed_origin_shape=None,
|
| pos_embed_target_shape=None,
|
| layer_scale_init_value=0.0,
|
| transformer_partition_dims=None,
|
| **kwargs):
|
| super().__init__(**kwargs)
|
| self._num_layers = num_layers
|
| self._mlp_dim = mlp_dim
|
| self._num_heads = num_heads
|
| self._dropout_rate = dropout_rate
|
| self._attention_dropout_rate = attention_dropout_rate
|
| self._kernel_regularizer = kernel_regularizer
|
| self._inputs_positions = inputs_positions
|
| self._init_stochastic_depth_rate = init_stochastic_depth_rate
|
| self._kernel_initializer = kernel_initializer
|
| self._add_pos_embed = add_pos_embed
|
| self._pos_embed_origin_shape = pos_embed_origin_shape
|
| self._pos_embed_target_shape = pos_embed_target_shape
|
| self._layer_scale_init_value = layer_scale_init_value
|
| self._transformer_partition_dims = transformer_partition_dims
|
|
|
| def build(self, input_shape):
|
| if self._add_pos_embed:
|
| self._pos_embed = AddPositionEmbs(
|
| posemb_init=tf_keras.initializers.RandomNormal(stddev=0.02),
|
| posemb_origin_shape=self._pos_embed_origin_shape,
|
| posemb_target_shape=self._pos_embed_target_shape,
|
| name='posembed_input')
|
| self._dropout = layers.Dropout(rate=self._dropout_rate)
|
|
|
| self._encoder_layers = []
|
|
|
|
|
| for i in range(self._num_layers):
|
| encoder_layer = nn_blocks.TransformerEncoderBlock(
|
| inner_activation=activations.gelu,
|
| num_attention_heads=self._num_heads,
|
| inner_dim=self._mlp_dim,
|
| output_dropout=self._dropout_rate,
|
| attention_dropout=self._attention_dropout_rate,
|
| kernel_regularizer=self._kernel_regularizer,
|
| kernel_initializer=self._kernel_initializer,
|
| norm_first=True,
|
| stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
|
| self._init_stochastic_depth_rate, i + 1, self._num_layers),
|
| norm_epsilon=1e-6,
|
| layer_scale_init_value=self._layer_scale_init_value,
|
| transformer_partition_dims=self._transformer_partition_dims)
|
| self._encoder_layers.append(encoder_layer)
|
| self._norm = layers.LayerNormalization(epsilon=1e-6)
|
| super().build(input_shape)
|
|
|
| def call(self, inputs, training=None):
|
| x = inputs
|
| if self._add_pos_embed:
|
| x = self._pos_embed(x, inputs_positions=self._inputs_positions)
|
| x = self._dropout(x, training=training)
|
|
|
| for encoder_layer in self._encoder_layers:
|
| x = encoder_layer(x, training=training)
|
| x = self._norm(x)
|
| return x
|
|
|
| def get_config(self):
|
| config = super().get_config()
|
| updates = {
|
| 'num_layers': self._num_layers,
|
| 'mlp_dim': self._mlp_dim,
|
| 'num_heads': self._num_heads,
|
| 'dropout_rate': self._dropout_rate,
|
| 'attention_dropout_rate': self._attention_dropout_rate,
|
| 'kernel_regularizer': self._kernel_regularizer,
|
| 'inputs_positions': self._inputs_positions,
|
| 'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
|
| 'kernel_initializer': self._kernel_initializer,
|
| 'add_pos_embed': self._add_pos_embed,
|
| 'pos_embed_origin_shape': self._pos_embed_origin_shape,
|
| 'pos_embed_target_shape': self._pos_embed_target_shape,
|
| 'layer_scale_init_value': self._layer_scale_init_value,
|
| 'transformer_partition_dims': self._transformer_partition_dims,
|
| }
|
| config.update(updates)
|
| return config
|
|
|
|
|
| class VisionTransformer(tf_keras.Model):
|
| """Class to build VisionTransformer family model."""
|
|
|
| def __init__(
|
| self,
|
| mlp_dim=3072,
|
| num_heads=12,
|
| num_layers=12,
|
| attention_dropout_rate=0.0,
|
| dropout_rate=0.1,
|
| init_stochastic_depth_rate=0.0,
|
| input_specs=layers.InputSpec(shape=[None, None, None, 3]),
|
| patch_size=16,
|
| hidden_size=768,
|
| representation_size=0,
|
| pooler='token',
|
| kernel_regularizer=None,
|
| original_init: bool = True,
|
| output_encoded_tokens: bool = True,
|
| output_2d_feature_maps: bool = False,
|
| pos_embed_shape: Optional[Tuple[int, int]] = None,
|
| layer_scale_init_value: float = 0.0,
|
| transformer_partition_dims: Optional[Tuple[int, int, int, int]] = None,
|
| ):
|
| """VisionTransformer initialization function."""
|
| self._mlp_dim = mlp_dim
|
| self._num_heads = num_heads
|
| self._num_layers = num_layers
|
| self._hidden_size = hidden_size
|
| self._patch_size = patch_size
|
|
|
| inputs = tf_keras.Input(shape=input_specs.shape[1:])
|
|
|
| x = layers.Conv2D(
|
| filters=hidden_size,
|
| kernel_size=patch_size,
|
| strides=patch_size,
|
| padding='valid',
|
| kernel_regularizer=kernel_regularizer,
|
| kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
|
| inputs)
|
| if tf_keras.backend.image_data_format() == 'channels_last':
|
| rows_axis, cols_axis = (1, 2)
|
| else:
|
| rows_axis, cols_axis = (2, 3)
|
|
|
|
|
|
|
|
|
| x = tf.transpose(x, perm=[0, 2, 3, 1])
|
|
|
| pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
|
| feat_h = input_specs.shape[rows_axis] // patch_size
|
| feat_w = input_specs.shape[cols_axis] // patch_size
|
| seq_len = feat_h * feat_w
|
| x = tf.reshape(x, [-1, seq_len, hidden_size])
|
|
|
|
|
| if pooler == 'token':
|
| x = TokenLayer(name='cls')(x)
|
|
|
| x = Encoder(
|
| num_layers=num_layers,
|
| mlp_dim=mlp_dim,
|
| num_heads=num_heads,
|
| dropout_rate=dropout_rate,
|
| attention_dropout_rate=attention_dropout_rate,
|
| kernel_regularizer=kernel_regularizer,
|
| kernel_initializer='glorot_uniform' if original_init else dict(
|
| class_name='TruncatedNormal', config=dict(stddev=.02)),
|
| init_stochastic_depth_rate=init_stochastic_depth_rate,
|
| pos_embed_origin_shape=pos_embed_shape,
|
| pos_embed_target_shape=pos_embed_target_shape,
|
| layer_scale_init_value=layer_scale_init_value)(
|
| x)
|
|
|
| if pooler == 'token':
|
| output_feature = x[:, 1:]
|
| x = x[:, 0]
|
| elif pooler == 'gap':
|
| output_feature = x
|
| x = tf.reduce_mean(x, axis=1)
|
| elif pooler == 'none':
|
| output_feature = x
|
| x = tf.identity(x, name='encoded_tokens')
|
| else:
|
| raise ValueError(f'unrecognized pooler type: {pooler}')
|
|
|
| endpoints = {}
|
| if output_2d_feature_maps:
|
|
|
| feat_level = round(math.log2(patch_size))
|
| logging.info(
|
| 'VisionTransformer patch size %d and feature level: %d',
|
| patch_size,
|
| feat_level,
|
| )
|
| endpoints[str(feat_level)] = tf.reshape(
|
| output_feature, [-1, feat_h, feat_w, x.shape.as_list()[-1]])
|
|
|
|
|
| self._output_specs = {k: v.shape for k, v in endpoints.items()}
|
|
|
| if representation_size:
|
| x = layers.Dense(
|
| representation_size,
|
| kernel_regularizer=kernel_regularizer,
|
| name='pre_logits',
|
| kernel_initializer='lecun_normal' if original_init else 'he_uniform',
|
| )(x)
|
| x = tf.nn.tanh(x)
|
| else:
|
| x = tf.identity(x, name='pre_logits')
|
|
|
| if pooler == 'none':
|
| if output_encoded_tokens:
|
| endpoints['encoded_tokens'] = x
|
| else:
|
| endpoints['pre_logits'] = tf.reshape(
|
| x, [-1, 1, 1, representation_size or hidden_size])
|
|
|
| super().__init__(inputs=inputs, outputs=endpoints)
|
|
|
| @property
|
| def output_specs(self):
|
| """A dict of {level: TensorShape} pairs for the model output."""
|
| return self._output_specs
|
|
|
|
|
| @factory.register_backbone_builder('vit')
|
| def build_vit(input_specs,
|
| backbone_config,
|
| norm_activation_config,
|
| l2_regularizer=None):
|
| """Build ViT model."""
|
| del norm_activation_config
|
| backbone_type = backbone_config.type
|
| backbone_cfg = backbone_config.get()
|
| assert backbone_type == 'vit', (f'Inconsistent backbone type '
|
| f'{backbone_type}')
|
| backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])
|
| logging.info(
|
| (
|
| 'ViT specs: mlp_dim=%d, num_heads=%d, num_layers=%d,'
|
| 'patch_size=%d, hidden_size=%d, representation_size=%d.'
|
| ),
|
| backbone_cfg.transformer.mlp_dim,
|
| backbone_cfg.transformer.num_heads,
|
| backbone_cfg.transformer.num_layers,
|
| backbone_cfg.patch_size,
|
| backbone_cfg.hidden_size,
|
| backbone_cfg.representation_size,
|
| )
|
|
|
| return VisionTransformer(
|
| mlp_dim=backbone_cfg.transformer.mlp_dim,
|
| num_heads=backbone_cfg.transformer.num_heads,
|
| num_layers=backbone_cfg.transformer.num_layers,
|
| attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
|
| dropout_rate=backbone_cfg.transformer.dropout_rate,
|
| init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
|
| input_specs=input_specs,
|
| patch_size=backbone_cfg.patch_size,
|
| hidden_size=backbone_cfg.hidden_size,
|
| representation_size=backbone_cfg.representation_size,
|
| pooler=backbone_cfg.pooler,
|
| kernel_regularizer=l2_regularizer,
|
| original_init=backbone_cfg.original_init,
|
| output_encoded_tokens=backbone_cfg.output_encoded_tokens,
|
| output_2d_feature_maps=backbone_cfg.output_2d_feature_maps,
|
| layer_scale_init_value=backbone_cfg.layer_scale_init_value,
|
| pos_embed_shape=backbone_cfg.pos_embed_shape,
|
| transformer_partition_dims=backbone_cfg.transformer_partition_dims)
|
|
|