# Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py.""" from collections.abc import Callable from typing import Any import flax.linen as nn import jax import jax.numpy as jnp from openpi.models import resnet as models_resnet Array = Any PRNGKey = Any Shape = tuple[int] Dtype = Any class IdentityLayer(nn.Module): """Identity layer, convenient for giving a name to an array.""" @nn.compact def __call__(self, x): return x class AddPositionEmbs(nn.Module): """Adds learned positional embeddings to the inputs. Attributes: posemb_init: positional embedding initializer. """ posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] param_dtype: Dtype = jnp.float32 @nn.compact def __call__(self, inputs): """Applies the AddPositionEmbs module. Args: inputs: Inputs to the layer. Returns: Output tensor with shape `(bs, timesteps, in_dim)`. """ # inputs.shape is (batch_size, seq_len, emb_dim). assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}" pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype) return inputs + pe class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: int dtype: Dtype = jnp.float32 param_dtype: Dtype = jnp.float32 out_dim: int | None = None dropout_rate: float = 0.1 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform() bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, *, deterministic): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( features=self.mlp_dim, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, )( # pytype: disable=wrong-arg-types inputs ) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) output = nn.Dense( features=actual_out_dim, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, )( # pytype: disable=wrong-arg-types x ) return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic) class Encoder1DBlock(nn.Module): """Transformer encoder layer. Attributes: inputs: input data. mlp_dim: dimension of the mlp on top of attention block. dtype: the dtype of the computation (default: float32). dropout_rate: dropout rate. attention_dropout_rate: dropout for attention heads. deterministic: bool, deterministic or not (to apply dropout). num_heads: Number of heads in nn.MultiHeadDotProductAttention """ mlp_dim: int num_heads: int dtype: Dtype = jnp.float32 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 @nn.compact def __call__(self, inputs, deterministic): """Applies Encoder1DBlock module. Args: inputs: Inputs to the layer. deterministic: Dropout will not be applied when set to true. Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" x = nn.LayerNorm(dtype=self.dtype)(inputs) x = nn.MultiHeadDotProductAttention( dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=self.attention_dropout_rate, num_heads=self.num_heads, # why isn't this true by default??? force_fp32_for_softmax=True, )(x, x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=self.dtype)(x) y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( y, deterministic=deterministic ) return x + y, None class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. Attributes: num_layers: number of layers mlp_dim: dimension of the mlp on top of attention block num_heads: Number of heads in nn.MultiHeadDotProductAttention dropout_rate: dropout rate. attention_dropout_rate: dropout rate in self attention. """ dtype: jax.typing.DTypeLike num_layers: int mlp_dim: int num_heads: int dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 add_position_embedding: bool = True @nn.compact def __call__(self, x, *, train): """Applies Transformer model on the inputs. Args: x: Inputs to the layer. train: Set to `True` when training. Returns: output of a transformer encoder. """ assert x.ndim == 3 # (batch, len, emb) if self.add_position_embedding: x = AddPositionEmbs( posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. name="posembed_input", )(x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x.astype(self.dtype) # Input Encoder block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,)) x, _ = nn.scan( block, variable_axes={"params": 0}, split_rngs={"params": True, "dropout": True}, in_axes=nn.broadcast, length=self.num_layers, )( name="encoderblock", mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, dtype=self.dtype, num_heads=self.num_heads, )(x, not train) return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x) class VisionTransformer(nn.Module): """VisionTransformer.""" dtype: jax.typing.DTypeLike num_classes: int patches: Any transformer: Any hidden_size: int resnet: Any | None = None representation_size: int | None = None classifier: str = "token" head_bias_init: float = 0.0 encoder: type[nn.Module] = Encoder model_name: str | None = None @nn.compact def __call__(self, inputs, *, train): x = inputs # (Possibly partial) ResNet root. if self.resnet is not None: width = int(64 * self.resnet.width_factor) # Root block. x = models_resnet.StdConv( features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root" )(x) x = nn.GroupNorm(name="gn_root")(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") # ResNet stages. if self.resnet.num_layers: x = models_resnet.ResNetStage( block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1" )(x) for i, block_size in enumerate(self.resnet.num_layers[1:], 1): x = models_resnet.ResNetStage( block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}" )(x) n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv( features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding="VALID", name="embedding", )(x) # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if self.transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier in ["token", "token_unpooled"]: cls = self.param("cls", nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train) if self.classifier == "token": x = x[:, 0] elif self.classifier == "gap": x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) elif self.classifier in ["unpooled", "token_unpooled"]: pass else: raise ValueError(f"Invalid classifier={self.classifier}") if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name="pre_logits")(x) x = nn.tanh(x) else: x = IdentityLayer(name="pre_logits")(x) if self.num_classes: x = nn.Dense( features=self.num_classes, name="head", kernel_init=nn.initializers.zeros, bias_init=nn.initializers.constant(self.head_bias_init), )(x) return x