| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | """A single transformer layer in inference mode.
|
| |
|
| | Modified
|
| | https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py
|
| | To accommodate sequence packing + kv cache + relative position during test time.
|
| | """
|
| |
|
| | from typing import Callable, Mapping, NewType, Optional, Tuple
|
| |
|
| | from absl import logging
|
| | import gin
|
| | import jax
|
| | import jax.numpy as jnp
|
| | from transformer import attention
|
| | from transformer import nn_components
|
| | from transformer import position
|
| | from transformer import transformer_layer
|
| |
|
| |
|
| | Array = jnp.ndarray
|
| | DecoderState = NewType("DecoderState", Mapping[str, Array])
|
| | WindowState = Optional[Tuple[attention.KVITuple, Array]]
|
| |
|
| |
|
| | @jax.vmap
|
| | def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array:
|
| | """Update a stored keys/values slice for different-lengthed seqs in batch."""
|
| | return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0)
|
| |
|
| |
|
| | def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]:
|
| | @jax.vmap
|
| | def fn(array: Array, idx: Array) -> Array:
|
| | return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0)
|
| |
|
| | return fn
|
| |
|
| |
|
| | @gin.configurable
|
| | class TransformerLayerGenerate(transformer_layer.TransformerLayer):
|
| | """Full transformer layer, with attention."""
|
| |
|
| | def _next_decoder_state(
|
| | self, decoder_state: DecoderState, keys: Array, values: Array
|
| | ) -> Tuple[DecoderState, Array, Array]:
|
| | """Compute the next decoder state, and return keys,values to attend to.
|
| |
|
| | The keys,values returned from this function are drawn from the prior
|
| | decoding state, and comprise a full window of local context.
|
| |
|
| | Args:
|
| | decoder_state: The current decoder state, initially created using
|
| | init_decoder_state().
|
| | keys: The key for the current token, of shape (batch_size, 1, dim)
|
| | values: The value for the current token of shape (batch_size, 1, dim)
|
| |
|
| | Returns:
|
| | (next_decoder_state,
|
| | window of keys of shape (batch_size, window_length, dim),
|
| | window of values of shape (batch_size, window_length, dim))
|
| | """
|
| |
|
| | assert keys.shape[1] == 1
|
| |
|
| |
|
| | stored_keys = decoder_state["keys"]
|
| | stored_values = decoder_state["values"]
|
| | curr_index = decoder_state["current_index"]
|
| |
|
| |
|
| | out_decoder_state = {}
|
| | curr_win_index = curr_index - self.window_length
|
| |
|
| |
|
| |
|
| | out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index)
|
| |
|
| |
|
| |
|
| | out_values = slice_in_dim_1(self.window_length)(
|
| | stored_values, curr_win_index
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| | stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index)
|
| |
|
| |
|
| | stored_values = update_slice_in_dim_1(stored_values, values, curr_index)
|
| | curr_index = curr_index + 1
|
| |
|
| |
|
| | out_decoder_state["keys"] = stored_keys
|
| | out_decoder_state["values"] = stored_values
|
| | out_decoder_state["current_index"] = curr_index
|
| | out_decoder_state["relative_position_bias"] = decoder_state[
|
| | "relative_position_bias"
|
| | ]
|
| | out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"]
|
| |
|
| | return (DecoderState(out_decoder_state), out_keys, out_values)
|
| |
|
| | def __call__(
|
| | self,
|
| | xs: Array,
|
| | start_of_sequence: Array,
|
| | *,
|
| | importance: Optional[Array] = None,
|
| | cross_attention_kv: Optional[Tuple[Array, Array]] = None,
|
| | window_state: Optional[WindowState] = None,
|
| | decoder_state: Optional[DecoderState] = None,
|
| | ):
|
| | """Computes attention over a sequence of inputs.
|
| |
|
| | Args:
|
| | xs: input sequence of shape (batch_size, sequence_length, num_hidden)
|
| | start_of_sequence: An input array of shape (batch_size) --- The following
|
| | must be passed by keyword only. ---
|
| | importance: Array of shape (batch_size, sequence_length). An importance
|
| | bias for attention.
|
| | cross_attention_kv: Keys and values from encoder for cross-attention.
|
| | window_state: State object which contains context from the prior window
|
| | when using a transformer-XL or sliding window. Initially created with
|
| | load_window_state().
|
| | decoder_state: State object for autoregressive decoding, initially created
|
| | with from init_decoder_state().
|
| |
|
| | Returns:
|
| | (ys: outputs of shape (batch_size, sequence_length, num_hidden),
|
| | importance_score: importance score for the next layer,
|
| | next_window_state: state to pass to the next window,
|
| | next_decoder_state: next decoder state for autoregressive decoding,
|
| | viz_dict: dictionary of visualizations
|
| | )
|
| | """
|
| |
|
| | xs = jnp.asarray(xs, dtype=self.dtype)
|
| | logging.info("tlayer: recurrent = %r", self.recurrent_attention)
|
| | logging.info("tlayer: compute_importance = %r", self.compute_importance)
|
| |
|
| | is_training = self.mode == "train"
|
| |
|
| |
|
| |
|
| | logging.info("tlayer: compute keys,values,queries.")
|
| | (keys, values, queries, queries2) = self.tbase.kvq(xs)
|
| | attention_scale_factors = self.tbase.attention_scale_factors()
|
| | (_, sequence_length, num_heads, _) = queries.shape
|
| |
|
| |
|
| |
|
| | if decoder_state is not None:
|
| | logging.info("tlayer: using autoregressive decoder.")
|
| |
|
| |
|
| |
|
| | assert window_state is None
|
| |
|
| | prev_kvi = None
|
| | recurrent_state = None
|
| | cross_attention_kv = None
|
| | rel_position_bias = decoder_state["relative_position_bias"]
|
| | causal_mask = None
|
| | dropout_multiplier = None
|
| |
|
| |
|
| | cached_recurrent_kvq = decoder_state["recurrent_kvq"]
|
| | if cached_recurrent_kvq is not None:
|
| | assert cross_attention_kv is None
|
| | cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1])
|
| | del cached_recurrent_kvq
|
| |
|
| |
|
| | (decoder_state, keys, values) = self._next_decoder_state(
|
| | decoder_state, keys, values
|
| | )
|
| |
|
| |
|
| | assert keys.shape[1] == self.window_length
|
| | kq_relative_offset = self.window_length
|
| |
|
| | if not self.use_long_xl_architecture:
|
| | kqpos = position.relative_positions(
|
| | 1, self.window_length, offset=0
|
| | )
|
| | current_idx = decoder_state["current_index"]
|
| |
|
| |
|
| | kqpos = jnp.expand_dims(kqpos, axis=(0, 1))
|
| | kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1))
|
| |
|
| |
|
| | current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3))
|
| |
|
| | causal_mask = kqpos > self.window_length * 2 - current_idx
|
| | else:
|
| | logging.info("tlayer: windowed attention.")
|
| |
|
| |
|
| |
|
| | (prev_kvi, recurrent_state) = (
|
| | window_state
|
| | )
|
| |
|
| |
|
| | (num_queries, num_keys) = attention.sliding_attention_window_shape(
|
| | (keys, values, importance),
|
| | prev_kvi,
|
| | queries,
|
| | window_length=self.window_length,
|
| | )
|
| | kq_relative_offset = num_keys - num_queries
|
| |
|
| |
|
| |
|
| | if self.relative_positions is not None:
|
| | rel_position_bias = self.relative_positions(
|
| | num_queries, num_keys, bidirectional=False
|
| | )
|
| | else:
|
| | rel_position_bias = None
|
| |
|
| |
|
| | if self.use_causal_mask:
|
| | causal_mask = position.causal_mask(
|
| | num_queries, num_keys, window_length=self.window_length
|
| | )
|
| | else:
|
| | causal_mask = None
|
| |
|
| |
|
| |
|
| | if self.attn_dropout_rate > 0.0 and is_training:
|
| | dropout_rng = self.make_rng("dropout")
|
| | attn_shape = (self.num_heads, num_queries, num_keys)
|
| | dropout_multiplier = nn_components.dropout_multiplier_mask(
|
| | dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype
|
| | )
|
| | else:
|
| | dropout_multiplier = None
|
| |
|
| |
|
| |
|
| | (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode)
|
| | external_kv = self._query_external_memory(
|
| | keys,
|
| | values,
|
| | queries,
|
| | start_of_sequence=start_of_sequence,
|
| | mode=mode,
|
| | update_memory=decoder_state is None and update_memory,
|
| | )
|
| |
|
| | if (
|
| | self.memory is not None
|
| | and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN"
|
| | ):
|
| | external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype)
|
| | external_memory_bias = jnp.reshape(
|
| | external_memory_bias, (1, 1, num_heads, 1)
|
| | )
|
| | external_memory_bias = jax.nn.sigmoid(external_memory_bias)
|
| | else:
|
| | external_memory_bias = None
|
| |
|
| |
|
| |
|
| | if sequence_length < self.window_length:
|
| | num_windows = 1
|
| | elif sequence_length == self.window_length:
|
| | num_windows = 1
|
| | if self.use_long_xl_architecture:
|
| | assert prev_kvi is not None
|
| | else:
|
| | if not self.use_long_xl_architecture:
|
| | raise ValueError("Can only use sliding window with Transformer XL.")
|
| | num_windows = sequence_length // self.window_length
|
| | if (num_windows * self.window_length) != sequence_length:
|
| | raise ValueError(
|
| | f"Window length {self.window_length} must be a "
|
| | + f"multiple of sequence length {sequence_length}"
|
| | )
|
| | logging.info("tlayer: num_windows = %d.", num_windows)
|
| |
|
| |
|
| |
|
| | def single_window_attention(
|
| | carry: tuple[Array, Array], inputs_w: tuple[Array, Array]
|
| | ) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
|
| |
|
| |
|
| | nonlocal rel_position_bias
|
| | nonlocal causal_mask
|
| | nonlocal kq_relative_offset
|
| | nonlocal dropout_multiplier
|
| | nonlocal attention_scale_factors
|
| | nonlocal external_memory_bias
|
| | nonlocal cross_attention_kv
|
| |
|
| |
|
| |
|
| | (prev_kvi_w, rec_state) = carry
|
| | (kvqi_w, external_kv_w) = inputs_w
|
| |
|
| |
|
| |
|
| |
|
| | (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w)
|
| | (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w
|
| |
|
| |
|
| |
|
| | if rec_state is not None:
|
| | logging.info("tlayer: recurrent attention.")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | logging.info("tlayer: recurrent kvq.")
|
| | rec_kvq = self.recurrent_tbase.kvq(rec_state)
|
| | r_scale_factors = self.recurrent_tbase.attention_scale_factors()
|
| | (r_keys, r_values, r_queries, r_queries2) = rec_kvq
|
| |
|
| |
|
| | logging.info("tlayer: recurrent self-attention.")
|
| | r_attn_ys = attention.simple_attention(
|
| | r_keys,
|
| | r_values,
|
| | r_queries,
|
| | None,
|
| | scale_factor=r_scale_factors[0],
|
| | dtype=self.dtype,
|
| | )
|
| |
|
| | logging.info("tlayer: recurrent cross-attention.")
|
| | r_cross_attn_ys = attention.simple_attention(
|
| | keys_w,
|
| | values_w,
|
| | r_queries2,
|
| | importance_w,
|
| | scale_factor=r_scale_factors[1],
|
| | dtype=self.dtype,
|
| | )
|
| |
|
| |
|
| | logging.info("tlayer: recurrent ffn.")
|
| | next_rec_state = self.recurrent_tbase.post_attn_ffn(
|
| | rec_state, r_attn_ys, r_cross_attn_ys
|
| | )
|
| |
|
| |
|
| | assert cross_attention_kv is None
|
| | local_cross_attention_kv = (r_keys, r_values)
|
| | else:
|
| |
|
| | next_rec_state = None
|
| | local_cross_attention_kv = cross_attention_kv
|
| |
|
| |
|
| | if self.relative_position_type == "rotary":
|
| | logging.info(
|
| | "Using rotary position encodings (RoPE), offset = %d",
|
| | kq_relative_offset,
|
| | )
|
| | (keys_w, queries_w) = position.rotate_kq(
|
| | keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset
|
| | )
|
| |
|
| |
|
| | logging.info("tlayer: self-attention.")
|
| | attn_ys_w = attention.simple_attention(
|
| | keys_w,
|
| | values_w,
|
| | queries_w,
|
| | importance_w,
|
| | relative_position_bias=rel_position_bias,
|
| | scale_factor=attention_scale_factors[0],
|
| | causal_mask=causal_mask,
|
| | dropout_multiplier=dropout_multiplier,
|
| | dtype=self.dtype,
|
| | )
|
| |
|
| |
|
| | if external_kv_w is not None:
|
| | (external_keys_w, external_values_w) = external_kv_w
|
| | y_ext = attention.external_attention(
|
| | external_keys_w,
|
| | external_values_w,
|
| | queries_w,
|
| | scale_factor=attention_scale_factors[0],
|
| | )
|
| | if external_memory_bias is not None:
|
| | ebias = external_memory_bias
|
| | attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias)
|
| | elif self.memory_combine_with_local == "ADD":
|
| | attn_ys_w += y_ext
|
| | elif self.memory_combine_with_local == "STOP_FORWARD":
|
| | attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w))
|
| | else:
|
| | raise ValueError(
|
| | f"Unexpected setting: {self.memory_combine_with_local = }"
|
| | )
|
| |
|
| |
|
| | if local_cross_attention_kv is not None:
|
| | logging.info("tlayer: cross-attention.")
|
| | (c_keys, c_values) = local_cross_attention_kv
|
| |
|
| |
|
| | cross_attn_ys_w = attention.simple_attention(
|
| | c_keys,
|
| | c_values,
|
| | queries2_w,
|
| | None,
|
| | scale_factor=attention_scale_factors[1],
|
| | dtype=self.dtype,
|
| | )
|
| | else:
|
| | cross_attn_ys_w = None
|
| |
|
| |
|
| | return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w))
|
| |
|
| |
|
| |
|
| | if (
|
| | self.recurrent_attention
|
| | and 0 <= self.max_unrolled_windows
|
| | and self.max_unrolled_windows < num_windows
|
| | ):
|
| | logging.info("tlayer: force initialization of recurrent_tbase.")
|
| | self.recurrent_tbase.force_init(recurrent_state)
|
| |
|
| |
|
| |
|
| | initial_carry = (prev_kvi, recurrent_state)
|
| | kvqi = (keys, values, queries, queries2, importance)
|
| | attn_inputs = (kvqi, external_kv)
|
| | (next_carry, attn_outputs) = attention.split_and_scan(
|
| | single_window_attention,
|
| | initial_carry,
|
| | attn_inputs,
|
| | sections=num_windows,
|
| | axis=1,
|
| | max_unrolled_windows=self.max_unrolled_windows,
|
| | )
|
| | (attn_ys, cross_attn_ys) = attn_outputs
|
| |
|
| | logging.info("tlayer: End windows.")
|
| |
|
| |
|
| |
|
| | logging.info("tlayer: final FFN.")
|
| | ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys)
|
| |
|
| |
|
| | if self.compute_importance:
|
| | (batch_size, sequence_length, _) = ys.shape
|
| | importance_score = self.importance_layer(ys)
|
| | importance_score = importance_score.reshape((batch_size, sequence_length))
|
| | else:
|
| | importance_score = None
|
| |
|
| | next_window_state = next_carry if window_state is not None else None
|
| | viz_dict = {}
|
| | return (ys, importance_score, next_window_state, decoder_state, viz_dict)
|
| |
|
| | def init_decoder_state_vanilla(
|
| | self, sequence_length: int, start_of_sequence: Array
|
| | ) -> DecoderState:
|
| | """Initialize decoder state for autoregressive generation.
|
| |
|
| | Args:
|
| | sequence_length: The maximum length of the sequence to generate.
|
| | start_of_sequence: Array of boolean of shape (batch_size,) True if
|
| | starting a new sequence (with no prefix).
|
| |
|
| | Returns:
|
| | A state object that can be passed to __call__.
|
| | """
|
| |
|
| | if not self.use_causal_mask:
|
| | raise ValueError("Generator must have been trained with a causal mask.")
|
| |
|
| |
|
| | rel_position_bias = self.relative_positions(
|
| | 1, self.window_length, offset=self.window_length, bidirectional=False
|
| | )
|
| | rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1))
|
| |
|
| |
|
| |
|
| | num_keys = sequence_length + self.window_length
|
| | stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size)
|
| | stored_keys = jnp.zeros(stored_shape, dtype=self.dtype)
|
| | stored_values = jnp.zeros(stored_shape, dtype=self.dtype)
|
| |
|
| | recurrent_kvq = None
|
| | current_index = jnp.array([self.window_length] * self.batch_size)
|
| |
|
| | decoder_state_dict = {
|
| | "keys": stored_keys,
|
| | "values": stored_values,
|
| | "current_index": current_index,
|
| | "relative_position_bias": rel_position_bias,
|
| | "recurrent_kvq": recurrent_kvq,
|
| | }
|
| | return DecoderState(decoder_state_dict)
|
| |
|