fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer."""
from typing import Any, Callable, Optional, Sequence
from absl import logging
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from scenic.model_lib.base_models.multilabel_classification_model import MultiLabelClassificationModel
from scenic.model_lib.layers import attention_layers
from scenic.model_lib.layers import nn_layers
import scipy
from tensorflow.io import gfile
Initializer = Callable[[jnp.ndarray, Sequence[int], jnp.dtype], jnp.ndarray]
class AddPositionEmbs(nn.Module):
"""Adds learned positional embeddings to the inputs.
Attributes:
posemb_init: Positional embedding initializer.
Returns:
Output in shape `[bs, timesteps, in_dim]`.
"""
posemb_init: Initializer = nn.initializers.normal(stddev=0.02) # From BERT.
@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
# Inputs.shape is (batch_size, seq_len, emb_dim).
assert inputs.ndim == 3, ('Number of dimensions should be 3,'
' but it is: %d' % inputs.ndim)
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape,
inputs.dtype)
return inputs + pe
class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x):
n, _, d = x.shape
probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d),
x.dtype)
probe = jnp.tile(probe, [n, 1, 1])
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform()
)(probe, x)
y = nn.LayerNorm()(x)
x = x + attention_layers.MlpBlock(
mlp_dim=self.mlp_dim,
dtype=self.dtype,
dropout_rate=0.0)(y, deterministic=True)
return x[:, 0]
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
Attributes:
mlp_dim: Dimension of the mlp on top of attention block.
num_heads: Number of self-attention heads.
dtype: The dtype of the computation (default: float32).
dropout_rate: Dropout rate.
attention_dropout_rate: Dropout for attention heads.
stochastic_depth: probability of dropping a layer linearly grows
from 0 to the provided value.
Returns:
output after transformer encoder block.
"""
mlp_dim: int
num_heads: int
dtype: Any = jnp.float32
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
stochastic_depth: float = 0.0
@nn.compact
def __call__(self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
"""Applies Encoder1DBlock module.
Args:
inputs: Input data.
deterministic: Deterministic or not (to apply dropout).
Returns:
Output after transformer encoder block.
"""
# Attention block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
dtype=self.dtype,
kernel_init=nn.initializers.xavier_uniform(),
broadcast_dropout=False,
deterministic=deterministic,
dropout_rate=self.attention_dropout_rate)(x, x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)
x = nn_layers.StochasticDepth(rate=self.stochastic_depth)(x, deterministic)
x = x + inputs
# MLP block.
y = nn.LayerNorm(dtype=self.dtype)(x)
y = attention_layers.MlpBlock(
mlp_dim=self.mlp_dim,
dtype=self.dtype,
dropout_rate=self.dropout_rate,
activation_fn=nn.gelu,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.normal(stddev=1e-6))(
y, deterministic=deterministic)
y = nn_layers.StochasticDepth(rate=self.stochastic_depth)(y, deterministic)
return y + x
class Encoder(nn.Module):
"""Transformer Encoder.
Attributes:
num_layers: Number of layers.
mlp_dim: Dimension of the mlp on top of attention block.
num_heads: The number of heads for multi-head self-attention.
positional_embedding: The type of positional embeddings to add to the
input tokens. Options are {learned_1d, sinusoidal_2d, none}.
dropout_rate: Dropout rate.
stochastic_depth: probability of dropping a layer linearly grows
from 0 to the provided value. Our implementation of stochastic depth
follows timm library, which does per-example layer dropping and uses
independent dropping patterns for each skip-connection.
dtype: Dtype of activations.
has_cls_token: Whether or not the sequence is prepended with a CLS token.
"""
num_layers: int
mlp_dim: int
num_heads: int
positional_embedding: str = 'learned_1d'
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
stochastic_depth: float = 0.0
dtype: Any = jnp.float32
has_cls_token: bool = False
@nn.compact
def __call__(self, inputs: jnp.ndarray, *, train: bool = False):
"""Applies Transformer model on the inputs.
Args:
inputs: Input tokens of shape [batch, num_tokens, channels].
train: If in training mode, dropout and stochastic depth is applied.
Returns:
Encoded tokens.
"""
assert inputs.ndim == 3 # Shape is `[batch, len, emb]`.
dtype = jax.dtypes.canonicalize_dtype(self.dtype)
# Add positional embeddings to tokens.
if self.positional_embedding == 'learned_1d':
x = AddPositionEmbs(
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
name='posembed_input')(
inputs)
elif self.positional_embedding == 'sinusoidal_1d':
x = attention_layers.Add1DPositionEmbedding(posemb_init=None)(inputs)
elif self.positional_embedding == 'sinusoidal_2d':
batch, num_tokens, hidden_dim = inputs.shape
if self.has_cls_token:
num_tokens -= 1
height = width = int(np.sqrt(num_tokens))
if height * width != num_tokens:
raise ValueError('Input is assumed to be square for sinusoidal init.')
if self.has_cls_token:
inputs_reshape = inputs[:, 1:].reshape(
[batch, height, width, hidden_dim]
)
x = attention_layers.AddFixedSinCosPositionEmbedding()(inputs_reshape)
x = x.reshape([batch, num_tokens, hidden_dim])
x = jnp.concatenate([inputs[:, :1], x], axis=1)
else:
inputs_reshape = inputs.reshape([batch, height, width, hidden_dim])
x = attention_layers.AddFixedSinCosPositionEmbedding()(inputs_reshape)
x = x.reshape([batch, num_tokens, hidden_dim])
elif self.positional_embedding == 'none':
x = inputs
else:
raise ValueError('Unknown positional embedding: '
f'{self.positional_embedding}')
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
# Input Encoder.
for lyr in range(self.num_layers):
x = Encoder1DBlock(
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
stochastic_depth=(lyr / max(self.num_layers - 1, 1))
* self.stochastic_depth,
name=f'encoderblock_{lyr}',
dtype=dtype,
)(x, deterministic=not train)
encoded = nn.LayerNorm(name='encoder_norm')(x)
return encoded
class ViT(nn.Module):
"""Vision Transformer model.
Attributes:
num_classes: Number of output classes.
mlp_dim: Dimension of the mlp on top of attention block.
num_layers: Number of layers.
num_heads: Number of self-attention heads.
patches: Configuration of the patches extracted in the stem of the model.
hidden_size: Size of the hidden state of the output of model's stem.
positional_embedding: The type of positional embeddings to add to the
tokens at the beginning of the transformer encoder. Options are
{learned_1d, sinusoidal_2d, none}.
representation_size: Size of the representation layer in the model's head.
if None, we skip the extra projection + tanh activation at the end.
dropout_rate: Dropout rate.
attention_dropout_rate: Dropout for attention heads.
classifier: type of the classifier layer. Options are 'gap', 'gmp', 'gsp',
'token', 'none'.
dtype: JAX data type for activations.
"""
num_classes: int
mlp_dim: int
num_layers: int
num_heads: int
patches: ml_collections.ConfigDict
hidden_size: int
positional_embedding: str = 'learned_1d'
representation_size: Optional[int] = None
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
stochastic_depth: float = 0.0
classifier: str = 'gap'
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x: jnp.ndarray, *, train: bool, debug: bool = False):
fh, fw = self.patches.size
# Extracting patches and then embedding is in fact a single convolution.
x = nn.Conv(
self.hidden_size, (fh, fw),
strides=(fh, fw),
padding='VALID',
name='embedding')(
x)
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# If we want to add a class token, add it here.
if self.classifier == 'token':
cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype)
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1)
x = Encoder(
mlp_dim=self.mlp_dim,
num_layers=self.num_layers,
num_heads=self.num_heads,
positional_embedding=self.positional_embedding,
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
stochastic_depth=self.stochastic_depth,
dtype=self.dtype,
has_cls_token=self.classifier == 'token',
name='Transformer',
)(x, train=train)
if self.classifier in ('token', '0'):
x = x[:, 0]
elif self.classifier in ('gap', 'gmp', 'gsp'):
fn = {'gap': jnp.mean, 'gmp': jnp.max, 'gsp': jnp.sum}[self.classifier]
x = fn(x, axis=1)
elif self.classifier == 'map':
x = MAPHead(
num_heads=self.num_heads, mlp_dim=self.mlp_dim, dtype=self.dtype)(x)
elif self.classifier == 'none':
pass
else:
raise ValueError(f'Unknown classifier {self.classifier}')
if self.representation_size is not None:
x = nn.Dense(self.representation_size, name='pre_logits')(x)
x = nn.tanh(x)
else:
x = nn_layers.IdentityLayer(name='pre_logits')(x)
if self.num_classes > 0:
# If self.num_classes <= 0, we just return the backbone features.
x = nn.Dense(
self.num_classes,
kernel_init=nn.initializers.zeros,
name='output_projection')(
x)
return x
class ViTMultiLabelClassificationModel(MultiLabelClassificationModel):
"""Vision Transformer model for multi-label classification task."""
def build_flax_model(self)-> nn.Module:
dtype_str = self.config.get('model_dtype_str', 'float32')
if dtype_str != 'float32':
raise ValueError('`dtype` argument is not propagated properly '
'in the current implmentation, so only '
'`float32` is supported for now.')
return ViT(
num_classes=self.dataset_meta_data['num_classes'],
mlp_dim=self.config.model.mlp_dim,
num_layers=self.config.model.num_layers,
num_heads=self.config.model.num_heads,
positional_embedding=self.config.model.get('positional_embedding',
'learned_1d'),
representation_size=self.config.model.representation_size,
patches=self.config.model.patches,
hidden_size=self.config.model.hidden_size,
classifier=self.config.model.classifier,
dropout_rate=self.config.model.get('dropout_rate'),
attention_dropout_rate=self.config.model.get('attention_dropout_rate'),
stochastic_depth=self.config.model.get('stochastic_depth', 0.0),
dtype=getattr(jnp, dtype_str),
)
def default_flax_model_config(self) -> ml_collections.ConfigDict:
return ml_collections.ConfigDict({
'model':
dict(
num_heads=2,
num_layers=1,
representation_size=16,
mlp_dim=32,
dropout_rate=0.,
attention_dropout_rate=0.,
hidden_size=16,
patches={'size': (4, 4)},
classifier='gap',
data_dtype_str='float32')
})
def init_from_train_state(
self, train_state: Any, restored_train_state: Any,
restored_model_cfg: ml_collections.ConfigDict) -> Any:
"""Updates the train_state with data from restored_train_state.
This function is writen to be used for 'fine-tuning' experiments. Here, we
do some surgery to support larger resolutions (longer sequence length) in
the transformer block, with respect to the learned pos-embeddings.
Args:
train_state: A raw TrainState for the model.
restored_train_state: A TrainState that is loaded with parameters/state of
a pretrained model.
restored_model_cfg: Configuration of the model from which the
restored_train_state come from. Usually used for some asserts.
Returns:
Updated train_state.
"""
return init_vit_from_train_state(train_state, restored_train_state,
self.config, restored_model_cfg)
def load_augreg_params(self, train_state: Any, params_path: str,
model_cfg: ml_collections.ConfigDict) -> Any:
"""Loads parameters from an AugReg checkpoint.
See
https://github.com/google-research/vision_transformer/
and
https://arxiv.org/abs/2106.10270
for more information about these pre-trained models.
Args:
train_state: A raw TrainState for the model.
params_path: Path to an Augreg checkpoint. The model config is read from
the filename (e.g. a B/32 model starts with "B_32-").
model_cfg: Configuration of the model. Usually used for some asserts.
Returns:
Updated train_state with params replaced with the ones read from the
AugReg checkpoint.
"""
restored_model_cfg = _get_augreg_cfg(params_path)
assert tuple(restored_model_cfg.patches.size) == tuple(
model_cfg.patches.size)
assert restored_model_cfg.hidden_size == model_cfg.hidden_size
assert restored_model_cfg.mlp_dim == model_cfg.mlp_dim
assert restored_model_cfg.num_layers == model_cfg.num_layers
assert restored_model_cfg.num_heads == model_cfg.num_heads
assert restored_model_cfg.classifier == model_cfg.classifier
flattened = np.load(gfile.GFile(params_path, 'rb'))
restored_params = flax.traverse_util.unflatten_dict(
{tuple(k.split('/')): v for k, v in flattened.items()})
restored_params['output_projection'] = restored_params.pop('head')
if 'optimizer' in train_state:
# TODO(dehghani): Remove support for flax optim.
params = flax.core.unfreeze(train_state.optimizer.target)
_merge_params(params, restored_params, model_cfg, restored_model_cfg)
return train_state.replace(
optimizer=train_state.optimizer.replace(
target=flax.core.freeze(params)))
else:
params = flax.core.unfreeze(train_state.params)
_merge_params(params, restored_params, model_cfg, restored_model_cfg)
return train_state.replace(params=flax.core.freeze(params))
def _get_augreg_cfg(params_path):
model = params_path.split('/')[-1].split('-')[0]
assert model in ('B_16', 'B_32'), (
'Currently only B/16 and B/32 models are supported.')
sz = {'B_16': 16, 'B_32': 32}[model]
return ml_collections.ConfigDict(
dict(
num_classes=0,
mlp_dim=3072,
num_layers=12,
num_heads=12,
hidden_size=768,
classifier='token',
patches=dict(size=(sz, sz)),
dropout_rate=0.,
attention_dropout_rate=0.,
))
def _merge_params(params, restored_params, model_cfg, restored_model_cfg):
"""Merges `restored_params` into `params`."""
# Start moving parameters, one-by-one and apply changes if needed.
for m_key, m_params in restored_params.items():
if m_key == 'output_projection':
# For the classifier head, we use a the randomly initialized params and
# ignore the one from pretrained model.
pass
elif m_key == 'pre_logits':
if model_cfg.model.representation_size is None:
# We don't have representation_size in the new model, so let's ignore
# it from the pretained model, in case it has it.
# Note, removing the key from the dictionary is necessary to prevent
# obscure errors from the Flax optimizer.
params.pop(m_key, None)
else:
assert restored_model_cfg.model.representation_size
params[m_key] = m_params
elif m_key == 'Transformer':
for tm_key, tm_params in m_params.items():
if tm_key == 'posembed_input': # Might need resolution change.
posemb = params[m_key]['posembed_input']['pos_embedding']
restored_posemb = m_params['posembed_input']['pos_embedding']
if restored_posemb.shape != posemb.shape:
# Rescale the grid of pos, embeddings: param shape is (1, N, d).
logging.info('Resized variant: %s to %s', restored_posemb.shape,
posemb.shape)
ntok = posemb.shape[1]
if restored_model_cfg.model.classifier == 'token':
# The first token is the CLS token.
restored_posemb_grid = restored_posemb[0, 1:]
if model_cfg.model.classifier == 'token':
# CLS token in restored model and in target.
cls_tok = restored_posemb[:, :1]
ntok -= 1
else:
# CLS token in restored model, but not target.
cls_tok = restored_posemb[:, :0]
else:
restored_posemb_grid = restored_posemb[0]
if model_cfg.model.classifier == 'token':
# CLS token in target, but not restored model.
cls_tok = posemb[:, :1]
ntok -= 1
else:
# CLS token not in target or restored model.
cls_tok = restored_posemb[:, :0]
restored_gs = int(np.sqrt(len(restored_posemb_grid)))
gs = int(np.sqrt(ntok))
if restored_gs != gs: # We need resolution change.
logging.info('Grid-size from %s to %s.', restored_gs, gs)
restored_posemb_grid = restored_posemb_grid.reshape(
restored_gs, restored_gs, -1)
zoom = (gs / restored_gs, gs / restored_gs, 1)
restored_posemb_grid = scipy.ndimage.zoom(
restored_posemb_grid, zoom, order=1)
# Attach the CLS token again.
restored_posemb_grid = restored_posemb_grid.reshape(
1, gs * gs, -1)
restored_posemb = jnp.array(
np.concatenate([cls_tok, restored_posemb_grid], axis=1))
params[m_key][tm_key]['pos_embedding'] = restored_posemb
# Other parameters of the Transformer encoder if they are in the target.
elif tm_key in params[m_key]:
params[m_key][tm_key] = tm_params
else:
logging.info('Ignoring %s. In restored model\'s Transformer,'
'but not in target', m_key)
elif m_key in params:
# Use the rest if they are in the pretrained model.
params[m_key] = m_params
else:
logging.info('Ignoring %s. In restored model, but not in target', m_key)
def init_vit_from_train_state(
train_state: Any, restored_train_state: Any,
model_cfg: ml_collections.ConfigDict,
restored_model_cfg: ml_collections.ConfigDict) -> Any:
"""Updates the train_state with data from restored_train_state.
This function is written to be used for 'fine-tuning' experiments. Here, we
do some surgery to support larger resolutions (longer sequence length) in
the transformer block, with respect to the learned pos-embeddings.
The function supports train_states using either Optax or flax.optim (which
has been deprecated, and will be removed from Scenic.)
Args:
train_state: A raw TrainState for the model.
restored_train_state: A TrainState that is loaded with parameters/state of a
pretrained model.
model_cfg: Configuration of the model. Usually used for some asserts.
restored_model_cfg: Configuration of the model from which the
restored_train_state come from. Usually used for some asserts.
Returns:
Updated train_state.
"""
if hasattr(train_state, 'optimizer'):
# TODO(dehghani): Remove support for flax optim.
params = flax.core.unfreeze(train_state.optimizer.target)
restored_params = flax.core.unfreeze(restored_train_state.optimizer.target)
_merge_params(params, restored_params, model_cfg, restored_model_cfg)
return train_state.replace(
optimizer=train_state.optimizer.replace(
target=flax.core.freeze(params)))
else:
params = flax.core.unfreeze(train_state.params)
restored_params = flax.core.unfreeze(restored_train_state.params)
_merge_params(params, restored_params, model_cfg, restored_model_cfg)
return train_state.replace(params=flax.core.freeze(params))