# Copyright 2025 The Scenic Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Layers and Modules for Knowledge-FID.""" import functools from typing import Optional, Sequence from flax import linen as nn import jax import jax.numpy as jnp from scenic.projects.knowledge_visual_language.models import constants from t5x.examples.t5 import layers as t5_layers from t5x.examples.t5 import network as t5_network @jax.vmap def batch_index_select(data, idx): return jnp.take(data, idx, axis=0) def _mask_select(data, mask): return jax.lax.select( mask > 0, data, jnp.full(data.shape, 0).astype(data.dtype) ) def l2_norm(x): """Compute the l2 norm of a vector.""" return jnp.sqrt((x * x).sum(axis=-1)) def l2_normalize(x, axis=-1, eps=1e-10): """Normalizes along dimension `axis` using an L2 norm. This specialized function exists for numerical stability reasons. Args: x: An input ndarray. axis: Dimension along which to normalize, e.g. `1` to separately normalize vectors in a batch. Passing `None` views `t` as a flattened vector when calculating the norm (equivalent to Frobenius norm). eps: Epsilon to avoid dividing by zero. Returns: An array of the same shape as 'x' L2-normalized along 'axis'. """ denorm = (x * x).sum(axis=axis, keepdims=True) + eps return (x * jax.lax.rsqrt(denorm)).astype(x.dtype) class AffineTransform(nn.Module): """Do affine Transform for modulating attention score.""" @nn.compact def __call__(self, x): scale = self.param('scale', nn.initializers.ones, (1,), jnp.float32) bias = self.param('bias', nn.initializers.zeros, (1,), jnp.float32) return x * nn.sigmoid(scale) * 5 + bias class TransformerHead(nn.Module): """A stack of encoder layers.""" num_head_layers: int key_dim: int vocab_size: int emb_dim: int num_heads: int num_encoder_layers: int num_decoder_layers: int head_dim: int mlp_dim: int dropout_rate: float out_head: nn.Module dtype: str = 'bfloat16' mlp_activations: Sequence[str] = ('gelu', 'linear') logits_via_embedding: bool = False def setup(self): self.t5_config = t5_network.T5Config( vocab_size=self.vocab_size, emb_dim=self.emb_dim, num_heads=self.num_heads, num_encoder_layers=self.num_encoder_layers, num_decoder_layers=self.num_decoder_layers, head_dim=self.head_dim, mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype, mlp_activations=self.mlp_activations, logits_via_embedding=self.logits_via_embedding, ) @nn.compact def __call__(self, encoded_emb, encoder_mask=None, use_dropout=True): """transform the encoded representation.""" cfg = self.t5_config assert encoded_emb.ndim == 3 # [batch, length, emb_dim] x = encoded_emb if encoder_mask is not None: encoder_mask = t5_layers.make_attention_mask( encoder_mask, encoder_mask, dtype=cfg.dtype ) rel_emb = t5_layers.RelativePositionBiases( num_buckets=32, max_distance=128, num_heads=self.num_heads, dtype=self.dtype, embedding_init=nn.initializers.variance_scaling( 1.0, 'fan_avg', 'uniform' ), ) for _ in range( cfg.num_encoder_layers - self.num_head_layers, cfg.num_encoder_layers ): # [batch, length, emb_dim] -> [batch, length, emb_dim] x = t5_network.EncoderLayer(config=cfg, relative_embedding=rel_emb)( x, encoder_mask, deterministic=not use_dropout ) x = t5_layers.LayerNorm(dtype=cfg.dtype)(x[:, 0, :]) return l2_normalize(self.out_head(x), axis=-1) class LowerT5Encoder(nn.Module): """T5 encoder as a separate model which fuse multi-modal input. This module contains the encoder part of a pretrained T5. It is useful when adopting the pretrained T5 encoder as a part of a larger network. Note that the embedding layer should be created outside the module and provided as a parameter `shared_embedding` to share it in other parts of the network (e.g., text encoder). If `shared_embedding` is not provided, the embedding layer is created within the module. Attributes: vocab_size: Size of the vocabulary. emb_dim: Size of the embeddings. num_heads: Number of attention heads. num_encoder_layers: Number of encoder layers. num_decoder_layers: Number of decoder layers. head_dim: Size of the embeddings in each head. mlp_dim: Size of the MLP output embeddings. dropout_rate: Dropout rate. dtype: Data type. mlp_activations: Sequence of activations in MLP. logits_via_embedding: Use the embedding weights for computing logits. shared_embedding: Optional. Embedding layer that is shared outside this module. If not given, a non-shared embedding layer will be created within the module. """ vocab_size: int emb_dim: int num_heads: int num_encoder_layers: int num_decoder_layers: int num_fusion_layers: int head_dim: int mlp_dim: int dropout_rate: float dtype: str = 'bfloat16' mlp_activations: Sequence[str] = ('gelu', 'linear') logits_via_embedding: bool = False shared_embedding: Optional[nn.Module] = None def setup(self): self.t5_config = t5_network.T5Config( vocab_size=self.vocab_size, emb_dim=self.emb_dim, num_heads=self.num_heads, num_encoder_layers=self.num_encoder_layers, num_decoder_layers=self.num_decoder_layers, head_dim=self.head_dim, mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype, mlp_activations=self.mlp_activations, logits_via_embedding=self.logits_via_embedding, ) if self.shared_embedding is None: self.shared_embedding = t5_layers.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, dtype=self.dtype, attend_dtype=jnp.float32, # For logit training stability. embedding_init=nn.initializers.normal(stddev=1.0), one_hot=True, ) @nn.compact def __call__( self, encoder_input_tokens, encoder_segment_ids=None, use_dropout=True, frozen_base=True, ): """encode the text sentence only. Args: encoder_input_tokens: input text tokens encoder_segment_ids: segmend ID in packing mode use_dropout: whether to use dropout during Training frozen_base: whether froze the text encoder Returns: Sequence of token embedding with or without fusion """ cfg = self.t5_config assert encoder_input_tokens.ndim == 2 # (batch, len) # Make padding attention mask. encoder_mask = encoder_input_tokens > 0 mask_matrix = t5_layers.make_attention_mask( encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype ) # Add segmentation block-diagonal attention mask if using segmented data. if encoder_segment_ids is not None: mask_matrix = t5_layers.combine_masks( mask_matrix, t5_layers.make_attention_mask( encoder_segment_ids, encoder_segment_ids, jnp.equal, dtype=cfg.dtype, ), ) rel_emb = t5_layers.RelativePositionBiases( num_buckets=32, max_distance=128, num_heads=self.t5_config.num_heads, dtype=self.t5_config.dtype, embedding_init=nn.initializers.variance_scaling( 1.0, 'fan_avg', 'uniform' ), ) # [batch, length] -> [batch, length, emb_dim] x = self.shared_embedding(encoder_input_tokens.astype('int32')) x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( x, deterministic=not use_dropout ) x = x.astype(cfg.dtype) n_layer = cfg.num_encoder_layers - self.num_fusion_layers frozen_layer_id = int(n_layer * 0.8) - 1 for lyr in range(n_layer): # [batch, length, emb_dim] -> [batch, length, emb_dim] x = t5_network.EncoderLayer(config=cfg, relative_embedding=rel_emb)( x, mask_matrix, deterministic=not use_dropout ) if frozen_base and lyr == frozen_layer_id: x = jax.lax.stop_gradient(x) return x, encoder_mask class FusedT5Encoder(nn.Module): """T5 encoder as a separate model which fuse multi-modal input. This module contains the encoder part of a pretrained T5. It is useful when adopting the pretrained T5 encoder as a part of a larger network. Note that the embedding layer should be created outside the module and provided as a parameter `shared_embedding` to share it in other parts of the network (e.g., text encoder). If `shared_embedding` is not provided, the embedding layer is created within the module. Attributes: vocab_size: Size of the vocabulary. emb_dim: Size of the embeddings. num_heads: Number of attention heads. num_encoder_layers: Number of encoder layers. num_decoder_layers: Number of decoder layers. head_dim: Size of the embeddings in each head. mlp_dim: Size of the MLP output embeddings. dropout_rate: Dropout rate. dtype: Data type. mlp_activations: Sequence of activations in MLP. logits_via_embedding: Use the embedding weights for computing logits. """ vocab_size: int emb_dim: int num_heads: int num_encoder_layers: int num_decoder_layers: int num_fusion_layers: int head_dim: int mlp_dim: int dropout_rate: float dtype: str = 'bfloat16' mlp_activations: Sequence[str] = ('gelu', 'linear') logits_via_embedding: bool = False def setup(self): self.t5_config = t5_network.T5Config( vocab_size=self.vocab_size, emb_dim=self.emb_dim, num_heads=self.num_heads, num_encoder_layers=self.num_encoder_layers, num_decoder_layers=self.num_decoder_layers, head_dim=self.head_dim, mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype, mlp_activations=self.mlp_activations, logits_via_embedding=self.logits_via_embedding, ) @nn.compact def __call__( self, fused_input_embs, encoder_input_embs=None, encoder_mask=None, fused_mask=None, att_mask=None, use_dropout=True, output=False, ): """Function to fuse text and imaget embedding. encode both the encoded text embedding (encoder_input_embs) and encoded image embedding (fused_input_embs) together using self-attentive Transformer. Args: fused_input_embs: pre-encoded embeddings of other modalities encoder_input_embs: encoded text embedding sequence encoder_mask: mask for encoding part fused_mask: mask for fusion part att_mask: pre-computed attention product to each layer's output use_dropout: whether to use dropout. output: whether it's output layer. Returns: Sequence of token embedding after fusion """ cfg = self.t5_config if encoder_input_embs is not None: x = jnp.concatenate([encoder_input_embs, fused_input_embs], axis=1) else: x = fused_input_embs rel_emb = t5_layers.RelativePositionBiases( num_buckets=32, max_distance=128, num_heads=self.t5_config.num_heads, dtype=self.t5_config.dtype, embedding_init=nn.initializers.variance_scaling( 1.0, 'fan_avg', 'uniform' ), ) if encoder_mask is not None: if fused_mask is None: pad_width = fused_input_embs.shape[1] fused_mask = jnp.pad( array=encoder_mask, pad_width=((0, 0), (0, pad_width)), mode='constant', constant_values=1.0, ) else: fused_mask = jnp.concatenate([encoder_mask, fused_mask], axis=1) mask_matrix = t5_layers.make_attention_mask( fused_mask, fused_mask, dtype=cfg.dtype ) attn_weights_all_layers = [] for _ in range( cfg.num_encoder_layers - self.num_fusion_layers, cfg.num_encoder_layers ): # [batch, length, emb_dim] -> [batch, length, emb_dim] x, attn_weights = FusionEncoderLayer( config=cfg, relative_embedding=rel_emb )( x, encoder_mask=mask_matrix, att_mask=att_mask, deterministic=not use_dropout, ) attn_weights_all_layers += [attn_weights] if output: x = t5_layers.LayerNorm(dtype=cfg.dtype)(x) if att_mask is not None: x = x * att_mask x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=not use_dropout) return x, fused_mask, attn_weights_all_layers class FusionEncoderLayer(nn.Module): """Transformer encoder layer.""" config: t5_network.T5Config relative_embedding: nn.Module @nn.compact def __call__( self, inputs, att_mask=None, encoder_mask=None, deterministic=False ): cfg = self.config # Relative position embedding as attention biases. encoder_bias = self.relative_embedding( inputs.shape[-2], inputs.shape[-2], True ) # Attention block. assert inputs.ndim == 3 x = t5_layers.LayerNorm(dtype=cfg.dtype)(inputs) if att_mask is not None: x = x * att_mask # [batch, length, emb_dim] -> [batch, length, emb_dim] x, attn_weights = MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, head_dim=cfg.head_dim, dropout_rate=cfg.dropout_rate, float32_logits=cfg.float32_attention_logits, )(x, x, encoder_mask, encoder_bias, deterministic=deterministic) x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic ) x = x + inputs # MLP block. y = t5_layers.LayerNorm(dtype=cfg.dtype)(x) # [batch, length, emb_dim] -> [batch, length, emb_dim] y = t5_layers.MlpBlock( intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, )(y, deterministic=deterministic) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic ) y = y + x return y, attn_weights class PerceiverEncoder(nn.Module): """Reimplementation of Perceiver. Perceiver: General Perception with Iterative Attention (https://arxiv.org/abs/2103.03206) """ perceiver_output_dim: int vocab_size: int emb_dim: int num_heads: int num_encoder_layers: int num_decoder_layers: int num_fusion_layers: int head_dim: int mlp_dim: int dropout_rate: float dtype: str = 'bfloat16' mlp_activations: Sequence[str] = ('gelu', 'linear') logits_via_embedding: bool = False def setup(self): self.t5_config = t5_network.T5Config( vocab_size=self.vocab_size, emb_dim=self.emb_dim, num_heads=self.num_heads, num_encoder_layers=self.num_encoder_layers, num_decoder_layers=self.num_decoder_layers, head_dim=self.head_dim, mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype, mlp_activations=self.mlp_activations, logits_via_embedding=self.logits_via_embedding, ) self.perceive_embedding = self.param( 'perceive_embedding', nn.initializers.normal(stddev=1.0), (1, self.perceiver_output_dim, self.emb_dim), jnp.float32, ) v = jnp.arange(self.perceiver_output_dim) self.batch_triangle_select = jax.vmap( functools.partial(_mask_select, mask=v < v.reshape([-1, 1])) ) def linear_disentangle(self, y): mean = y.mean(axis=-2, keepdims=True) norm_y = l2_normalize(y - mean) pairwise_mat = jnp.square(jnp.einsum('bqd,btd->bqt', norm_y, norm_y)) masked_mat = self.batch_triangle_select(pairwise_mat) return jnp.mean(masked_mat) @nn.compact def __call__(self, encoded, encoded_mask, use_dropout=False): cfg = self.t5_config rel_emb = t5_layers.RelativePositionBiases( num_buckets=32, max_distance=128, num_heads=cfg.num_heads, dtype=cfg.dtype, embedding_init=nn.initializers.variance_scaling( 1.0, 'fan_avg', 'uniform' ), ) # [batch, length] -> [batch, length, emb_dim] encoded = t5_layers.LayerNorm(dtype=cfg.dtype)(encoded) bsz = encoded.shape[0] y = jnp.asarray(self.perceive_embedding, dtype=cfg.dtype) y = jnp.repeat(y, bsz, axis=0) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=not use_dropout ) y = y.astype(cfg.dtype) mask = jnp.ones([bsz, self.perceiver_output_dim]).astype(bool) encoder_decoder_mask = t5_layers.make_attention_mask( mask, encoded_mask, dtype=self.dtype ) for _ in range(self.num_fusion_layers): # [batch, length, emb_dim] -> [batch, length, emb_dim] y = t5_network.DecoderLayer(config=cfg, relative_embedding=rel_emb)( y, encoded, deterministic=not use_dropout, encoder_decoder_mask=encoder_decoder_mask, decode=False, ) return y * 4, mask, self.linear_disentangle(y) def dot_product_attention( query: constants.JTensor, key: constants.JTensor, value: constants.JTensor, bias: Optional[constants.JTensor] = None, dropout_rng: Optional[constants.JTensor] = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: constants.DType = jnp.float32, float32_logits: bool = False, ): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. Args: query: queries for calculating attention with shape of `[batch, q_length, num_heads, qk_depth_per_head]`. key: keys for calculating attention with shape of `[batch, kv_length, num_heads, qk_depth_per_head]`. value: values to be used in attention with shape of `[batch, kv_length, num_heads, v_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the shape `[batch, num_heads, q_length, kv_length]` This can be used for incorporating causal masks, padding masks, proximity bias, etc. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: float32) float32_logits: bool, if True then compute logits in float32 to avoid numerical issues with bfloat16. Returns: Output of shape `[batch, length, num_heads, v_depth_per_head]`. """ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert ( query.shape[:-3] == key.shape[:-3] == value.shape[:-3] ), 'q, k, v batch dims must match.' assert ( query.shape[-2] == key.shape[-2] == value.shape[-2] ), 'q, k, v num_heads must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' # Casting logits and softmax computation for float32 for model stability. if float32_logits: query = query.astype(jnp.float32) key = key.astype(jnp.float32) # `attn_weights`: [batch, num_heads, q_length, kv_length] attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) # Apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: attn_weights = attn_weights + bias.astype(attn_weights.dtype) # Normalize the attention weights across `kv_length` dimension. attn_weights = jax.nn.softmax(attn_weights).astype(dtype) # Apply attention dropout. if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate # T5 broadcasts along the "length" dim, but unclear which one that # corresponds to in positional dimensions here, assuming query dim. dropout_shape = list(attn_weights.shape) dropout_shape[-2] = 1 keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jnp.broadcast_to(keep, attn_weights.shape) multiplier = keep.astype(attn_weights.dtype) / jnp.asarray( keep_prob, dtype=dtype ) attn_weights = attn_weights * multiplier # Take the linear combination of `value`. return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value), attn_weights class MultiHeadDotProductAttention(nn.Module): """Multi-head dot-product attention. Attributes: num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. head_dim: dimension of each head. dtype: the dtype of the computation. dropout_rate: dropout rate kernel_init: initializer for the kernel of the Dense layers. float32_logits: bool, if True then compute logits in float32 to avoid numerical issues with bfloat16. """ num_heads: int head_dim: int dtype: constants.DType = jnp.float32 dropout_rate: float = 0.0 kernel_init: constants.Initializer = nn.initializers.variance_scaling( 1.0, 'fan_in', 'normal' ) float32_logits: bool = False # computes logits in float32 for stability. @nn.compact def __call__( self, inputs_q: constants.JTensor, inputs_kv: constants.JTensor, mask: Optional[constants.JTensor] = None, bias: Optional[constants.JTensor] = None, *, decode: bool = False, deterministic: bool = False, ) -> constants.JTensor: """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. There are two modes: decoding and non-decoding (e.g., training). The mode is determined by `decode` argument. For decoding, this method is called twice, first to initialize the cache and then for an actual decoding process. The two calls are differentiated by the presence of 'cached_key' in the variable dict. In the cache initialization stage, the cache variables are initialized as zeros and will be filled in the subsequent decoding process. In the cache initialization call, `inputs_q` has a shape [batch, length, q_features] and `inputs_kv`: [batch, length, kv_features]. During the incremental decoding stage, query, key and value all have the shape [batch, 1, qkv_features] corresponding to a single step. Args: inputs_q: input queries of shape `[batch, q_length, q_features]`. inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. decode: Whether to prepare and use an autoregressive cache. deterministic: Disables dropout if set to True. Returns: output of shape `[batch, length, q_features]`. """ projection = functools.partial( t5_layers.DenseGeneral, axis=-1, features=(self.num_heads, self.head_dim), kernel_axes=('embed', 'joined_kv'), dtype=self.dtype, ) # NOTE: T5 does not explicitly rescale the attention logits by # 1/sqrt(depth_kq)! This is folded into the initializers of the # linear transformations, which is equivalent under Adafactor. depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) query_init = lambda *args: self.kernel_init(*args) / depth_scaling # Project inputs_q to multi-headed q/k/v # dimensions are then [batch, length, num_heads, head_dim] query = projection(kernel_init=query_init)(inputs_q) key = projection(kernel_init=self.kernel_init)(inputs_kv) value = projection(kernel_init=self.kernel_init)(inputs_kv) query = t5_layers.with_sharding_constraint( query, ('batch', 'length', 'heads', 'kv') ) key = t5_layers.with_sharding_constraint( key, ('batch', 'length', 'heads', 'kv') ) value = t5_layers.with_sharding_constraint( value, ('batch', 'length', 'heads', 'kv') ) if decode: # Detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') # The key and value have dimension [batch, length, num_heads, head_dim], # but we cache them as [batch, num_heads, head_dim, length] as a TPU # fusion optimization. This also enables the "scatter via one-hot # broadcast" trick, which means we do a one-hot broadcast instead of a # scatter/gather operations, resulting in a 3-4x speedup in practice. swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) cached_key = self.variable( 'cache', 'cached_key', jnp.zeros, swap_dims(key.shape), key.dtype ) cached_value = self.variable( 'cache', 'cached_value', jnp.zeros, swap_dims(value.shape), value.dtype, ) cache_index = self.variable( 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) ) if is_initialized: batch, num_heads, head_dim, length = cached_key.value.shape # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. # Sanity shape check of cached key against input query. expected_shape = (batch, 1, num_heads, head_dim) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape) ) # Create a OHE of the current index. NOTE: the index is increased below. cur_index = cache_index.value one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) # In order to update the key, value caches with the current key and # value, we move the length axis to the back, similar to what we did for # the cached ones above. # Note these are currently the key and value of a single position, since # we feed one position at a time. one_token_key = jnp.moveaxis(key, -3, -1) one_token_value = jnp.moveaxis(value, -3, -1) # Update key, value caches with our new 1d spatial slices. # We implement an efficient scatter into the cache via one-hot # broadcast and addition. key = cached_key.value + one_token_key * one_hot_indices value = cached_value.value + one_token_value * one_hot_indices cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # Move the keys and values back to their original shapes. key = jnp.moveaxis(key, -1, -3) value = jnp.moveaxis(value, -1, -3) # Causal mask for cached decoder self-attention: our single query # position should only attend to those key positions that have already # been generated and cached, not the remaining zero elements. mask = t5_layers.combine_masks( mask, jnp.broadcast_to( jnp.arange(length) <= cur_index, # (1, 1, length) represent (head dim, query length, key length) # query length is 1 because during decoding we deal with one # index. # The same mask is applied to all batch elements and heads. (batch, 1, 1, length), ), ) # Grab the correct relative attention bias during decoding. This is # only required during single step decoding. if bias is not None: # The bias is a full attention matrix, but during decoding we only # have to take a slice of it. # This is equivalent to bias[..., cur_index:cur_index+1, :]. bias = t5_layers.dynamic_vector_slice_in_dim( jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2 ) # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias attention_bias = jax.lax.select( mask > 0, jnp.full(mask.shape, 0.0).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype), ) else: attention_bias = None # Add provided bias term (e.g. relative position embedding). if bias is not None: attention_bias = t5_layers.combine_biases(attention_bias, bias) dropout_rng = None if not deterministic and self.dropout_rate > 0.0: dropout_rng = self.make_rng('dropout') # Apply attention. x, attn_weights = dot_product_attention( query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, deterministic=deterministic, dtype=self.dtype, float32_logits=self.float32_logits, ) # Back to the original inputs dimensions. out = t5_layers.DenseGeneral( features=inputs_q.shape[-1], # output dim is set to the input dim. axis=(-2, -1), kernel_init=self.kernel_init, kernel_axes=('joined_kv', 'embed'), dtype=self.dtype, )(x) return out, attn_weights