| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Common attention modules. |
| |
| Conventions: |
| - Pass `deterministic` and `rng` as an argument. `rng` is optional and defaults |
| to `self.make_rng()`. |
| - `train` and `deterministic` should not have a default. |
| - Do not define `rng`, `deterministic` or `train` as attributes. |
| - `rng`, `deterministic`, `train` should always be keyword only arguments. |
| - Prefer `use_bias` over `bias`. |
| """ |
| import functools |
| from typing import Callable, Optional, Sequence, Tuple |
|
|
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scenic.model_lib.layers import nn_layers |
|
|
| |
| |
| Initializer = Callable[[jnp.ndarray, Sequence[int], jnp.dtype], jnp.ndarray] |
| Shape = Sequence[int] |
|
|
|
|
| def _attention_dropout(attn_weights: jnp.ndarray, |
| *, |
| rate: float, |
| broadcast: bool = True, |
| dropout_rng: jnp.ndarray) -> jnp.ndarray: |
| """Applies dropout on attention weights. |
| |
| This always applies the dropout. There is no `deterministic` parameter. |
| |
| Arguments: |
| attn_weights: Attention weights. |
| rate: The dropout rate. (_not_ the keep rate!) |
| broadcast: Whether to broadcast on first and second last axis. |
| dropout_rng: RNG. |
| |
| Returns: |
| Weights after dropout. |
| """ |
| keep_prob = 1.0 - rate |
| if broadcast: |
| |
| dropout_shape = list(attn_weights.shape) |
| dropout_shape[0] = 1 |
| dropout_shape[-2] = 1 |
| keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape) |
| else: |
| keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) |
| multiplier = ( |
| keep.astype(attn_weights.dtype) / |
| jnp.asarray(keep_prob, dtype=attn_weights.dtype)) |
| return attn_weights * multiplier |
|
|
|
|
| def dot_product_attention( |
| query: jnp.ndarray, |
| key: jnp.ndarray, |
| value: jnp.ndarray, |
| *, |
| bias: Optional[jnp.ndarray] = None, |
| bias_kv: Optional[jnp.ndarray] = None, |
| broadcast_dropout: bool = True, |
| dropout_rate: float = 0.1, |
| dtype: jnp.dtype = jnp.float32, |
| precision: Optional[jax.lax.Precision] = None, |
| deterministic: bool, |
| dropout_rng: Optional[jnp.ndarray] = None, |
| capture_attention_weights: bool = True) -> jnp.ndarray: |
| """Computes the dot-product attention given query, key and value. |
| |
| This is the core function for applying attention based on |
| https://arxiv.org/abs/1706.03762. It calculates the attention weights given |
| query and key and combines the values using the attention weights. |
| |
| Note: query, key, value needn't have any batch dimensions. |
| |
| Args: |
| query: Queries for calculating attention with shape of `[batch..., q_length, |
| num_heads, qk_depth_per_head]`. |
| key: Keys for calculating attention with shape of `[batch..., kv_length, |
| num_heads, qk_depth_per_head]`. |
| value: Values to be used in attention with shape of `[batch..., kv_length, |
| num_heads, v_depth_per_head]`. |
| bias: Bias for the attention weights. This should be |
| broadcastable to the shape: `[batch..., num_heads, q_length, kv_length]` |
| This can be used for incorporating causal masks, padding masks, |
| proximity bias, etc. |
| bias_kv: Attention bias defined for keys only which has shape |
| `[batch..., kv_length]`. Can be used for masking elements in k/v. |
| broadcast_dropout: Use a broadcasted dropout along batch dims. |
| dropout_rate: Dropout rate. |
| dtype: The dtype of the computation (default: float32). |
| precision: Numerical precision of the computation see `jax.lax.Precision` |
| for details. |
| deterministic: Deterministic or not (to apply dropout). |
| dropout_rng: Optional JAX PRNGKey to be used for dropout. |
| capture_attention_weights: Whether to add an identity layer to tag the |
| attention weights to be used for capturing them using Flax |
| capture_intermediate, e.g. for visualization. Note that if this is set to |
| True, this function can be only called within a Flax module. |
| |
| Returns: |
| Output of shape `[batch..., length, num_heads, v_depth_per_head]`. |
| """ |
| assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' |
| assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( |
| 'q, k, v batch dims must match.') |
| assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( |
| 'q, k, v num_heads must match.') |
| assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' |
| assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' |
|
|
| |
| depth = query.shape[-1] |
| query = query / jnp.sqrt(depth).astype(dtype) |
| |
| attn_weights = jnp.einsum( |
| '...qhd,...khd->...hqk', query, key, precision=precision) |
|
|
| |
| if bias is not None: |
| attn_weights = attn_weights + bias |
| if bias_kv is not None: |
| bias_kv = bias_kv[..., jnp.newaxis, jnp.newaxis, :] |
| attn_weights += bias_kv |
|
|
| |
| attn_weights = jax.nn.softmax(attn_weights).astype(dtype) |
|
|
| if capture_attention_weights: |
| |
| attn_weights = nn_layers.IdentityLayer(name='attn_weights')(attn_weights) |
|
|
| |
| if not deterministic and dropout_rate > 0.: |
| if dropout_rng is None: |
| raise ValueError('Did not provide `rng` to dot_product_attention().') |
| attn_weights = _attention_dropout( |
| attn_weights, |
| rate=dropout_rate, |
| broadcast=broadcast_dropout, |
| dropout_rng=dropout_rng) |
|
|
| |
| return jnp.einsum( |
| '...hqk,...khd->...qhd', attn_weights, value, precision=precision) |
|
|
|
|
| def axial_dot_product_attention( |
| query: jnp.ndarray, |
| key: jnp.ndarray, |
| value: jnp.ndarray, |
| *, |
| bias: Optional[jnp.ndarray] = None, |
| bias_kv: Optional[jnp.ndarray] = None, |
| broadcast_dropout: bool = True, |
| dropout_rate: float = 0.1, |
| dtype: jnp.dtype = jnp.float32, |
| precision: Optional[jax.lax.Precision] = None, |
| deterministic: bool, |
| dropout_rng: Optional[jnp.ndarray] = None) -> jnp.ndarray: |
| """Applies masked, head-axial qkv dot-product attention. |
| |
| Assigns different heads for different axes which is more efficient and |
| allows for having attention on all axes in every layer. |
| |
| Args: |
| query: Queries for calculating attention with shape of `[batch, ..., |
| num_heads, qk_depth_per_head]`. |
| key: Keys for calculating attention with shape of `[batch, ..., num_heads, |
| qk_depth_per_head]`. |
| value: Values to be used in attention with shape of `[batch, ..., num_heads, |
| v_depth_per_head]`. |
| bias: Bias is not supported and will raise an error if passed. |
| bias_kv: Bias for the attention weights. This should be |
| broadcastable to the shape: `[batch, ...]`. This can be used for |
| incorporating causal masks, padding masks, proximity bias, etc. Default |
| is None, which means no bias is applied on attention matrix. |
| broadcast_dropout: Use a broadcasted dropout along batch dims. |
| dropout_rate: Dropout rate. |
| dtype: The dtype of the computation (default: float32). |
| precision: Numerical precision of the computation see `jax.lax.Precision` |
| for details. |
| deterministic: Deterministic or not (to apply dropout). |
| dropout_rng: Optional JAX PRNGKey to be used for dropout. |
| |
| Returns: |
| Output of shape `[bs, ..., num_heads, features]`. |
| """ |
| if query.shape != key.shape: |
| raise ValueError('Axial dot product attention only supports ' |
| 'query and key with the same shape.') |
| if bias is not None: |
| raise NotImplementedError('Bias passed to axial attention.') |
| if bias_kv is not None: |
| |
| |
| bias_kv = bias_kv[..., jnp.newaxis, jnp.newaxis] |
| |
| query = query / jnp.sqrt(query.shape[-1]).astype(dtype) |
| prefix_str = 'abcdefghijk' |
| |
| num_attn_dimensions = query.ndim - 3 |
| if query.shape[-2] % num_attn_dimensions != 0: |
| raise ValueError(f'In head-axial dot-product attention, number of ' |
| f'heads ({query.shape[-2]}) should be divisible by number ' |
| f'of attention dimensions ({num_attn_dimensions})!') |
|
|
| queries = jnp.split(query, num_attn_dimensions, axis=-2) |
| keys = jnp.split(key, num_attn_dimensions, axis=-2) |
| values = jnp.split(value, num_attn_dimensions, axis=-2) |
|
|
| outputs = [] |
| for i, (query, key, value) in enumerate(zip(queries, keys, values)): |
| axis = i + 1 |
| batch_dims = prefix_str[:axis] |
| einsum_str = f'{batch_dims}x...z,{batch_dims}y...z->{batch_dims}x...y' |
| attn_logits = jnp.einsum(einsum_str, query, key, precision=precision) |
| if bias_kv is not None: |
| |
| attn_logits += jnp.swapaxes(bias_kv, axis, -1) |
| attn_weights = jax.nn.softmax(attn_logits, axis=-1) |
|
|
| |
| if not deterministic and dropout_rate > 0.: |
| attn_weights = _attention_dropout( |
| attn_weights, |
| rate=dropout_rate, |
| broadcast=broadcast_dropout, |
| dropout_rng=dropout_rng) |
| einsum_str = f'{batch_dims}x...y,{batch_dims}y...z->{batch_dims}x...z' |
| outputs.append( |
| jnp.einsum(einsum_str, attn_weights, value, precision=precision)) |
|
|
| return jnp.concatenate(outputs, axis=-2) |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """Customized multi-head attention for scenic. |
| |
| Attributes: |
| num_heads: Number of attention heads. Features (i.e. inputs_q.shape[-1]) |
| should be divisible by the number of heads. |
| qkv_features: Dimension of the key, query, and value. |
| out_features: Dimension of the last projection. |
| dropout_rate: Dropout rate. |
| broadcast_dropout: Use a broadcasted dropout along batch dims. |
| kernel_init: Initializer for the kernel of the Dense layers. |
| bias_init: Initializer for the bias of the Dense layers. |
| out_kernel_init: Initializer for the kernel of the output Dense layers. If |
| None, kernel_init will be used. |
| use_bias: Whether pointwise QKV dense transforms use bias. |
| precision: Numerical precision of the computation see `jax.lax.Precision` |
| for details. |
| attention_fn: Defaults to dot_product_attention. Other function of the |
| same signature are possible. |
| dtype: the dtype of the computation (default: float32). |
| enforce_hidden_size_divisible_by_heads: Whether or not we allow the hidden |
| size to not be divisible by the number of heads. |
| """ |
| num_heads: int |
| qkv_features: Optional[int] = None |
| out_features: Optional[int] = None |
| dropout_rate: float = 0. |
| broadcast_dropout: bool = False |
| kernel_init: Initializer = nn.linear.default_kernel_init |
| bias_init: Initializer = nn.initializers.zeros |
| out_kernel_init: Optional[Initializer] = None |
| use_bias: bool = True |
| attention_fn: Callable[..., jnp.ndarray] = dot_product_attention |
| precision: Optional[jax.lax.Precision] = None |
| dtype: jnp.dtype = jnp.float32 |
| enforce_hidden_size_divisible_by_heads: bool = True |
|
|
| @nn.compact |
| def __call__(self, |
| inputs_q: jnp.ndarray, |
| inputs_kv: Optional[jnp.ndarray], |
| *, |
| pos_emb_q: Optional[jnp.ndarray] = None, |
| pos_emb_k: Optional[jnp.ndarray] = None, |
| pos_emb_v: Optional[jnp.ndarray] = None, |
| attention_bias: Optional[jnp.ndarray] = None, |
| attention_bias_kv: Optional[jnp.ndarray] = None, |
| deterministic: bool = False) -> jnp.ndarray: |
| """Applies multi-head dot product attention on the input data. |
| |
| Projects the inputs into multi-headed query, key, and value vectors, |
| applies dot-product attention and project the results to an output vector. |
| |
| This can be used for encoder-decoder attention by specifying both `inputs_q` |
| and `inputs_kv` or for self-attention by only specifying `inputs_q` and |
| setting `inputs_kv` to None. |
| |
| Args: |
| inputs_q: Input queries of shape `[bs, ..., len_q, features]`. |
| inputs_kv: Key/values of shape `[bs, ..., len_k, features]` or None for |
| self-attention, in which case key/values will be derived from inputs_q. |
| pos_emb_q: Positional embedding to be added to the query. |
| pos_emb_k: Positional embedding to be added to the key. |
| pos_emb_v: Positional embedding to be added to the value. |
| attention_bias: Full attention bias. Should be broadcastable to: |
| inputs_q.shape[:-2] + (num_heads, len_q, len_k). |
| attention_bias_kv: Attention bias for keys independent of queries which |
| has shape (bs, ..., len_k). |
| deterministic: Run deterministically or with dropout. |
| |
| Returns: |
| Output of shape `[bs, ..., features]`. |
| """ |
| if inputs_kv is None: |
| inputs_kv = inputs_q |
|
|
| features = self.out_features or inputs_q.shape[-1] |
| qkv_features = self.qkv_features or inputs_q.shape[-1] |
|
|
| if self.enforce_hidden_size_divisible_by_heads: |
| assert qkv_features % self.num_heads == 0, ( |
| 'Memory dimension must be divisible by number of heads.') |
| head_dim = qkv_features // self.num_heads |
|
|
| def add_positional_emb(x, pos): |
| return x + pos if pos is not None else x |
|
|
| query, key, value = ( |
| add_positional_emb(inputs_q, pos_emb_q), |
| add_positional_emb(inputs_kv, pos_emb_k), |
| add_positional_emb(inputs_kv, pos_emb_v)) |
|
|
| dense = functools.partial( |
| nn.DenseGeneral, |
| axis=-1, |
| features=(self.num_heads, head_dim), |
| kernel_init=self.kernel_init, |
| bias_init=self.bias_init, |
| use_bias=self.use_bias, |
| precision=self.precision) |
| |
| |
| query, key, value = (dense(name='query')(query), |
| dense(name='key')(key), |
| dense(name='value')(value)) |
|
|
| |
| attn_kwargs = {} |
| if attention_bias_kv is not None: |
| |
| attn_kwargs['bias_kv'] = attention_bias_kv |
| if not deterministic and self.dropout_rate > 0: |
| attn_kwargs['dropout_rng'] = self.make_rng('dropout') |
|
|
| x = self.attention_fn( |
| query, |
| key, |
| value, |
| bias=attention_bias, |
| dropout_rate=self.dropout_rate, |
| broadcast_dropout=self.broadcast_dropout, |
| deterministic=deterministic, |
| dtype=self.dtype, |
| precision=self.precision, |
| **attn_kwargs) |
| |
|
|
| |
| out_kernel_init = (self.out_kernel_init if self.out_kernel_init is not None |
| else self.kernel_init) |
| out = nn.DenseGeneral( |
| features=features, |
| axis=(-2, -1), |
| kernel_init=out_kernel_init, |
| bias_init=self.bias_init, |
| use_bias=True, |
| dtype=self.dtype, |
| precision=self.precision, |
| name='out')(x) |
|
|
| return out |
|
|
|
|
| class MlpBlock(nn.Module): |
| """Transformer MLP / feed-forward block.""" |
|
|
| mlp_dim: int |
| out_dim: Optional[int] = None |
| dropout_rate: float = 0.1 |
| use_bias: bool = True |
| kernel_init: Initializer = nn.initializers.xavier_uniform() |
| bias_init: Initializer = nn.initializers.normal(stddev=1e-6) |
| activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu |
| precision: Optional[jax.lax.Precision] = None |
| dtype: jnp.ndarray = jnp.float32 |
|
|
| @nn.compact |
| def __call__(self, inputs: jnp.ndarray, *, deterministic: bool): |
| """Applies Transformer MlpBlock module.""" |
| actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim |
| x = nn.Dense( |
| self.mlp_dim, |
| dtype=self.dtype, |
| use_bias=self.use_bias, |
| kernel_init=self.kernel_init, |
| bias_init=self.bias_init, |
| precision=self.precision)( |
| inputs) |
| x = nn_layers.IdentityLayer(name='mlp1')(self.activation_fn(x)) |
| x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) |
| output = nn.Dense( |
| actual_out_dim, |
| dtype=self.dtype, |
| use_bias=self.use_bias, |
| kernel_init=self.kernel_init, |
| bias_init=self.bias_init, |
| precision=self.precision)( |
| x) |
| output = nn_layers.IdentityLayer(name='mlp2')(output) |
| output = nn.Dropout(rate=self.dropout_rate)( |
| output, deterministic=deterministic) |
| return output |
|
|
|
|
| def sinusoidal_init(max_len: int, max_timescale: float = 1.0e4): |
| """1D Sinusoidal Position Embedding Initializer. |
| |
| Args: |
| max_len: maximum possible length for the input. |
| max_timescale: Maximum time scale. |
| |
| Returns: |
| output: init function returning `(1, max_len, d_feature)` |
| """ |
|
|
| def init(key: jnp.ndarray, |
| shape: Sequence[int], |
| dtype: jnp.dtype = jnp.float32) -> jnp.ndarray: |
| """Sinusoidal init. |
| |
| The defined API by JAX for a custom initializer is: |
| `def init(key, shape, dtype)` |
| |
| Even though some of args might be not used, the signature should follow |
| this API as JAX passes all the three arguments (key, shape, dtype) |
| to the initializers. |
| |
| Args: |
| key: JAXPRNG key. |
| shape: Shape used for making the initialized values. |
| dtype: JAX data type. |
| |
| Returns: |
| Initialized values |
| """ |
| del key, dtype |
| d_feature = shape[-1] |
| pos_emb = np.zeros((max_len, d_feature), dtype=np.float32) |
| position = np.arange(0, max_len)[:, np.newaxis] |
| div_term = np.exp( |
| np.arange(0, d_feature, 2) * -(np.log(max_timescale) / d_feature)) |
| pos_emb[:, 0::2] = np.sin(position * div_term) |
| pos_emb[:, 1::2] = np.cos(position * div_term) |
| pe = pos_emb[np.newaxis, :, :] |
| return jnp.array(pe) |
|
|
| return init |
|
|
|
|
| class Add1DPositionEmbedding(nn.Module): |
| """Adds 1-dimensional positional embeddings to the inputs. |
| |
| Attributes: |
| rescale_from: tuple; If not None, embeddings are rescaled from this shape. |
| max_len: int; Maximum possible length for the input. If None, the max_len is |
| set to the inputs sequence length. |
| posemb_init: Positional embedding initializer. |
| param_name: The name of the parameter that stores the positional embedding. |
| """ |
|
|
| rescale_from: Optional[Sequence[int]] = None |
| max_len: Optional[int] = None |
| posemb_init: Optional[Initializer] = None |
| param_name: str = 'pos_embedding' |
|
|
| @nn.compact |
| def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: |
| """Applies Add1DPositionEmbedding module. |
| |
| Args: |
| inputs: nd-arrary; Input data. |
| |
| Returns: |
| Output: `(bs, timesteps, in_dim)`. |
| """ |
| assert inputs.ndim == 3, ('Number of dimensions should be 3,' |
| ' but it is: %d' % inputs.ndim) |
| length = inputs.shape[1] |
| max_len = self.max_len or length |
| embedding_length = max_len |
|
|
| if self.rescale_from: |
| embedding_length = self.rescale_from[0] |
|
|
| pos_emb_shape = (1, embedding_length, inputs.shape[-1]) |
| if self.posemb_init is None: |
| |
| pos_embedding = sinusoidal_init(max_len=embedding_length)(None, |
| pos_emb_shape, |
| None) |
| else: |
| pos_embedding = self.param(self.param_name, self.posemb_init, |
| pos_emb_shape) |
| pe = pos_embedding[:, :length, :] |
|
|
| if max_len != embedding_length: |
| pe = jax.image.resize( |
| pe, (1, max_len, pe.shape[-1]), method='bilinear', antialias=False) |
| pe = jnp.reshape(pe, (1, max_len, -1)) |
| return inputs + pe |
|
|
|
|
| class Add2DPositionEmbedding(nn.Module): |
| """Adds 2-dimensional positional embeddings to the inputs. |
| |
| Attributes: |
| rescale_from: tuple; If not None, embeddings are rescaled from this shape. |
| posemb_init: Positional embedding initializer. |
| """ |
|
|
| rescale_from: Optional[Tuple[int, ...]] = None |
| posemb_init: Initializer = nn.initializers.normal(stddev=0.02) |
|
|
| @nn.compact |
| def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: |
| """Applies Add2DPositionEmbedding module. |
| |
| Args: |
| inputs: nd-arrary; Input data. |
| |
| Returns: |
| Output: `(bs, h, w, c)`. |
| """ |
| assert inputs.ndim == 4, ('Number of dimensions should be 4,' |
| ' but it is: %d' % inputs.ndim) |
| _, h, w, c = inputs.shape |
| embedding_h, embedding_w = h, w |
| if self.rescale_from: |
| embedding_h, embedding_w = self.rescale_from[0], self.rescale_from[1] |
|
|
| row_pos_embed = self.param('row_pos_embedding', self.posemb_init, |
| (embedding_w, c // 2)) |
| col_pos_embed = self.param('col_pos_embedding', self.posemb_init, |
| (embedding_h, c // 2)) |
| |
| x_pos_emb = jnp.tile( |
| jnp.expand_dims(row_pos_embed, axis=0), (embedding_h, 1, 1)) |
| |
| y_pos_emb = jnp.tile( |
| jnp.expand_dims(col_pos_embed, axis=1), (1, embedding_w, 1)) |
| |
| pos = jnp.concatenate((x_pos_emb, y_pos_emb), axis=-1) |
|
|
| if w != embedding_w or h != embedding_h: |
| pos = jax.image.resize(pos, (h, w, c), method='bilinear', antialias=False) |
|
|
| |
| pos = jnp.expand_dims(pos, axis=0) |
|
|
| return inputs + pos |
|
|
|
|
| def get_fixed_sincos_position_embedding(x_shape: Shape, |
| temperature: float = 10_000, |
| dtype: jnp.dtype = jnp.float32): |
| """Provides a fixed position encoding for 2D and 3D coordinates. |
| |
| The embedding follows the initialisation method used in multiple papers such |
| as "Attention is All You Need", https://arxiv.org/abs/1706.03762 and |
| "Better plain ViT baselines for ImageNet-1k", https://arxiv.org/abs/2205.01580 |
| |
| Arguments: |
| x_shape: the shape of the input for which a position embedding is needed. |
| temperature: Temperature parameter. |
| dtype: dtype of the position encoding. |
| Returns: |
| Matrix of position embeddings, has shape [1, ...], where ... = x_shape[1:]. |
| """ |
| assert len(x_shape) in (4, 5), f'Unsupported input shape: {x_shape}' |
| num_parts = 4 if len(x_shape) == 4 else 6 |
| channels = x_shape[-1] |
| assert channels % num_parts == 0, f'Channels must be multiple of {num_parts}' |
| omega = jnp.arange( |
| channels // num_parts, dtype=jnp.float32) / (channels / num_parts) |
| omega = 1. / (temperature**omega) |
|
|
| if len(x_shape) == 4: |
| _, h, w, _ = x_shape |
| y, x = jnp.mgrid[:h, :w] |
| y = jnp.einsum('m,d->md', y.flatten(), omega) |
| x = jnp.einsum('m,d->md', x.flatten(), omega) |
| p = [jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)] |
| shape = (1, h, w, channels) |
| elif len(x_shape) == 5: |
| _, t, h, w, _ = x_shape |
| z, y, x = jnp.mgrid[:t, :h, :w] |
| z = jnp.einsum('m,d->md', z.flatten(), omega) |
| y = jnp.einsum('m,d->md', y.flatten(), omega) |
| x = jnp.einsum('m,d->md', x.flatten(), omega) |
| p = [jnp.sin(z), jnp.cos(z), |
| jnp.sin(x), jnp.cos(x), |
| jnp.sin(y), jnp.cos(y)] |
| shape = (1, t, h, w, channels) |
| else: |
| raise ValueError(f'Unsupported input shape: {x_shape}') |
|
|
| assert (shape[0] == 1) and (shape[1:] == x_shape[1:]) |
| pe = jnp.concatenate(p, axis=1) |
| return jnp.asarray(pe, dtype).reshape(*shape) |
|
|
|
|
| class AddFixedSinCosPositionEmbedding(nn.Module): |
| """Provides a fixed position encoding for 2D and 3D coordinates. |
| |
| The embedding follows the initialisation method used in multiple papers such |
| as "Attention is All You Need", https://arxiv.org/abs/1706.03762 and |
| "Better plain ViT baselines for ImageNet-1k", https://arxiv.org/abs/2205.01580 |
| |
| Attributes: |
| temperature: Temperature parameter. |
| dtype: dtype of the position encoding. |
| """ |
| temperature: float = 10_000 |
| dtype: jnp.dtype = jnp.float32 |
|
|
| @nn.compact |
| def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: |
| """Adds the fixed embedding to the inputs. |
| |
| Args: |
| inputs: Either an [N, W, H, C] or [N, T, W, H, C] input array. |
| |
| Returns: |
| inputs with position encodings added to them. |
| """ |
| return inputs + get_fixed_sincos_position_embedding( |
| inputs.shape, self.temperature, self.dtype) |
|
|
|
|
| class RelativeAttentionBias(nn.Module): |
| """Provides learnable NxN relative attention bias. |
| |
| Attributes: |
| num_heads: Number of heads for which to provide relative attention. |
| nd_shape: Shape for which to provided relative attention bias. For instance, |
| for images we we would provide a 2D shape. Note that batch and feature |
| dimensions should be excluded here. |
| initializer: Initializer for the bias. |
| """ |
|
|
| num_heads: int |
| nd_shape: Sequence[int] |
| initializer: Initializer = nn.initializers.zeros |
|
|
| @nn.compact |
| def __call__(self) -> jnp.ndarray: |
| """Creates relative attention bias that factorizes over dimensions. |
| |
| length = prod(nd_shape) |
| |
| Returns: |
| Bias of shape `[num_heads, length, length]`. |
| """ |
| length = np.prod(self.nd_shape) |
| tile = 1 |
| biases = [] |
| for i, l in enumerate(self.nd_shape): |
| |
| if l > 1: |
| new_bias = self.relative_attn_bias(l, self.num_heads, f'bias_{i}') |
| repeat = length // (tile * l) |
| if repeat > 1: |
| new_bias = new_bias[:, :, jnp.newaxis, :, jnp.newaxis] |
| new_bias = jnp.tile(new_bias, [1, tile, repeat, tile, repeat]) |
| new_bias = jnp.reshape(new_bias, [self.num_heads, length, length]) |
| elif tile > 1: |
| new_bias = jnp.tile(new_bias, [1, tile, tile]) |
| tile *= l |
| biases.append(new_bias) |
|
|
| return sum(biases) |
|
|
| def relative_attn_bias(self, length, num_heads, name): |
| """Computes attention bias based on relative positions. |
| |
| Content-based relative position attention bias was used in: |
| https://arxiv.org/pdf/1803.02155. |
| Non-content-based relative position attention bias was used in: |
| https://arxiv.org/abs/1606.01933. |
| |
| Args: |
| length: Length of self-attention window for relative attention. |
| num_heads: Number of attention heads. |
| name: Name of the parameter to be created. |
| |
| Returns: |
| A `[num_heads, length, length]` tensor with queries. |
| """ |
| |
| |
| num_rel_pos = 2 * length |
|
|
| rel_bias = self.param( |
| name, self.initializer, (self.num_heads, num_rel_pos)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| rel_bias = jnp.tile(rel_bias, [1, length]) |
|
|
| |
| num_rel_pos -= 1 |
| rel_bias = rel_bias[..., :length * num_rel_pos] |
|
|
| |
| |
| rel_bias = rel_bias.reshape(num_heads, length, num_rel_pos) |
|
|
| |
| |
| rel_bias = rel_bias[..., num_rel_pos - length:] |
|
|
| return rel_bias |
|
|