# Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A refactored and simplified ViT adoptation for Pi, taken from big_vision.""" from collections.abc import Sequence import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import openpi.training.sharding as sharding def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32): """Follows the MoCo v3 logic.""" y, x = jnp.mgrid[:h, :w] assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" omega = jnp.arange(width // 4) / (width // 4 - 1) omega = 1.0 / (temperature**omega) y = jnp.einsum("m,d->md", y.flatten(), omega) x = jnp.einsum("m,d->md", x.flatten(), omega) pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) return jnp.asarray(pe, dtype)[None, :, :] def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32): if typ == "learn": return self.param( name, nn.initializers.normal(stddev=1 / np.sqrt(width)), (1, np.prod(seqshape), width), dtype, ) if typ == "sincos2d": return posemb_sincos_2d(*seqshape, width, dtype=dtype) raise ValueError(f"Unknown posemb type: {typ}") class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: int | None = None # Defaults to 4x input dim dropout: float = 0.0 dtype_mm: str = "float32" @nn.compact def __call__(self, x, deterministic=True): # noqa: FBT002 """Applies Transformer MlpBlock module.""" inits = { "kernel_init": nn.initializers.xavier_uniform(), "bias_init": nn.initializers.normal(stddev=1e-6), } _, _, d = x.shape # n,l,d x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout)(x, deterministic) return nn.Dense(d, dtype=self.dtype_mm, **inits)(x) class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 dropout: float = 0.0 dtype_mm: str = "float32" @nn.compact def __call__(self, x, deterministic=True): # noqa: FBT002 out = {} x = sharding.activation_sharding_constraint(x) y = nn.LayerNorm(dtype=self.dtype_mm)(x) y = out["sa"] = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=deterministic, dtype=self.dtype_mm, )(y, y) y = sharding.activation_sharding_constraint(y) y = nn.Dropout(rate=self.dropout)(y, deterministic) x = out["+sa"] = x + y y = nn.LayerNorm(dtype=self.dtype_mm)(x) y = out["mlp"] = MlpBlock( mlp_dim=self.mlp_dim, dropout=self.dropout, dtype_mm=self.dtype_mm, )(y, deterministic) y = sharding.activation_sharding_constraint(y) y = nn.Dropout(rate=self.dropout)(y, deterministic) x = out["+mlp"] = x + y x = sharding.activation_sharding_constraint(x) return x, out class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" depth: int mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 dropout: float = 0.0 scan: bool = False remat_policy: str = "nothing_saveable" dtype_mm: str = "float32" @nn.compact def __call__(self, x, deterministic=True): # noqa: FBT002 out = {} if self.scan: block = nn.remat( Encoder1DBlock, prevent_cse=False, static_argnums=(2,), # 0=self, 2=deterministic policy=getattr(jax.checkpoint_policies, self.remat_policy, None), ) x, scan_out = nn.scan( block, variable_axes={"params": 0}, split_rngs={"params": True, "dropout": True}, in_axes=nn.broadcast, length=self.depth, )( name="encoderblock", dtype_mm=self.dtype_mm, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout=self.dropout, )(x, deterministic) for lyr in range(self.depth): out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out) else: # Input Encoder for lyr in range(self.depth): block_cur = Encoder1DBlock( name=f"encoderblock_{lyr}", dtype_mm=self.dtype_mm, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout=self.dropout, ) x, out[f"block{lyr:02d}"] = block_cur(x, deterministic) out["pre_ln"] = x # Alias for last block, but without the number in it. return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 dtype_mm: str = "float32" @nn.compact def __call__(self, x): n, _, d = x.shape # n,l,d probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, dtype=self.dtype_mm, kernel_init=nn.initializers.xavier_uniform(), )(probe, x) y = nn.LayerNorm(dtype=self.dtype_mm)(x) x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y) return x[:, 0] class _Module(nn.Module): """ViT model.""" num_classes: int | None = None patch_size: Sequence[int] = (16, 16) width: int = 768 depth: int = 12 mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 posemb: str = "learn" # Can also be "sincos2d" rep_size: int | bool = False dropout: float = 0.0 pool_type: str = "gap" # Can also be "map" or "tok" head_zeroinit: bool = True scan: bool = False # or "dots_with_no_batch_dims_saveable" for more speed (memory costly) remat_policy: str = "nothing_saveable" dtype_mm: str = "float32" @nn.compact def __call__(self, image, *, train=False): out = {} # Kevin edit: do patch extraction and posemb in float32, # because I feel like it's a bit safer. image = jnp.asarray(image, jnp.float32) # Patch extraction x = out["stem"] = nn.Conv( self.width, self.patch_size, strides=self.patch_size, padding="VALID", name="embedding", dtype=jnp.float32, )(image) n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # Add posemb before adding extra token. x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32) if self.pool_type == "tok": cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype) x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) n, _, c = x.shape # n,l,d x = nn.Dropout(rate=self.dropout)(x, not train) # Kevin edit: now cast back to dtype_mm (potentially half precision) x = x.astype(self.dtype_mm) x, out["encoder"] = Encoder( depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout=self.dropout, scan=self.scan, remat_policy=self.remat_policy, dtype_mm=self.dtype_mm, name="Transformer", )(x, deterministic=not train) encoded = out["encoded"] = x if self.pool_type == "map": x = out["head_input"] = MAPHead( num_heads=self.num_heads, mlp_dim=self.mlp_dim, dtype=self.dtype_mm, )(x) elif self.pool_type == "gap": x = out["head_input"] = jnp.mean(x, axis=1) elif self.pool_type == "0": x = out["head_input"] = x[:, 0] elif self.pool_type == "tok": x = out["head_input"] = x[:, 0] encoded = encoded[:, 1:] elif self.pool_type == "none": pass else: raise ValueError(f"Unknown pool type: '{self.pool_type}'") x_2d = jnp.reshape(encoded, [n, h, w, -1]) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits") # NOTE: In the past we did not include tanh in pre_logits. # For few-shot, it should not matter much, as it whitens anyways. x_2d = nn.tanh(hid(x_2d)) x = nn.tanh(hid(x)) out["pre_logits_2d"] = x_2d out["pre_logits"] = x if self.num_classes: kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw) x_2d = out["logits_2d"] = head(x_2d) x = out["logits"] = head(x) return x, out def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802 """Factory function, because linen really don't like what I'm doing!""" return _Module(num_classes, **{**decode_variant(variant), **kw}) def decode_variant(variant): """Converts a string like "B" or "B/32" into a params dict.""" if variant is None: return {} v, patch = variant, {} if "/" in variant: v, patch = variant.split("/") patch = {"patch_size": (int(patch), int(patch))} return { # pylint:disable=line-too-long # Reference: Table 2 of https://arxiv.org/abs/2106.04560. "width": { "mu": 32, "Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "So400m": 1152, "H": 1280, "g": 1408, "g-opt": 1536, "G": 1664, "G-opt": 1536, "e": 1792, }[v], "depth": { "mu": 1, "Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "So400m": 27, "H": 32, "g": 40, "g-opt": 40, "G": 48, "G-opt": 48, "e": 56, }[v], "mlp_dim": { "mu": 128, "Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "So400m": 4304, "H": 5120, "g": 6144, "g-opt": 6144, "G": 8192, "G-opt": 8192, "e": 15360, }[v], "num_heads": { "mu": 2, "Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "So400m": 16, "H": 16, "g": 16, "g-opt": 16, "G": 16, "G-opt": 16, "e": 16, }[v], # pylint:enable=line-too-long **patch, }