Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Base class for transformer layers.""" | |
| from typing import Any, Callable, Optional, Tuple | |
| from absl import logging | |
| from flax import linen as nn | |
| import gin | |
| import jax | |
| import jax.numpy as jnp | |
| from transformer import nn_components | |
| Array = Any | |
| # Tuple of scale factors | |
| AttnScaleTuple = Tuple[Optional[Array], Optional[Array]] | |
| # Tuple of keys,values,queries | |
| KVQTuple = Tuple[Array, Array, Optional[Array], Optional[Array]] | |
| class KVQLayer(nn.Module): | |
| """Generate keys, values, and queries for attention.""" | |
| embedding_size: int | |
| num_heads: int | |
| head_size: int | |
| has_queries: bool = True | |
| has_queries2: bool = False # For cross-attention, e.g. decoder or recurrence. | |
| normalize_keys: bool = True # Normalize keys and queries. | |
| num_position_embeddings: int = 0 # Learned absolute position embeddings. | |
| pre_attn_dropout: bool = True | |
| dropout_rate: float = 0.0 | |
| dtype: Any = jnp.float32 | |
| def setup(self): | |
| kernel_init = nn.initializers.variance_scaling( | |
| scale=1.0, mode="fan_in", distribution="truncated_normal") | |
| # Project to keys,values,queries | |
| # Disable bias. This prevents a failure mode whereby the attention matrix | |
| # can become filled with very large uniform values, due to high bias. | |
| self.keys_layer = nn.Dense( | |
| features=self.num_heads * self.head_size, | |
| use_bias=False, # No bias for keys. | |
| kernel_init=kernel_init, | |
| dtype=self.dtype) | |
| self.values_layer = nn.Dense( | |
| features=self.num_heads * self.head_size, | |
| use_bias=False, # No bias for values. | |
| kernel_init=kernel_init, | |
| dtype=self.dtype) | |
| if self.has_queries: | |
| self.queries_layer = nn.Dense( | |
| features=self.num_heads * self.head_size, | |
| use_bias=False, # No bias for queries. | |
| kernel_init=kernel_init, | |
| dtype=self.dtype) | |
| if self.has_queries2: | |
| self.queries2_layer = nn.Dense( | |
| features=self.num_heads * self.head_size, | |
| use_bias=False, # No bias for queries. | |
| kernel_init=kernel_init, | |
| dtype=self.dtype) | |
| # When normalizing keys and queries, attention must be scaled with | |
| # learned parameters. | |
| if self.normalize_keys: | |
| self.attention_scale = self.param("attention_scale", | |
| jax.nn.initializers.ones, | |
| (self.num_heads,), jnp.float32) | |
| # Learned position embeddings for absolute positions. | |
| if self.num_position_embeddings > 0: | |
| # Embeddings for query elements. | |
| self.position_embeddings = self.param( | |
| "position_embeddings", | |
| jax.nn.initializers.normal(stddev=1.0), | |
| (self.num_position_embeddings, self.embedding_size), | |
| jnp.float32) | |
| # Layernorm | |
| self.pre_attn_layernorm = nn_components.LayerNorm() | |
| def attention_scale_factor(self) -> Optional[Array]: | |
| """Returns the attention scale, when keys and queries are normalized.""" | |
| if self.normalize_keys: | |
| return jnp.asarray(self.attention_scale, dtype=self.dtype) | |
| else: | |
| return None | |
| def _get_dropout_rng(self): | |
| return self.make_rng("dropout") | |
| def _normalize_kq(self, kq: Array) -> Array: | |
| """Normalize function for keys and queries.""" | |
| epsilon = jnp.array(1.0e-6, dtype=self.dtype) | |
| kq_sum_sqr = jnp.sum(jnp.square(kq), axis=-1, keepdims=True) | |
| norm_kq = kq * jax.lax.rsqrt(kq_sum_sqr + epsilon) | |
| return jnp.asarray(norm_kq, dtype=self.dtype) | |
| def __call__(self, xs: Array, deterministic: bool = False) -> KVQTuple: | |
| """Takes a sequence of embeddings as input, and returns keys,values,queries. | |
| First apply pre_attn layernorm, and pre_attn dropout. | |
| Then add learned positional embeddings, if any. | |
| Return (keys, values, queries, queries2). | |
| Args: | |
| xs: input sequence of shape (batch_size, sequence_length, embedding_size) | |
| deterministic: if False, apply dropout. | |
| Returns: | |
| (keys, values, queries, queries2) of shape | |
| (batch_size, sequence_length, num_heads, head_size) | |
| """ | |
| # Project inputs to (keys, values, queries). | |
| (batch_size, num_keys, _) = xs.shape | |
| drop_tile_shape = (1, 128, self.embedding_size) | |
| # Apply layernorm to input, rather than the output. | |
| # This provides better gradients through the resnet, and also avoids | |
| # the need for a prolonged warmup phase (https://arxiv.org/abs/2002.04745) | |
| # Layernorm for self-attention. | |
| logging.info("kvq: pre_attn xs = %r", xs) | |
| xs = jnp.asarray(xs, dtype=self.dtype) | |
| xs = self.pre_attn_layernorm(xs) | |
| # Add (optional) learned position embeddings. | |
| if self.num_position_embeddings > 0: | |
| assert xs.ndim == 3 # (b, sequence_length, embedding_size) | |
| assert xs.shape[-2] == self.num_position_embeddings | |
| logging.info("kvq: learned positions.") | |
| xs_pos = jnp.asarray(self.position_embeddings, dtype=self.dtype) | |
| xs_pos = jnp.expand_dims(xs_pos, 0) # Add batch dimension. | |
| xs = xs + xs_pos | |
| # Pre-attention dropout. | |
| if self.pre_attn_dropout: | |
| logging.info("kvq: pre_attn dropout.") | |
| xs = nn_components.tiled_dropout(xs, drop_tile_shape, self.dropout_rate, | |
| rng_function=self._get_dropout_rng, | |
| deterministic=deterministic) | |
| # Compute keys and values. | |
| keys = self.keys_layer(xs) # (b, num_keys, num_heads * head_size) | |
| values = self.values_layer(xs) | |
| # Compute queries and cross-attention queries if necessary. | |
| if self.has_queries: | |
| queries = self.queries_layer(xs) # (b, num_keys, n_heads * head_size) | |
| logging.info("kvq: queries = %r", queries) | |
| else: | |
| queries = None | |
| if self.has_queries2: | |
| queries2 = self.queries2_layer(xs) # (b, num_keys, n_heads * head_size) | |
| logging.info("kvq: queries2 = %r", queries2) | |
| else: | |
| queries2 = None | |
| # Reshape to split num_heads, head_size into separate dimensions. | |
| kv_shape = (batch_size, num_keys, self.num_heads, self.head_size) | |
| keys = jnp.reshape(keys, kv_shape) | |
| values = jnp.reshape(values, kv_shape) | |
| if queries is not None: | |
| queries = jnp.reshape(queries, kv_shape) | |
| if queries2 is not None: | |
| queries2 = jnp.reshape(queries2, kv_shape) | |
| if self.normalize_keys: | |
| # Normalize both keys and queries. | |
| # The learned attention_scale_factors() will return non-None. | |
| logging.info("kvq: normalize keys, queries.") | |
| keys = self._normalize_kq(keys) | |
| if queries is not None: | |
| queries = self._normalize_kq(queries) | |
| if queries2 is not None: | |
| queries2 = self._normalize_kq(queries2) | |
| else: | |
| # Scale queries by 1 / sqrt(d) when using unnormalized keys,queries. | |
| d_scale = jax.lax.rsqrt(float(self.head_size)).astype(self.dtype) | |
| logging.info("kvq: scale queries by 1/sqrt(d).") | |
| if queries is not None: | |
| queries = queries * d_scale | |
| if queries2 is not None: | |
| queries2 = queries2 * d_scale | |
| # Return keys, values, and queries. | |
| return (keys, values, queries, queries2) | |
| class TransformerBase(nn.Module): | |
| """TransformerBase implements everything except attention. | |
| It handles: | |
| - Projection to (keys, values, queries) before attention. | |
| - Projection MLP back to embedding_size after attention. | |
| - Final FFN layer. | |
| - layernorm, dropout, and normalization of keys and queries. | |
| This functionality is ecapsulated here so that it can be reused with more | |
| complicated attention mechanisms. | |
| """ | |
| # Options set by parent module. | |
| mode: str | |
| embedding_size: int | |
| num_heads: int | |
| head_size: int | |
| cross_attention_q: bool = False # Additional q for cross-attention. | |
| cross_attention_kv: bool = False # Additional kv for cross-attention. | |
| num_position_embeddings: int = 0 # Learned position embeddings. | |
| num_cross_position_embeddings: int = 0 # Learned position embeddings. | |
| # Configurable hyperparameters. | |
| attn_mlp_factory: Callable[[int], nn.Module] = gin.REQUIRED | |
| ffn_factory: Callable[[int], nn.Module] = gin.REQUIRED | |
| gate_type: str = "residual" | |
| single_gate: bool = False | |
| skip_ffn: bool = False | |
| normalize_keys: bool = True | |
| dropout_rate: float = 0.0 | |
| pre_attn_dropout: bool = True | |
| post_attn_dropout: bool = False | |
| pre_ffn_dropout: bool = False | |
| post_ffn_dropout: bool = True | |
| dtype: Any = jnp.float32 | |
| def is_training(self) -> bool: | |
| return self.mode == "train" | |
| def _get_dropout_rng(self): | |
| return self.make_rng("dropout") | |
| def _normalize_kq(self, kq: Array) -> Array: | |
| """Normalize function for keys and queries.""" | |
| epsilon = jnp.array(1.0e-6, dtype=self.dtype) | |
| kq_sum_sqr = jnp.sum(jnp.square(kq), axis=-1, keepdims=True) | |
| norm_kq = kq * jax.lax.rsqrt(kq_sum_sqr + epsilon) | |
| return jnp.asarray(norm_kq, dtype=self.dtype) | |
| def setup(self): | |
| # Keys,values,queries for self-attention; queries for cross-attention. | |
| self._kvq = KVQLayer(self.embedding_size, self.num_heads, self.head_size, | |
| has_queries=True, | |
| has_queries2=self.cross_attention_q, | |
| num_position_embeddings=self.num_position_embeddings, | |
| normalize_keys=self.normalize_keys, | |
| pre_attn_dropout=self.pre_attn_dropout, | |
| dropout_rate=self.dropout_rate, | |
| dtype=self.dtype) | |
| # Keys,values, attention_scale for cross-attention. | |
| if self.cross_attention_kv: | |
| # Use a full kvq layer, with layernorm and attention scale. | |
| self._cross_kv = KVQLayer( | |
| self.embedding_size, self.num_heads, self.head_size, | |
| has_queries=False, | |
| has_queries2=False, | |
| num_position_embeddings=self.num_cross_position_embeddings, | |
| normalize_keys=self.normalize_keys, | |
| pre_attn_dropout=self.pre_attn_dropout, | |
| dropout_rate=self.dropout_rate, | |
| dtype=self.dtype) | |
| elif self.cross_attention_q: | |
| # No separate keys,values for cross-attention, but we may still need | |
| # cross-attention-scale, so we create our own. | |
| assert self.num_cross_position_embeddings == 0 | |
| if self.normalize_keys: | |
| self.attention_scale2 = self.param("attention_scale2", | |
| jax.nn.initializers.ones, | |
| (self.num_heads,), jnp.float32) | |
| # Post-attention linear projection. | |
| if not self.single_gate: | |
| self.post_attn_mlp = self.attn_mlp_factory( | |
| self.embedding_size, | |
| gate_type=self.gate_type, | |
| final_activation=None, | |
| dtype=self.dtype) # pytype: disable=wrong-keyword-args # trace-all-classes | |
| # Final FNN. | |
| if not self.skip_ffn: | |
| self.ffn = self.ffn_factory( | |
| self.embedding_size, | |
| gate_type=self.gate_type, | |
| final_activation=("tanh" if self.single_gate else None), | |
| dtype=self.dtype) # pytype: disable=wrong-keyword-args # trace-all-classes | |
| # Layernorm. | |
| self.pre_ffn_layernorm = nn_components.LayerNorm() | |
| def force_init(self, xs: Array): | |
| """Force flax initialization of self, prior to use with lax.scan. | |
| Args: | |
| xs: The input sequence that the module will be called with. | |
| """ | |
| logging.info("tbase: Begin forced initialization.") | |
| _ = self.kvq(xs) | |
| batch_size = xs.shape[0] | |
| seq_len = xs.shape[1] | |
| attn_ys_shape = (batch_size, seq_len, self.num_heads, self.head_size) | |
| dummy_attn_ys = jnp.zeros(attn_ys_shape, dtype=self.dtype) | |
| if self.cross_attention_kv or self.cross_attention_q: | |
| dummy_cross_attn_ys = dummy_attn_ys | |
| else: | |
| dummy_cross_attn_ys = None | |
| _ = self.post_attn_ffn(xs, dummy_attn_ys, dummy_cross_attn_ys) | |
| logging.info("tbase: End forced initialization.") | |
| def attention_scale_factors(self) -> AttnScaleTuple: | |
| """Returns the attention scales, when keys and queries are normalized. | |
| Returns: (scale for kv (i.e. queries), scale for cross_kv (i.e queries2)) | |
| """ | |
| sfactor = self._kvq.attention_scale_factor() | |
| if self.cross_attention_kv: | |
| cross_sfactor = self._cross_kv.attention_scale_factor() | |
| elif self.cross_attention_q and self.normalize_keys: | |
| cross_sfactor = jnp.asarray(self.attention_scale2, dtype=self.dtype) | |
| else: | |
| cross_sfactor = None | |
| return (sfactor, cross_sfactor) | |
| def kvq(self, xs: Array) -> KVQTuple: | |
| enable_dropout = self.pre_attn_dropout and self.is_training() | |
| return self._kvq(xs, deterministic=not enable_dropout) | |
| def cross_kv(self, xs: Array) -> Tuple[Array, Array]: | |
| assert self.cross_attention_kv | |
| enable_dropout = self.pre_attn_dropout and self.is_training() | |
| (k, v, _, _) = self._cross_kv(xs, deterministic=not enable_dropout) | |
| return (k, v) | |
| def post_attn_ffn(self, xs: Array, attn_ys: Array, | |
| cross_attn_ys: Optional[Array]) -> Array: | |
| """Combines the output of attention with the original input sequence. | |
| Post-attn MLP on attn_ys, followed by resnet/gate. | |
| Pre-FFN layernorm and dropout, then the FFN layer, followed by resnet/gate. | |
| Args: | |
| xs: Original input sequence of shape | |
| (batch_size, sequence_length, embedding_size) | |
| attn_ys: Output of the self-attention module, of shape | |
| (batch_size, sequence_length, num_heads, head_size) | |
| cross_attn_ys: Output of the cross-attention module, of shape | |
| (batch_size, sequence_length, num_heads, head_size) | |
| Returns: | |
| Array of shape (batch_size, sequence_length, embedding_size) | |
| """ | |
| (batch_size, sequence_length, _) = xs.shape | |
| assert attn_ys.shape == (batch_size, sequence_length, | |
| self.num_heads, self.head_size) | |
| no_dropout = not self.is_training() | |
| drop_tile_shape = (1, 128, self.embedding_size) | |
| # Concatenate cross-attention and self-attention results. | |
| if cross_attn_ys is not None: | |
| # Concatenate self-attention and cross-attention results, before | |
| # applying the projection layer. | |
| logging.info("tbase: using cross-attention.") | |
| assert attn_ys.shape == (batch_size, sequence_length, | |
| self.num_heads, self.head_size) | |
| attn_ys = jnp.concatenate([attn_ys, cross_attn_ys], axis=2) | |
| att_ys_num_heads = self.num_heads * 2 | |
| else: | |
| # Only use self-attention. | |
| att_ys_num_heads = self.num_heads | |
| logging.info("tbase: attn_ys = %r", attn_ys) | |
| attn_ys = attn_ys.reshape( | |
| (batch_size, sequence_length, att_ys_num_heads * self.head_size)) | |
| if self.single_gate: | |
| logging.info("tbase: single gate.") | |
| assert not self.skip_ffn | |
| # Skip post-attention linear projection and residual connection. | |
| ys_hidden = xs # The FFN (below) will be gated onto xs (the input). | |
| ffn_in = attn_ys # The input to the FFN is the output of attention. | |
| else: | |
| logging.info("tbase: post-attention MLP.") | |
| # Standard transformer archicture. | |
| # The post-attention MLP applies a linear projection to project attn_ys | |
| # to embedding space. It then uses a residual connection or gate to | |
| # combine the projection with xs. Post-attention dropout is applied | |
| # before the residual/gate. | |
| post_attn_ys = self.post_attn_mlp( | |
| attn_ys, xs, | |
| apply_dropout=self.post_attn_dropout and not no_dropout, | |
| dropout_rate=self.dropout_rate, | |
| drop_tile_shape=drop_tile_shape, | |
| rng_function=self._get_dropout_rng) | |
| # The FFN (below) will be gated onto post_attn_ys (which gates onto xs). | |
| ys_hidden = post_attn_ys | |
| if self.skip_ffn: | |
| logging.info("tbase: skip final FFN. ys = %r", ys_hidden) | |
| return ys_hidden | |
| # The input to the FFN; Layernorm is applied before the FFN. | |
| ffn_in = self.pre_ffn_layernorm(ys_hidden) | |
| logging.info("tbase: pre-FFN layernorm = %r", ffn_in) | |
| # Pre-FFN dropout. | |
| if self.pre_ffn_dropout: | |
| logging.info("tbase: pre-FFN dropout.") | |
| ffn_in = nn_components.tiled_dropout( | |
| ffn_in, drop_tile_shape, self.dropout_rate, | |
| rng_function=self._get_dropout_rng, deterministic=no_dropout) | |
| # FFN layer. | |
| # Large MLP with hidden layers followed by residual connection or gate. | |
| # The MLP will apply post-ffn dropout before the gate. | |
| logging.info("tbase: final FFN") | |
| ys = self.ffn(ffn_in, ys_hidden, | |
| apply_dropout=self.post_ffn_dropout and not no_dropout, | |
| dropout_rate=self.dropout_rate, | |
| drop_tile_shape=drop_tile_shape, | |
| rng_function=self._get_dropout_rng) | |
| logging.info("tbase: ys = %r", ys) | |
| return ys | |