Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """A single transformer layer.""" | |
| from typing import Any, Mapping, NewType, Optional, Sequence, Tuple | |
| from absl import logging | |
| from flax import linen as nn | |
| import gin | |
| import jax | |
| import jax.numpy as jnp | |
| from transformer import attention | |
| from transformer import memory_factory | |
| from transformer import nn_components | |
| from transformer import position | |
| from transformer import position_fourier | |
| from transformer import position_t5 | |
| from transformer import transformer_base | |
| Array = jnp.ndarray | |
| DecoderState = NewType("DecoderState", Mapping[str, Array]) | |
| WindowState = Optional[Tuple[attention.KVITuple, Array]] | |
| KVITuple = attention.KVITuple | |
| class TransformerLayer(nn.Module): | |
| """Full transformer layer, with attention.""" | |
| # Set by DecoderStack | |
| mode: str | |
| batch_size: int | |
| embedding_size: int | |
| cross_attention: bool = False | |
| recurrent_attention: bool = False | |
| memory: Optional[memory_factory.MemoryManager] = None | |
| # Configurable hyper-parameters | |
| num_heads: int = gin.REQUIRED | |
| head_size: int = gin.REQUIRED | |
| window_length: int = gin.REQUIRED | |
| use_long_xl_architecture: bool = True | |
| max_unrolled_windows: int = -1 # Always unroll. | |
| relative_position_type: Optional[str] = "fourier" # {None, "fourier", "t5"} | |
| use_causal_mask: bool = True | |
| attn_dropout_rate: float = 0.0 | |
| recurrent_num_states: int = 0 | |
| recurrent_gate_type: str = "bias" | |
| recurrent_single_gate: bool = False | |
| recurrent_skip_ffn: bool = False | |
| compute_importance: bool = False | |
| memory_num_neighbors: int = 0 | |
| memory_reset_on_new_doc: bool = True | |
| dtype: Any = jnp.float32 | |
| # Modes which support caching of previous keys and values. | |
| supported_modes_for_cache: Sequence[str] = ("train", "test") | |
| update_memory_modes: Sequence[str] = ("train", "test") | |
| def supports_generate(self) -> bool: | |
| return self.use_long_xl_architecture | |
| def _get_cache_name_from_mode(self, mode: str) -> Tuple[str, bool, bool]: | |
| """Get the name of the cache, and whether to update the cache, from mode.""" | |
| # This is a hack to ensure that "generate" steps generate text as a | |
| # continuation of the text that is stored in the "test" cache, | |
| # but it does not update the "test" cache. | |
| if mode == "generate": | |
| assert "test" in self.supported_modes_for_cache | |
| return ("test", False, False) # Use test cache, but don't update it. | |
| elif mode == "init": | |
| return ("train", False, False) # Use training cache for initialization. | |
| else: | |
| return (mode, True, mode in self.update_memory_modes) | |
| def _allocate_cached_kvi(self, mode: str) -> KVITuple: | |
| """Allocate (keys, values, importance) which can be cached between steps.""" | |
| kv_shape = [self.batch_size, self.window_length, | |
| self.num_heads, self.head_size] | |
| imp_shape = [self.batch_size, self.window_length] | |
| def kv_initializer(shape): | |
| return jnp.zeros(shape, dtype=self.dtype) | |
| def imp_initializer(shape): | |
| return jnp.zeros(shape, dtype=self.dtype) | |
| pkeys = self.variable("state", "previous_keys_" + mode, | |
| kv_initializer, kv_shape) | |
| pvals = self.variable("state", "previous_values_" + mode, | |
| kv_initializer, kv_shape) | |
| if self.compute_importance: | |
| pimportance = self.variable("state", "previous_importance_" + mode, | |
| imp_initializer, imp_shape) | |
| else: | |
| pimportance = None | |
| return (pkeys, pvals, pimportance) | |
| def _allocate_cached_recurrent_state(self, mode: str): | |
| rec_num_states = self.recurrent_num_states | |
| st_shape = [self.batch_size, rec_num_states, self.embedding_size] | |
| def st_initializer(shape): | |
| return jnp.zeros(shape, dtype=self.dtype) | |
| return self.variable("state", "recurrent_state_" + mode, | |
| st_initializer, st_shape) | |
| def setup(self): | |
| # Basic transformer functionality: everything except attention. | |
| self.tbase = transformer_base.TransformerBase( | |
| mode=self.mode, | |
| embedding_size=self.embedding_size, | |
| num_heads=self.num_heads, | |
| head_size=self.head_size, | |
| cross_attention_q=self.recurrent_attention or self.cross_attention, | |
| cross_attention_kv=False, # or True to use separate k,v. | |
| num_position_embeddings=0, | |
| num_cross_position_embeddings=0, # or self.recurrent_num_states w/ k,v. | |
| dtype=self.dtype) | |
| # Recurrent transformer functionality. | |
| self.recurrent_tbase = None | |
| if self.recurrent_attention: | |
| # Recurrent transformer layer. | |
| # We use a learned position embedding so that each element of the state | |
| # can learn to query and compute different summaries. | |
| self.recurrent_tbase = transformer_base.TransformerBase( | |
| mode="pure", # Disable dropout, which breaks jax.lax.scan. | |
| embedding_size=self.embedding_size, | |
| num_heads=self.num_heads, | |
| head_size=self.head_size, | |
| cross_attention_q=True, | |
| cross_attention_kv=False, # or True to use separate k,v. | |
| num_position_embeddings=self.recurrent_num_states, | |
| num_cross_position_embeddings=0, # or self.window_length w/ k,v. | |
| gate_type=self.recurrent_gate_type, | |
| single_gate=self.recurrent_single_gate, | |
| skip_ffn=self.recurrent_skip_ffn, | |
| dtype=self.dtype) | |
| # Initial state at start of document. | |
| # We want this to be initially small, but large enough that adafactor | |
| # will scale updates to a reasonable value. | |
| self.recurrent_initial_state = self.param( | |
| "recurrent_initial_state", | |
| jax.nn.initializers.normal(stddev=0.1), | |
| (self.recurrent_num_states, self.embedding_size), jnp.float32) | |
| # Cached state from previous step for BPTT. | |
| rec_state = {} | |
| for mkey in self.supported_modes_for_cache: | |
| rec_state[mkey] = self._allocate_cached_recurrent_state(mkey) | |
| self.cached_recurrent_state = rec_state | |
| # Set up relative position encoding. | |
| if self.relative_position_type == "fourier": | |
| self.relative_positions = position_fourier.RelativeFourierPositions( | |
| num_heads=self.num_heads, | |
| max_number_of_keys=self.window_length, | |
| dtype=self.dtype) | |
| elif self.relative_position_type == "t5": | |
| self.relative_positions = position_t5.T5RelativePositionBiases( | |
| num_buckets=32, # TODO(delesley): Let Gin configure these. | |
| max_distance=128, | |
| num_heads=self.num_heads, | |
| dtype=self.dtype) | |
| elif self.relative_position_type == "rotary": | |
| # Rotary position encodings (RoPE). No learned bias parameters. | |
| self.relative_positions = None | |
| else: | |
| assert self.relative_position_type is None | |
| self.relative_positions = None | |
| # Set up cache for Transformer-XL style architectures. | |
| # A separate cache is created for each each mode (e.g. train, test) | |
| cached_kvi = {} | |
| if self.use_long_xl_architecture: | |
| for mkey in self.supported_modes_for_cache: | |
| cached_kvi[mkey] = self._allocate_cached_kvi(mkey) | |
| self.cached_kvi = cached_kvi | |
| # Set up external memory. | |
| # A separate memory will be created for each mode (e.g. train, test) | |
| mem_layers = {} | |
| if self.memory is not None: | |
| self.memory_bias = self.param("external_memory_bias", nn.zeros, | |
| (self.num_heads,), "float32") | |
| for mkey in self.supported_modes_for_cache: | |
| mlayer = self.memory.create_memory_layer() | |
| # Use setattr to setup the name and module containership hierarchy. | |
| setattr(self, "mem_layer_" + mkey, mlayer) | |
| mem_layers[mkey] = mlayer | |
| self.mem_layers = mem_layers | |
| def _get_cached_kvi(self, start_of_sequence: Array, | |
| mode: str) -> Optional[KVITuple]: | |
| """Returns cached (keys, values, importance) from the previous step.""" | |
| if not self.use_long_xl_architecture: | |
| return None | |
| if mode not in self.cached_kvi: | |
| # No cache, but we're using XL / sliding window, so return zeros. | |
| logging.info("tlayer: using zero as initial XL cache value.") | |
| kvi_shape = (self.batch_size, self.window_length, | |
| self.num_heads, self.head_size) | |
| return attention.initial_kvi(kvi_shape, | |
| self.compute_importance, dtype=self.dtype) | |
| # New documents start with zero_kv. | |
| # Continuing the same document will attend to previous keys/vals. | |
| (pkeys, pvals, pimportance) = self.cached_kvi[mode] | |
| (zkeys, zvals, zimportance) = attention.initial_kvi( | |
| pkeys.value.shape, self.compute_importance, dtype=self.dtype) | |
| # Broadcast start_of_sequence over non-batch dims. | |
| b = self.batch_size | |
| start_of_sequence_kv = jnp.reshape(start_of_sequence, [b, 1, 1, 1]) | |
| prev_keys = jnp.where(start_of_sequence_kv, zkeys, pkeys.value) | |
| prev_vals = jnp.where(start_of_sequence_kv, zvals, pvals.value) | |
| if self.compute_importance: | |
| start_of_sequence_imp = jnp.reshape(start_of_sequence, [b, 1]) | |
| prev_importance = jnp.where(start_of_sequence_imp, zimportance, | |
| pimportance.value) | |
| else: | |
| prev_importance = None | |
| logging.debug("tlayer: start_of_sequence = %r", start_of_sequence) | |
| logging.info("tlayer: prev_keys[%r] = %r", mode, prev_keys) | |
| logging.debug("tlayer: prev_importance[%r] = %r", mode, prev_importance) | |
| return (prev_keys, prev_vals, prev_importance) | |
| def _set_cached_kvi(self, next_kvi: KVITuple, mode: str): | |
| """Caches the last (keys, values, importance) from the current step.""" | |
| if not self.use_long_xl_architecture: | |
| return | |
| if mode not in self.cached_kvi: | |
| return | |
| (pkeys, pvals, pimportance) = self.cached_kvi[mode] | |
| (nkeys, nvals, nimportance) = next_kvi # From last window | |
| logging.info("tlayer: next_keys[%r] = %r", mode, nkeys) | |
| pkeys.value = nkeys | |
| pvals.value = nvals | |
| if self.compute_importance: | |
| logging.info("tlayer: next_importance[%r] = %r", mode, nimportance) | |
| pimportance.value = nimportance | |
| def _get_cached_recurrent_state(self, start_of_sequence: Array, | |
| mode: str) -> Optional[Array]: | |
| """Returns cached recurrent state from the previous step.""" | |
| if not self.recurrent_attention: | |
| return None | |
| if mode not in self.cached_recurrent_state: | |
| return None | |
| b = self.batch_size | |
| rstate = self.cached_recurrent_state[mode].value | |
| istate = jnp.asarray(self.recurrent_initial_state, dtype=self.dtype) | |
| istate = istate[jnp.newaxis, :, :] # Add batch dimension for broadcast. | |
| logging.info("tlayer: get_cached_recurrent_state %r, %r", istate, rstate) | |
| start_of_sequence_st = jnp.reshape(start_of_sequence, (b, 1, 1)) | |
| return jnp.where(start_of_sequence_st, istate, rstate) | |
| def _set_cached_recurrent_state(self, next_state: Array, mode: str): | |
| """Store the next recurrent state in the cache.""" | |
| if not self.recurrent_attention: | |
| return | |
| if mode not in self.cached_recurrent_state: | |
| return | |
| logging.info("tlayer: set_cached_recurrent_state %r", next_state) | |
| rstate = self.cached_recurrent_state[mode] | |
| rstate.value = next_state | |
| def _query_external_memory(self, keys: Array, values: Array, queries: Array, | |
| start_of_sequence: Array, | |
| mode: str, update_memory: bool): | |
| """Query and update external memory.""" | |
| if self.memory is None: | |
| return None | |
| # Make sure we initialize (allocate) the external memories for all modes. | |
| # Per the flax lazy module initialization scheme, setup() will not be | |
| # invoked on a submodule until that module is actually used. | |
| if mode == "init": | |
| for (_, mlayer) in self.mem_layers.items(): | |
| (_, _) = mlayer.topk_retrieval(queries, self.memory_num_neighbors) | |
| mode = "train" # Pretend we're in training mode during initialization. | |
| if mode not in self.mem_layers: | |
| return None | |
| if self.memory_num_neighbors == 0: | |
| raise ValueError("Using memory, but num_neighbors == 0") | |
| # Grab the appropriate memory layer for the current mode. | |
| memory_layer = self.mem_layers[mode] | |
| # Clear the relevant memories at the start of each new document. | |
| if update_memory and self.memory_reset_on_new_doc: | |
| # The number of "datasets" is batch_dim * num_heads. | |
| # jnp.repeat will "broadcast" start_of_sequence over num_heads. | |
| # E.g. if start_of_sequence = [True, False] and 4 heads, | |
| # jnp.repeat will yield [T, T, T, T, F, F, F, F] | |
| memory_layer.reset(jnp.repeat(start_of_sequence, self.num_heads)) | |
| # Query external memory, with queries. | |
| (rkeys, rvals) = memory_layer.topk_retrieval(queries, | |
| self.memory_num_neighbors) | |
| logging.info("tlayer: query external memory (%r): rvals = %r", mode, rvals) | |
| # Sanity check all dimensions are as expected. | |
| assert rkeys.ndim == 5 # (b, seq_len, num_heads, num_neigh, head_dim) | |
| assert rvals.ndim == 5 | |
| assert rkeys.shape == rvals.shape | |
| assert rkeys.shape[0] == queries.shape[0] # batch size | |
| assert rkeys.shape[1] == queries.shape[1] # sequence length | |
| assert rkeys.shape[2] == self.num_heads | |
| assert rkeys.shape[3] == self.memory_num_neighbors | |
| assert rkeys.shape[4] == self.head_size | |
| # Update external memory, with (keys, values). | |
| if update_memory: | |
| memory_layer.update(keys, values) | |
| return (rkeys, rvals) | |
| def __call__(self, xs: Array, start_of_sequence: Array, | |
| *, | |
| importance: Optional[Array] = None, | |
| cross_attention_kv: Optional[Tuple[Array, Array]] = None, | |
| window_state: Optional[WindowState] = None, | |
| decoder_state: Optional[DecoderState] = None) -> ( | |
| Tuple[Array, Optional[Array], Optional[WindowState], | |
| Optional[DecoderState], Any]): | |
| """Computes attention over a sequence of inputs. | |
| Args: | |
| xs: input sequence of shape (batch_size, sequence_length, num_hidden) | |
| start_of_sequence: An input array of shape (batch_size) | |
| --- The following must be passed by keyword only. --- | |
| importance: Array of shape (batch_size, sequence_length). | |
| An importance bias for attention. | |
| cross_attention_kv: Keys and values from encoder for cross-attention. | |
| window_state: State object which contains context from the prior | |
| window when using a transformer-XL or sliding window. | |
| Initially created with load_window_state(). | |
| decoder_state: State object for autoregressive decoding, initially | |
| created with from init_decoder_state(). | |
| Returns: | |
| (ys: outputs of shape (batch_size, sequence_length, num_hidden), | |
| importance: importance values for the next layer, | |
| next_window_state: state to pass to the next window, | |
| next_decoder_state: next decoder state for autoregressive decoding, | |
| viz_dict: dictionary of visualizations | |
| ) | |
| """ | |
| xs = jnp.asarray(xs, dtype=self.dtype) | |
| logging.info("tlayer: xs = %r", xs) | |
| logging.info("tlayer: recurrent = %r", self.recurrent_attention) | |
| logging.info("tlayer: cross-attention = %r", cross_attention_kv is not None) | |
| is_training = (self.mode == "train") | |
| # Compute keys, values and queries. | |
| # --------------------------------- | |
| logging.info("tlayer: compute keys,values,queries.") | |
| (keys, values, queries, queries2) = self.tbase.kvq(xs) | |
| attention_scale_factors = self.tbase.attention_scale_factors() | |
| (_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d) | |
| # Get biases and masks that are shared across windows. | |
| # ---------------------------------------------------- | |
| if decoder_state is not None: | |
| logging.info("tlayer: using autoregressive decoder.") | |
| # When decoding, prior keys,values are loaded from the decoder state. | |
| # Other values are precomputed, and loaded from the decoder state. | |
| # The decoder state will be updated with the current token. | |
| assert window_state is None | |
| prev_kvi = None | |
| recurrent_state = None # Use precomputed recurrent_kvq. | |
| cross_attention_kv = None | |
| rel_position_bias = decoder_state["relative_position_bias"] | |
| causal_mask = None | |
| dropout_multiplier = None | |
| # Reuse cached recurrent keys,values for each token. | |
| cached_recurrent_kvq = decoder_state["recurrent_kvq"] | |
| if cached_recurrent_kvq is not None: | |
| assert cross_attention_kv is None | |
| cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1]) | |
| del cached_recurrent_kvq | |
| # Get a full window of keys,values and update decoder state. | |
| (decoder_state, keys, values) = self._next_decoder_state( | |
| decoder_state, keys, values) | |
| # Each query attends to window_length prior keys. | |
| assert keys.shape[1] == self.window_length | |
| kq_relative_offset = self.window_length | |
| else: | |
| logging.info("tlayer: windowed attention.") | |
| # When training, attention is done using windows or chunks, and prior | |
| # context (e.g. keys,values from the previous window) is stored in the | |
| # window_state object. | |
| (prev_kvi, recurrent_state) = window_state # pytype: disable=attribute-error | |
| # Get the size of the sliding window for pos bias, dropout, & causal mask. | |
| (num_queries, num_keys) = attention.sliding_attention_window_shape( | |
| (keys, values, importance), prev_kvi, queries, | |
| window_length=self.window_length) | |
| kq_relative_offset = num_keys - num_queries | |
| # Get the relative position bias. | |
| # The bias doesn't depend on the query content, and so can be precomputed. | |
| if self.relative_positions is not None: | |
| rel_position_bias = self.relative_positions(num_queries, num_keys, | |
| bidirectional=False) | |
| logging.info("tlayer: %s relative bias = %r", | |
| self.relative_position_type, rel_position_bias) | |
| else: | |
| rel_position_bias = None | |
| # Get causal mask. | |
| if self.use_causal_mask: | |
| causal_mask = position.causal_mask(num_queries, num_keys, | |
| window_length=self.window_length) | |
| logging.info("tlayer: causal mask = %r", causal_mask) | |
| else: | |
| causal_mask = None | |
| # Apply dropout to the attention matrix. | |
| # The mask will be broadcast across batches and windows. | |
| if self.attn_dropout_rate > 0.0 and is_training: | |
| dropout_rng = self.make_rng("dropout") | |
| attn_shape = (self.num_heads, num_queries, num_keys) | |
| dropout_multiplier = nn_components.dropout_multiplier_mask( | |
| dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype) | |
| logging.info("tlayer: attn_dropout = %r", dropout_multiplier) | |
| else: | |
| dropout_multiplier = None | |
| # Load and store values into external memory, if memory is not None. | |
| # ------------------------------------------------------------------ | |
| (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode) | |
| external_kv = self._query_external_memory( | |
| keys, values, queries, | |
| start_of_sequence=start_of_sequence, mode=mode, | |
| update_memory=decoder_state is None and update_memory) | |
| if self.memory is not None: | |
| external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype) | |
| external_memory_bias = jnp.reshape(external_memory_bias, | |
| (1, 1, num_heads, 1)) | |
| external_memory_bias = jax.nn.sigmoid(external_memory_bias) | |
| else: | |
| external_memory_bias = None | |
| # Compute the number of windows. | |
| # ------------------------------ | |
| if sequence_length < self.window_length: | |
| num_windows = 1 # Happens with autoregressive decoding. | |
| elif sequence_length == self.window_length: | |
| num_windows = 1 | |
| if self.use_long_xl_architecture: | |
| assert prev_kvi is not None | |
| else: | |
| if not self.use_long_xl_architecture: | |
| raise ValueError("Can only use sliding window with Transformer XL.") | |
| num_windows = sequence_length // self.window_length | |
| if (num_windows * self.window_length) != sequence_length: | |
| raise ValueError(f"Window length {self.window_length} must be a " + | |
| f"multiple of sequence length {sequence_length}") | |
| logging.info("tlayer: num_windows = %d.", num_windows) | |
| # Define the function to do attention within a single window. | |
| # --------------------------------------------------------- | |
| def single_window_attention(carry, inputs_w): | |
| # This function uses the following variables from the outer scope. | |
| # They are listed here for clarity. | |
| nonlocal rel_position_bias | |
| nonlocal causal_mask | |
| nonlocal kq_relative_offset | |
| nonlocal dropout_multiplier | |
| nonlocal attention_scale_factors | |
| nonlocal external_memory_bias | |
| nonlocal cross_attention_kv # externally supplied. | |
| # keys,values,queries over the whole sequence will be split into chunks. | |
| # xs_w, kvqi_w, etc. are the chunk for the current window. | |
| (prev_kvi_w, rec_state) = carry # carried from one window to the next. | |
| (kvqi_w, external_kv_w) = inputs_w # inputs to the current window. | |
| # (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w | |
| # Concatenate keys,values from the previous window with the current | |
| # window to implement sliding window attention. | |
| (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w) | |
| (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w | |
| # Perform recurrent attention within the current window to get the next | |
| # recurrent state, and set up cross attention. | |
| if rec_state is not None: | |
| logging.info("tlayer: recurrent attention.") | |
| # NOTE -- recurrent states and input tokens are handled separately, | |
| # because they have separate learned positional embeddings. Due to | |
| # the way TransformerBase does cross-attention, this means that we use | |
| # separate key,value layers for rec_state and tokens_w. | |
| # Keys, values, queries from recurrent state. | |
| logging.info("tlayer: recurrent kvq.") | |
| rec_kvq = self.recurrent_tbase.kvq(rec_state) | |
| r_scale_factors = self.recurrent_tbase.attention_scale_factors() | |
| (r_keys, r_values, r_queries, r_queries2) = rec_kvq | |
| # Joint attention over both recurrent states and input tokens. | |
| logging.info("tlayer: recurrent self-attention.") | |
| r_attn_ys = attention.simple_attention( | |
| r_keys, r_values, r_queries, None, | |
| scale_factor=r_scale_factors[0], | |
| dtype=self.dtype) | |
| logging.info("tlayer: recurrent cross-attention.") | |
| r_cross_attn_ys = attention.simple_attention( | |
| keys_w, values_w, r_queries2, importance_w, | |
| scale_factor=r_scale_factors[1], | |
| dtype=self.dtype) | |
| # Recurrent post-attention FFN. | |
| logging.info("tlayer: recurrent ffn.") | |
| next_rec_state = self.recurrent_tbase.post_attn_ffn( | |
| rec_state, r_attn_ys, r_cross_attn_ys) | |
| # Get keys and values for cross-attention from recurrent state. | |
| assert cross_attention_kv is None | |
| local_cross_attention_kv = (r_keys, r_values) | |
| else: | |
| # Get keys and values for cross-attention from external argument. | |
| next_rec_state = None | |
| local_cross_attention_kv = cross_attention_kv | |
| # If using RoPE, keys and queries are rotated before self-attention. | |
| if self.relative_position_type == "rotary": | |
| logging.info("Using rotary position encodings (RoPE), offset = %d", | |
| kq_relative_offset) | |
| (keys_w, queries_w) = position.rotate_kq(keys_w, queries_w, | |
| max_wavelength=10_000, | |
| offset=kq_relative_offset) | |
| # Self-attention over input tokens. | |
| logging.info("tlayer: self-attention.") | |
| attn_ys_w = attention.simple_attention( | |
| keys_w, values_w, queries_w, importance_w, | |
| relative_position_bias=rel_position_bias, | |
| scale_factor=attention_scale_factors[0], | |
| causal_mask=causal_mask, | |
| dropout_multiplier=dropout_multiplier, | |
| dtype=self.dtype) | |
| # Attention over external memory. | |
| if external_kv_w is not None: | |
| (external_keys_w, external_values_w) = external_kv_w | |
| y_ext = attention.external_attention( | |
| external_keys_w, external_values_w, queries_w, | |
| scale_factor=attention_scale_factors[0]) | |
| if external_memory_bias is not None: | |
| ebias = external_memory_bias | |
| logging.info("tlayer: using external memory bias = %r", ebias) | |
| attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias) | |
| else: | |
| attn_ys_w += y_ext | |
| # Cross attention from input tokens to encoder or recurrent state. | |
| if local_cross_attention_kv is not None: | |
| logging.info("tlayer: cross-attention.") | |
| (c_keys, c_values) = local_cross_attention_kv | |
| # Cross-attention using queries2. | |
| cross_attn_ys_w = attention.simple_attention( | |
| c_keys, c_values, queries2_w, None, | |
| scale_factor=attention_scale_factors[1], | |
| dtype=self.dtype) | |
| else: | |
| cross_attn_ys_w = None | |
| # End function single_window_attention(...) | |
| return ((next_kvi_w, next_rec_state), | |
| (attn_ys_w, cross_attn_ys_w)) | |
| # Initialize recurrent_tbase before calling jax.lax.scan. | |
| # Otherwise flax will throw a tantrum. | |
| if (self.recurrent_attention and 0 <= self.max_unrolled_windows and | |
| self.max_unrolled_windows < num_windows): | |
| logging.info("tlayer: force initialization of recurrent_tbase.") | |
| self.recurrent_tbase.force_init(recurrent_state) | |
| # Perform sliding window attention over all keys,values,queries. | |
| # -------------------------------------------------------------- | |
| initial_carry = (prev_kvi, recurrent_state) # window state. | |
| kvqi = (keys, values, queries, queries2, importance) | |
| attn_inputs = (kvqi, external_kv) | |
| (next_carry, attn_outputs) = attention.split_and_scan( | |
| single_window_attention, | |
| initial_carry, | |
| attn_inputs, | |
| sections=num_windows, | |
| axis=1, | |
| max_unrolled_windows=self.max_unrolled_windows) | |
| (attn_ys, cross_attn_ys) = attn_outputs | |
| logging.info("tlayer: End windows.") | |
| # Post-attention MLP, resnet, and FFN. | |
| # ------------------------------------ | |
| logging.info("tlayer: final FFN.") | |
| ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys) | |
| importance_output = None | |
| next_window_state = next_carry if window_state is not None else None | |
| viz_dict = {} # Visualizations, not currently enabled. | |
| return (ys, importance_output, next_window_state, decoder_state, viz_dict) | |
| def load_window_state(self, start_of_sequence: Array) -> WindowState: | |
| """Load cached state that is passed from one window to the next.""" | |
| (mode, _, _) = self._get_cache_name_from_mode(self.mode) | |
| prev_kvi = self._get_cached_kvi(start_of_sequence, mode) | |
| rec_state = self._get_cached_recurrent_state(start_of_sequence, mode) | |
| if prev_kvi is not None: | |
| logging.info("tlayer: Loaded keys,values for mode %s from cache %s", | |
| self.mode, mode) | |
| else: | |
| logging.info("tlayer: Skipping XL cache for mode %s.", self.mode) | |
| if rec_state is not None: | |
| logging.info("tlayer: Loaded recurrent state for mode %s from cache %s.", | |
| self.mode, mode) | |
| return (prev_kvi, rec_state) | |
| def store_window_state(self, window_state: WindowState): | |
| """Write window state to the cache.""" | |
| (mode, update_cache, _) = self._get_cache_name_from_mode(self.mode) | |
| (next_kvi, next_rec_state) = window_state # pytype: disable=attribute-error | |
| if update_cache and next_kvi is not None: | |
| logging.info("tlayer: Storing keys,values for mode %s in cache %s.", | |
| self.mode, mode) | |
| self._set_cached_kvi(next_kvi, mode) | |
| else: | |
| logging.info("tlayer: Skipping XL cache update for mode %s.", self.mode) | |
| if update_cache and next_rec_state is not None: | |
| logging.info("tlayer: Storing recurrent state for mode %s in cache %s.", | |
| self.mode, mode) | |
| self._set_cached_recurrent_state(next_rec_state, mode) | |
| def get_recurrent_kv(self, window_state: WindowState): | |
| """Get the recurrent keys,values from window_state.""" | |
| # TODO(delesley): optimize. | |
| # This isn't ideal, because we wind up computing the recurrent keys,values | |
| # twice -- once within the sliding window above, and again in the | |
| # DecoderStack, so they can be passed to other layers. However, the | |
| # plumbing is a lot simpler this way. | |
| if window_state is None: | |
| return None | |
| (_, rec_state) = window_state | |
| if rec_state is None: | |
| return None | |
| logging.info("tlayer: get_recurrent_kv.") | |
| (r_keys, r_values, _, _) = self.recurrent_tbase.kvq(rec_state) | |
| return (r_keys, r_values) | |
| def init_decoder_state(self, sequence_length: int, | |
| start_of_sequence: Array) -> DecoderState: | |
| """Initialize decoder state for autoregressive generation. | |
| Args: | |
| sequence_length: The maximum length of the sequence to generate. | |
| start_of_sequence: Array of boolean of shape (batch_size,) | |
| True if starting a new sequence (with no prefix). | |
| Returns: | |
| A state object that can be passed to __call__. | |
| """ | |
| # Note that generate always uses a local context of size window_length. | |
| # Training should be set up appropriately. | |
| if not self.use_long_xl_architecture: | |
| raise ValueError("Generation is only supported for transformer XL.") | |
| if not self.use_causal_mask: | |
| raise ValueError("Generator must have been trained with a causal mask.") | |
| (mode, _, _) = self._get_cache_name_from_mode(self.mode) | |
| # Get relative position bias. | |
| if self.relative_positions is not None: | |
| # Relative positions for all tokens *prior* to the current token. | |
| # The causal mask prevents each token from attending to itself. | |
| rel_position_bias = self.relative_positions(1, self.window_length, | |
| offset=self.window_length, | |
| bidirectional=False) | |
| else: | |
| rel_position_bias = None | |
| # Initialize autoregressive storage for (key, value) pairs. | |
| # Include space for a prefix of window_length tokens. | |
| num_keys = sequence_length + self.window_length | |
| stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size) | |
| stored_keys = jnp.zeros(stored_shape, dtype=self.dtype) | |
| stored_values = jnp.zeros(stored_shape, dtype=self.dtype) | |
| start_index = self.window_length | |
| # Copy keys,values from cache into storage, for use as a prefix. | |
| prev_kvi = self._get_cached_kvi(start_of_sequence, mode) | |
| if prev_kvi is not None: | |
| (pkeys, pvals, prev_imps) = prev_kvi | |
| assert prev_imps is None # Not yet supported. | |
| assert pkeys.ndim == 4 | |
| assert pkeys.shape[1] == self.window_length # (b, wlen, num_heads, d) | |
| stored_keys = jax.lax.dynamic_update_slice_in_dim( | |
| stored_keys, pkeys, 0, axis=1) | |
| stored_values = jax.lax.dynamic_update_slice_in_dim( | |
| stored_values, pvals, 0, axis=1) | |
| # Grab the current recurrent_state, and precompute keys,values,queries. | |
| rstate = self._get_cached_recurrent_state(start_of_sequence, mode) | |
| if rstate is not None: | |
| recurrent_kvq = self.recurrent_tbase.kvq(rstate) | |
| else: | |
| recurrent_kvq = None | |
| decoder_state_dict = { | |
| "keys": stored_keys, | |
| "values": stored_values, | |
| "current_index": start_index, | |
| "relative_position_bias": rel_position_bias, | |
| "recurrent_kvq": recurrent_kvq | |
| } | |
| return DecoderState(decoder_state_dict) | |
| def _next_decoder_state(self, decoder_state: DecoderState, | |
| keys: Array, values: Array) -> Tuple[ | |
| DecoderState, Array, Array]: | |
| """Compute the next decoder state, and return keys,values to attend to. | |
| The keys,values returned from this function are drawn from the prior | |
| decoding state, and comprise a full window of local context. | |
| Args: | |
| decoder_state: The current decoder state, initially created using | |
| init_decoder_state(). | |
| keys: The key for the current token, of shape (batch_size, 1, dim) | |
| values: The value for the current token of shape (batch_size, 1, dim) | |
| Returns: | |
| (next_decoder_state, | |
| window of keys of shape (batch_size, window_length, dim), | |
| window of values of shape (batch_size, window_length, dim)) | |
| """ | |
| assert keys.shape[1] == 1 # single-token autoregressive decoding. | |
| logging.info("attn_layer: next decoder state; key = %r", keys) | |
| # Unpack decoder_state | |
| stored_keys = decoder_state["keys"] | |
| stored_values = decoder_state["values"] | |
| curr_index = decoder_state["current_index"] | |
| # Slice to get window_length-sized chunk of previous keys,values. | |
| out_decoder_state = {} | |
| curr_win_index = curr_index - self.window_length | |
| out_keys = jax.lax.dynamic_slice_in_dim( | |
| stored_keys, curr_win_index, self.window_length, axis=1) | |
| out_values = jax.lax.dynamic_slice_in_dim( | |
| stored_values, curr_win_index, self.window_length, axis=1) | |
| # Write current keys,values to stored keys, values. | |
| stored_keys = jax.lax.dynamic_update_slice_in_dim( | |
| stored_keys, keys, curr_index, axis=1) | |
| stored_values = jax.lax.dynamic_update_slice_in_dim( | |
| stored_values, values, curr_index, axis=1) | |
| curr_index = curr_index + 1 | |
| # Pack a new decoder_state object. | |
| out_decoder_state["keys"] = stored_keys | |
| out_decoder_state["values"] = stored_values | |
| out_decoder_state["current_index"] = curr_index | |
| out_decoder_state["relative_position_bias"] = ( | |
| decoder_state["relative_position_bias"]) | |
| out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"] | |
| return (DecoderState(out_decoder_state), out_keys, out_values) | |