owlv2 / scenic /projects /gerald /models /git_vit.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""ViT implementation.
Pytorch reference: https://github.com/microsoft/GenerativeImage2Text/blob/\
main/generativeimage2text/layers/CLIP/model.py
Compare to a plain ViT, this implementation uses quick_gelu, supports
configurable normalizations before/ after the transformer blocks.
Currently the code also supports windows attention and relative positional
embedding. These are not used in the original GIT, but can be used for larger
input size in future developed.
"""
import functools
from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
KERNEL_INIT = {
'normal': nn.initializers.normal(stddev=0.02),
}
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings.
Attributes:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
beit_like_qkv_bias (bool): no bias for k.
"""
dim: int
num_heads: int = 8
qkv_bias: bool = True
beit_like_qkv_bias: bool = False
kernel_init: str = 'normal'
with_grid_tokens: bool = False
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
"""Forward a block.
Args:
x: if self.with_grid_tokens == False (default), x should be in shape
(batch_size, num_tokens, dim);
if self.with_grid_tokens == True, x should be in shape
(batch_size, height, width, dim);
Returns:
x: the same shape as the input.
"""
batch, num_tokens, _ = x.shape
head_dim = self.dim // self.num_heads
if self.beit_like_qkv_bias:
q_bias = self.param(
'q_bias', nn.initializers.zeros, (self.dim,))
v_bias = self.param(
'v_bias', nn.initializers.zeros, (self.dim,))
k_bias = jnp.zeros((self.dim,), dtype=jnp.float32)
qkv_bias = jnp.concatenate([q_bias, k_bias, v_bias], axis=0)
qkv = nn.Dense(
self.dim * 3, use_bias=False, dtype=self.dtype,
kernel_init=KERNEL_INIT[self.kernel_init], name='qkv')(
x) # batch x height x width x 3dim
qkv = qkv + qkv_bias[None, None, :]
else:
qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name='qkv')(
x) # batch x num_tokens x 3dim
qkv = qkv.reshape(batch, num_tokens, 3, self.num_heads, -1).transpose(
2, 0, 3, 1, 4) # 3 x batch x num_heads x num_tokens x D
qkv = qkv.reshape(3, batch * self.num_heads, num_tokens, -1)
q, k, v = qkv[0], qkv[1], qkv[2] # [batch * num_heads, num_tokens, D]
attn = (q * (head_dim ** -0.5)) @ k.transpose(
0, 2, 1) # [batch * num_heads, num_tokens, num_tokens]
attn = jax.nn.softmax(attn)
x = (attn @ v).reshape(
batch, self.num_heads, num_tokens, -1).transpose(
0, 2, 1, 3).reshape(batch, num_tokens, -1)
x = nn.Dense(self.dim, name='proj')(x)
return x
def quick_gelu(x: jnp.ndarray) -> jnp.ndarray:
return x * jax.nn.sigmoid(1.702 * x)
class Mlp(nn.Module):
"""Multilayer perceptron."""
hidden_features: int
out_features: int
kernel_init: str = 'normal'
dtype: jnp.dtype = jnp.float32
activation: str = 'quick_gelu'
@nn.compact
def __call__(self, x):
x = nn.Dense(
self.hidden_features, dtype=self.dtype,
kernel_init=KERNEL_INIT[self.kernel_init], name='fc1')(x)
if self.activation == 'quick_gelu':
x = quick_gelu(x)
elif self.activation == 'gelu':
x = nn.gelu(x, approximate=False)
else:
raise NotImplementedError(self.activation)
x = nn.Dense(
self.out_features, dtype=self.dtype,
kernel_init=KERNEL_INIT[self.kernel_init], name='fc2')(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual blocks.
Attributes:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
beit_like_qkv_bias (bool): no bias for k.
drop_path (float): Stochastic depth rate.
"""
dim: int
num_heads: int
mlp_ratio: float = 4.0
qkv_bias: bool = True
beit_like_qkv_bias: bool = False
mlp_activation: str = 'quick_gelu'
drop_path: float = 0.0
layer_scale_init_value: float = -1.0
kernel_init: str = 'normal'
with_grid_tokens: bool = False
dtype: jnp.dtype = jnp.float32
def get_keep_pattern(self,
x: jnp.ndarray,
deterministic: bool):
"""DropPath Layer."""
if not deterministic and self.drop_path:
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
drop_pattern = jax.random.bernoulli(
self.make_rng('dropout'), self.drop_path, shape).astype(self.dtype)
keep_pattern = (1. - drop_pattern)
if self.drop_path < 1.:
keep_pattern = keep_pattern / (1. - self.drop_path)
return keep_pattern
else:
return 1.0
@nn.compact
def __call__(self, x, train: bool = False):
shortcut = x
ln = functools.partial(nn.LayerNorm, epsilon=1e-6)
x = ln(name='norm1')(x)
x = Attention(
self.dim,
num_heads=self.num_heads,
qkv_bias=self.qkv_bias,
beit_like_qkv_bias=self.beit_like_qkv_bias,
with_grid_tokens=self.with_grid_tokens,
name='attn')(x)
if self.layer_scale_init_value > 0:
gamma_1 = self.param(
'gamma_1',
nn.initializers.constant(self.layer_scale_init_value),
(self.dim))
x = x * gamma_1[..., :]
x = shortcut + self.get_keep_pattern(x, not train) * x
y = ln(name='norm2')(x)
y = Mlp(
int(self.dim * self.mlp_ratio),
self.dim,
kernel_init=self.kernel_init,
activation=self.mlp_activation,
dtype=self.dtype,
name='mlp')(y)
if self.layer_scale_init_value > 0:
gamma_2 = self.param(
'gamma_2',
nn.initializers.constant(self.layer_scale_init_value),
(self.dim))
y = y * gamma_2[..., :]
x = x + self.get_keep_pattern(y, not train) * y
return x
class ViT(nn.Module):
"""This module implements Vision Transformer (ViT) backbone.
Attributes:
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
beit_like_qkv_bias (bool): no bias for k.
drop_path_rate (float): Stochastic depth rate.
use_abs_pos (bool): If True, use absolute positional embeddings.
pretrain_img_size (int): input image size for pretraining models.
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
layer_scale_init_value (float): if add a scaling layer with the initialized
value. Negative means not add such layers.
kernel_init (str): functions to initialize layers. Currently only supports
'normal'.
freeze_vit_layer: (int). Freeze early layers.
use_ln_pre (bool): if use a layer norm before transformer blocks. Used in
CLIP/ GIT. Not used in MAE/ ViTDet.
use_ln_post (bool): if use a layer norm after transformer blocks. Used in
CLIP/ GIT. Not used in MAE/ ViTDet.
pe_bias (bool): if the patch-embedding layer has bias. Not used in
CLIP/ GIT. Used in MAE/ ViTDet.
use_class_embedding (bool): if use the cls_token in the attention. If True,
the attention block takes flattened tokens as input. If False, the
attention block takes grid feature as input.
dtype: jnp.dtype.
window_block_indexes: Never used. Keep to make legacy configs runable.
use_rel_pos: Never used. Keep to make lagacy configs runable.
"""
patch_size: int = 16
in_chans: int = 3
embed_dim: int = 768
depth: int = 12
num_heads: int = 12
mlp_ratio: float = 4.0
qkv_bias: bool = True
beit_like_qkv_bias: bool = False
mlp_activation: str = 'quick_gelu'
drop_path_rate: float = 0.1
use_abs_pos: bool = True
pretrain_img_size: int = 224
pretrain_use_cls_token: bool = True
layer_scale_init_value: float = -1.0
kernel_init: str = 'normal'
freeze_vit_layer: int = -1
use_ln_pre: bool = False
use_ln_post: bool = False
pe_bias: bool = True
use_class_embedding: bool = True
dtype: jnp.dtype = jnp.float32
token_mask_probability: float = -1.0
token_mask_test: bool = False
window_block_indexes: Any = None
use_rel_pos: Any = None
def _get_abs_pos(self, abs_pos, hw):
"""Calculate absolute positional embeddings.
If needed, resize embeddings and remove cls_token dimension for the original
embeddings.
Args:
abs_pos (array): absolute positional embeddings with (1, num_position, C).
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
h, w = hw
if self.pretrain_use_cls_token:
abs_pos_no_cls = abs_pos[:, 1:]
else:
abs_pos_no_cls = abs_pos
xy_num = abs_pos_no_cls.shape[1]
size = int(xy_num ** 0.5)
assert size * size == xy_num
abs_pos_no_cls = abs_pos_no_cls.reshape(
abs_pos_no_cls.shape[0], size, size, -1)
if size != h or size != w:
abs_pos_no_cls = jax.image.resize(
abs_pos_no_cls,
(abs_pos_no_cls.shape[0], h, w, abs_pos_no_cls.shape[3]),
method='bicubic',
)
if self.use_class_embedding:
abs_pos_no_cls = abs_pos_no_cls.reshape(
abs_pos_no_cls.shape[0], h * w, -1)
new_abs_pos = jnp.concatenate([abs_pos[:, :1], abs_pos_no_cls], axis=1)
else:
new_abs_pos = abs_pos_no_cls
else:
if self.use_class_embedding:
new_abs_pos = abs_pos
else:
new_abs_pos = abs_pos_no_cls
return new_abs_pos
@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = False):
"""Forward ViT backbone.
Args:
x: (batch_size, height, width, 3) the input image
train: bool;
Returns:
x: the features after the backbone. (batch_size, seq_length, embed_dim).
"""
x = nn.Conv(
self.embed_dim, (self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding='VALID',
use_bias=self.pe_bias,
name='patch_embed.proj')(x)
if self.use_class_embedding:
class_embedding = self.param(
'class_embedding', nn.initializers.zeros, (1, 1, self.embed_dim))
class_embedding = jnp.broadcast_to(
class_embedding, (x.shape[0], 1, self.embed_dim))
x = x.reshape(x.shape[0], -1, x.shape[-1]) # (B, hw, C)
x = jnp.concatenate([class_embedding, x], axis=1)
if self.use_abs_pos:
num_patches = (self.pretrain_img_size // self.patch_size) ** 2
num_positions = (
num_patches + 1) if self.pretrain_use_cls_token else num_patches
pos_embed = self.param(
'pos_embed', nn.initializers.zeros,
(1, num_positions, self.embed_dim))
if self.use_class_embedding:
input_size = int((x.shape[1] - 1) ** 0.5)
x = x + self._get_abs_pos(pos_embed, (input_size, input_size))
else:
x = x + self._get_abs_pos(pos_embed, (x.shape[1], x.shape[2]))
# TODO(zhouxy): The current MAE is not optimal. We sample a single index
# for all images in the batch. We should use different indexes each image.
if self.token_mask_probability > 0:
assert self.use_class_embedding
num_pixel_tokens = x.shape[1] - 1
num_remaining_tokens = int(
(1.0 - self.token_mask_probability) * num_pixel_tokens)
if train:
inds = jax.random.permutation(
self.make_rng('dropout'),
jnp.arange(num_pixel_tokens, dtype=jnp.int32),
independent=True,
)[:num_remaining_tokens]
else:
if self.token_mask_test:
inds = jnp.linspace(
0, num_pixel_tokens, num_remaining_tokens,
endpoint=False, dtype=jnp.int32)
else:
inds = jnp.arange(num_pixel_tokens, dtype=jnp.int32)
unmasked_pixel_tokens = jnp.take_along_axis(
x[:, 1:], inds[None, :, None], axis=1)
x = jnp.concatenate([x[:, :1], unmasked_pixel_tokens], axis=1)
dp_rates = [
self.drop_path_rate * i / (self.depth - 1) for i in range(self.depth)]
if self.use_ln_pre:
x = nn.LayerNorm(name='ln_pre')(x)
for i in range(self.depth):
x = Block(
dim=self.embed_dim,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias,
beit_like_qkv_bias=self.beit_like_qkv_bias,
mlp_activation=self.mlp_activation,
drop_path=dp_rates[i],
with_grid_tokens=not self.use_class_embedding,
layer_scale_init_value=self.layer_scale_init_value,
name=f'blocks.{i}',
)(x, train=train)
if i + 1 == self.freeze_vit_layer:
x = jax.lax.stop_gradient(x)
if self.use_ln_post:
x = nn.LayerNorm(name='ln_post')(x)
return x