|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py.""" |
|
|
|
|
|
from collections.abc import Callable |
|
|
from typing import Any |
|
|
|
|
|
import flax.linen as nn |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
|
|
|
from openpi.models import resnet as models_resnet |
|
|
|
|
|
Array = Any |
|
|
PRNGKey = Any |
|
|
Shape = tuple[int] |
|
|
Dtype = Any |
|
|
|
|
|
|
|
|
class IdentityLayer(nn.Module): |
|
|
"""Identity layer, convenient for giving a name to an array.""" |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
return x |
|
|
|
|
|
|
|
|
class AddPositionEmbs(nn.Module): |
|
|
"""Adds learned positional embeddings to the inputs. |
|
|
|
|
|
Attributes: |
|
|
posemb_init: positional embedding initializer. |
|
|
""" |
|
|
|
|
|
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] |
|
|
param_dtype: Dtype = jnp.float32 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, inputs): |
|
|
"""Applies the AddPositionEmbs module. |
|
|
|
|
|
Args: |
|
|
inputs: Inputs to the layer. |
|
|
|
|
|
Returns: |
|
|
Output tensor with shape `(bs, timesteps, in_dim)`. |
|
|
""" |
|
|
|
|
|
assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}" |
|
|
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) |
|
|
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype) |
|
|
return inputs + pe |
|
|
|
|
|
|
|
|
class MlpBlock(nn.Module): |
|
|
"""Transformer MLP / feed-forward block.""" |
|
|
|
|
|
mlp_dim: int |
|
|
dtype: Dtype = jnp.float32 |
|
|
param_dtype: Dtype = jnp.float32 |
|
|
out_dim: int | None = None |
|
|
dropout_rate: float = 0.1 |
|
|
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform() |
|
|
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6) |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, inputs, *, deterministic): |
|
|
"""Applies Transformer MlpBlock module.""" |
|
|
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim |
|
|
x = nn.Dense( |
|
|
features=self.mlp_dim, |
|
|
dtype=self.dtype, |
|
|
param_dtype=self.param_dtype, |
|
|
kernel_init=self.kernel_init, |
|
|
bias_init=self.bias_init, |
|
|
)( |
|
|
inputs |
|
|
) |
|
|
x = nn.gelu(x) |
|
|
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) |
|
|
output = nn.Dense( |
|
|
features=actual_out_dim, |
|
|
dtype=self.dtype, |
|
|
param_dtype=self.param_dtype, |
|
|
kernel_init=self.kernel_init, |
|
|
bias_init=self.bias_init, |
|
|
)( |
|
|
x |
|
|
) |
|
|
return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic) |
|
|
|
|
|
|
|
|
class Encoder1DBlock(nn.Module): |
|
|
"""Transformer encoder layer. |
|
|
|
|
|
Attributes: |
|
|
inputs: input data. |
|
|
mlp_dim: dimension of the mlp on top of attention block. |
|
|
dtype: the dtype of the computation (default: float32). |
|
|
dropout_rate: dropout rate. |
|
|
attention_dropout_rate: dropout for attention heads. |
|
|
deterministic: bool, deterministic or not (to apply dropout). |
|
|
num_heads: Number of heads in nn.MultiHeadDotProductAttention |
|
|
""" |
|
|
|
|
|
mlp_dim: int |
|
|
num_heads: int |
|
|
dtype: Dtype = jnp.float32 |
|
|
dropout_rate: float = 0.1 |
|
|
attention_dropout_rate: float = 0.1 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, inputs, deterministic): |
|
|
"""Applies Encoder1DBlock module. |
|
|
|
|
|
Args: |
|
|
inputs: Inputs to the layer. |
|
|
deterministic: Dropout will not be applied when set to true. |
|
|
|
|
|
Returns: |
|
|
output after transformer encoder block. |
|
|
""" |
|
|
|
|
|
|
|
|
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" |
|
|
x = nn.LayerNorm(dtype=self.dtype)(inputs) |
|
|
x = nn.MultiHeadDotProductAttention( |
|
|
dtype=self.dtype, |
|
|
kernel_init=nn.initializers.xavier_uniform(), |
|
|
broadcast_dropout=False, |
|
|
deterministic=deterministic, |
|
|
dropout_rate=self.attention_dropout_rate, |
|
|
num_heads=self.num_heads, |
|
|
|
|
|
force_fp32_for_softmax=True, |
|
|
)(x, x) |
|
|
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) |
|
|
x = x + inputs |
|
|
|
|
|
|
|
|
y = nn.LayerNorm(dtype=self.dtype)(x) |
|
|
y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( |
|
|
y, deterministic=deterministic |
|
|
) |
|
|
|
|
|
return x + y, None |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
"""Transformer Model Encoder for sequence to sequence translation. |
|
|
|
|
|
Attributes: |
|
|
num_layers: number of layers |
|
|
mlp_dim: dimension of the mlp on top of attention block |
|
|
num_heads: Number of heads in nn.MultiHeadDotProductAttention |
|
|
dropout_rate: dropout rate. |
|
|
attention_dropout_rate: dropout rate in self attention. |
|
|
""" |
|
|
|
|
|
dtype: jax.typing.DTypeLike |
|
|
num_layers: int |
|
|
mlp_dim: int |
|
|
num_heads: int |
|
|
dropout_rate: float = 0.1 |
|
|
attention_dropout_rate: float = 0.1 |
|
|
add_position_embedding: bool = True |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x, *, train): |
|
|
"""Applies Transformer model on the inputs. |
|
|
|
|
|
Args: |
|
|
x: Inputs to the layer. |
|
|
train: Set to `True` when training. |
|
|
|
|
|
Returns: |
|
|
output of a transformer encoder. |
|
|
""" |
|
|
assert x.ndim == 3 |
|
|
|
|
|
if self.add_position_embedding: |
|
|
x = AddPositionEmbs( |
|
|
posemb_init=nn.initializers.normal(stddev=0.02), |
|
|
name="posembed_input", |
|
|
)(x) |
|
|
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) |
|
|
|
|
|
x = x.astype(self.dtype) |
|
|
|
|
|
block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,)) |
|
|
x, _ = nn.scan( |
|
|
block, |
|
|
variable_axes={"params": 0}, |
|
|
split_rngs={"params": True, "dropout": True}, |
|
|
in_axes=nn.broadcast, |
|
|
length=self.num_layers, |
|
|
)( |
|
|
name="encoderblock", |
|
|
mlp_dim=self.mlp_dim, |
|
|
dropout_rate=self.dropout_rate, |
|
|
attention_dropout_rate=self.attention_dropout_rate, |
|
|
dtype=self.dtype, |
|
|
num_heads=self.num_heads, |
|
|
)(x, not train) |
|
|
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x) |
|
|
|
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
|
"""VisionTransformer.""" |
|
|
|
|
|
dtype: jax.typing.DTypeLike |
|
|
num_classes: int |
|
|
patches: Any |
|
|
transformer: Any |
|
|
hidden_size: int |
|
|
resnet: Any | None = None |
|
|
representation_size: int | None = None |
|
|
classifier: str = "token" |
|
|
head_bias_init: float = 0.0 |
|
|
encoder: type[nn.Module] = Encoder |
|
|
model_name: str | None = None |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, inputs, *, train): |
|
|
x = inputs |
|
|
|
|
|
if self.resnet is not None: |
|
|
width = int(64 * self.resnet.width_factor) |
|
|
|
|
|
|
|
|
x = models_resnet.StdConv( |
|
|
features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root" |
|
|
)(x) |
|
|
x = nn.GroupNorm(name="gn_root")(x) |
|
|
x = nn.relu(x) |
|
|
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") |
|
|
|
|
|
|
|
|
if self.resnet.num_layers: |
|
|
x = models_resnet.ResNetStage( |
|
|
block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1" |
|
|
)(x) |
|
|
for i, block_size in enumerate(self.resnet.num_layers[1:], 1): |
|
|
x = models_resnet.ResNetStage( |
|
|
block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}" |
|
|
)(x) |
|
|
|
|
|
n, h, w, c = x.shape |
|
|
|
|
|
|
|
|
x = nn.Conv( |
|
|
features=self.hidden_size, |
|
|
kernel_size=self.patches.size, |
|
|
strides=self.patches.size, |
|
|
padding="VALID", |
|
|
name="embedding", |
|
|
)(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.transformer is not None: |
|
|
n, h, w, c = x.shape |
|
|
x = jnp.reshape(x, [n, h * w, c]) |
|
|
|
|
|
|
|
|
if self.classifier in ["token", "token_unpooled"]: |
|
|
cls = self.param("cls", nn.initializers.zeros, (1, 1, c)) |
|
|
cls = jnp.tile(cls, [n, 1, 1]) |
|
|
x = jnp.concatenate([cls, x], axis=1) |
|
|
|
|
|
x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train) |
|
|
|
|
|
if self.classifier == "token": |
|
|
x = x[:, 0] |
|
|
elif self.classifier == "gap": |
|
|
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) |
|
|
elif self.classifier in ["unpooled", "token_unpooled"]: |
|
|
pass |
|
|
else: |
|
|
raise ValueError(f"Invalid classifier={self.classifier}") |
|
|
|
|
|
if self.representation_size is not None: |
|
|
x = nn.Dense(features=self.representation_size, name="pre_logits")(x) |
|
|
x = nn.tanh(x) |
|
|
else: |
|
|
x = IdentityLayer(name="pre_logits")(x) |
|
|
|
|
|
if self.num_classes: |
|
|
x = nn.Dense( |
|
|
features=self.num_classes, |
|
|
name="head", |
|
|
kernel_init=nn.initializers.zeros, |
|
|
bias_init=nn.initializers.constant(self.head_bias_init), |
|
|
)(x) |
|
|
return x |
|
|
|