s3y's picture
Upload folder using huggingface_hub
1be5b40 verified
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
from collections.abc import Callable
from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
from openpi.models import resnet as models_resnet
Array = Any
PRNGKey = Any
Shape = tuple[int]
Dtype = Any
class IdentityLayer(nn.Module):
"""Identity layer, convenient for giving a name to an array."""
@nn.compact
def __call__(self, x):
return x
class AddPositionEmbs(nn.Module):
"""Adds learned positional embeddings to the inputs.
Attributes:
posemb_init: positional embedding initializer.
"""
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, inputs):
"""Applies the AddPositionEmbs module.
Args:
inputs: Inputs to the layer.
Returns:
Output tensor with shape `(bs, timesteps, in_dim)`.
"""
# inputs.shape is (batch_size, seq_len, emb_dim).
assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}"
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype)
return inputs + pe
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
out_dim: int | None = None
dropout_rate: float = 0.1
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
@nn.compact
def __call__(self, inputs, *, deterministic):
"""Applies Transformer MlpBlock module."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)( # pytype: disable=wrong-arg-types
inputs
)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
output = nn.Dense(
features=actual_out_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)( # pytype: disable=wrong-arg-types
x
)
return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
Attributes:
inputs: input data.
mlp_dim: dimension of the mlp on top of attention block.
dtype: the dtype of the computation (default: float32).
dropout_rate: dropout rate.
attention_dropout_rate: dropout for attention heads.
deterministic: bool, deterministic or not (to apply dropout).
num_heads: Number of heads in nn.MultiHeadDotProductAttention
"""
mlp_dim: int
num_heads: int
dtype: Dtype = jnp.float32
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
@nn.compact
def __call__(self, inputs, deterministic):
"""Applies Encoder1DBlock module.
Args:
inputs: Inputs to the layer.
deterministic: Dropout will not be applied when set to true.
Returns:
output after transformer encoder block.
"""
# Attention block.
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.MultiHeadDotProductAttention(
dtype=self.dtype,
kernel_init=nn.initializers.xavier_uniform(),
broadcast_dropout=False,
deterministic=deterministic,
dropout_rate=self.attention_dropout_rate,
num_heads=self.num_heads,
# why isn't this true by default???
force_fp32_for_softmax=True,
)(x, x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = x + inputs
# MLP block.
y = nn.LayerNorm(dtype=self.dtype)(x)
y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
y, deterministic=deterministic
)
return x + y, None
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation.
Attributes:
num_layers: number of layers
mlp_dim: dimension of the mlp on top of attention block
num_heads: Number of heads in nn.MultiHeadDotProductAttention
dropout_rate: dropout rate.
attention_dropout_rate: dropout rate in self attention.
"""
dtype: jax.typing.DTypeLike
num_layers: int
mlp_dim: int
num_heads: int
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
add_position_embedding: bool = True
@nn.compact
def __call__(self, x, *, train):
"""Applies Transformer model on the inputs.
Args:
x: Inputs to the layer.
train: Set to `True` when training.
Returns:
output of a transformer encoder.
"""
assert x.ndim == 3 # (batch, len, emb)
if self.add_position_embedding:
x = AddPositionEmbs(
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
name="posembed_input",
)(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
x = x.astype(self.dtype)
# Input Encoder
block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))
x, _ = nn.scan(
block,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=nn.broadcast,
length=self.num_layers,
)(
name="encoderblock",
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
dtype=self.dtype,
num_heads=self.num_heads,
)(x, not train)
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
class VisionTransformer(nn.Module):
"""VisionTransformer."""
dtype: jax.typing.DTypeLike
num_classes: int
patches: Any
transformer: Any
hidden_size: int
resnet: Any | None = None
representation_size: int | None = None
classifier: str = "token"
head_bias_init: float = 0.0
encoder: type[nn.Module] = Encoder
model_name: str | None = None
@nn.compact
def __call__(self, inputs, *, train):
x = inputs
# (Possibly partial) ResNet root.
if self.resnet is not None:
width = int(64 * self.resnet.width_factor)
# Root block.
x = models_resnet.StdConv(
features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root"
)(x)
x = nn.GroupNorm(name="gn_root")(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
# ResNet stages.
if self.resnet.num_layers:
x = models_resnet.ResNetStage(
block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1"
)(x)
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
x = models_resnet.ResNetStage(
block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
)(x)
n, h, w, c = x.shape
# We can merge s2d+emb into a single conv; it's the same.
x = nn.Conv(
features=self.hidden_size,
kernel_size=self.patches.size,
strides=self.patches.size,
padding="VALID",
name="embedding",
)(x)
# Here, x is a grid of embeddings.
# (Possibly partial) Transformer.
if self.transformer is not None:
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# If we want to add a class token, add it here.
if self.classifier in ["token", "token_unpooled"]:
cls = self.param("cls", nn.initializers.zeros, (1, 1, c))
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1)
x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train)
if self.classifier == "token":
x = x[:, 0]
elif self.classifier == "gap":
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
elif self.classifier in ["unpooled", "token_unpooled"]:
pass
else:
raise ValueError(f"Invalid classifier={self.classifier}")
if self.representation_size is not None:
x = nn.Dense(features=self.representation_size, name="pre_logits")(x)
x = nn.tanh(x)
else:
x = IdentityLayer(name="pre_logits")(x)
if self.num_classes:
x = nn.Dense(
features=self.num_classes,
name="head",
kernel_init=nn.initializers.zeros,
bias_init=nn.initializers.constant(self.head_bias_init),
)(x)
return x