Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Transformer attention functions.""" | |
| import typing | |
| from typing import Any, Callable, Mapping, NewType, Optional, Sequence, Tuple, Union | |
| from absl import logging | |
| from flax import linen as nn | |
| import jax | |
| import jax.numpy as jnp | |
| from transformer import nn_components | |
| from transformer import position | |
| Array = jnp.ndarray | |
| ArrayTree = Union[Array, Tuple["ArrayTree", ...]] | |
| DecoderState = NewType("DecoderState", Mapping[str, Array]) | |
| # Tuple of keys, values, importance. | |
| KVITuple = Tuple[Array, Array, Optional[Array]] | |
| # Tuple of keys, values, queries, queries2, importance. | |
| KVQITuple = Tuple[Array, Array, Array, Optional[Array], Optional[Array]] | |
| # Tuple of scale factors. See TransformerBase.attention_scale_factors(). | |
| AttnScaleTuple = Tuple[Optional[Array], Optional[Array]] | |
| def initial_kvi(shape: Sequence[int], use_importance: bool, dtype: Any): | |
| """Returns initial (zero) keys/values/i that can be passed to prev_kvi.""" | |
| z = jnp.zeros(shape, dtype=dtype) | |
| if use_importance: | |
| i = jnp.zeros((shape[0], shape[1]), dtype=dtype) # (bsize, window_length) | |
| else: | |
| i = None | |
| return (z, z, i) | |
| def concat_kvqi(kvqi: KVQITuple, prev_kvi: Optional[KVITuple]) -> ( | |
| Tuple[KVQITuple, Optional[KVITuple]]): | |
| """Concatenate previous keys,values with current keys,values. | |
| Args: | |
| kvqi: Current keys, values, queries, quieres2, importance. | |
| prev_kvi: Previous keys, values, importance. | |
| Returns: | |
| (kvqi: Concatenated (keys, values, queries, importance), | |
| next_kvi: Next (keys, values, importance)) (from kvqi) | |
| """ | |
| (keys, values, queries, queries2, importance) = kvqi | |
| # The current keys,values,importance will be passed to the next window. | |
| next_kvi = (keys, values, importance) | |
| (batch_size, _, num_heads, head_dim) = keys.shape # (b, _, h, d) | |
| if prev_kvi is None: | |
| return (kvqi, None) # If prev_kvi is None, next_kvi should be None. | |
| # Unpack prev_kvi and check shapes. | |
| (pkeys, pvalues, pimportance) = prev_kvi | |
| num_pkeys = pkeys.shape[1] | |
| assert pkeys.shape == (batch_size, num_pkeys, num_heads, head_dim) | |
| assert pkeys.shape == pvalues.shape | |
| if pimportance is not None: | |
| assert pimportance.shape == (batch_size, num_pkeys) | |
| # Concatenate keys and values. | |
| keys = jnp.concatenate([pkeys, keys], axis=1) # (b, k, h, d) | |
| values = jnp.concatenate([pvalues, values], axis=1) # (b, k, h, d) | |
| if importance is not None: | |
| assert pimportance is not None | |
| importance = jnp.concatenate([pimportance, importance], axis=1) # (b, k) | |
| logging.info("attn: importance = %r", importance) | |
| return ((keys, values, queries, queries2, importance), next_kvi) | |
| def simple_attention(keys: Array, | |
| values: Array, | |
| queries: Array, | |
| importance: Optional[Array], | |
| *, | |
| relative_position_bias: Optional[Array] = None, | |
| scale_factor: Optional[Array] = None, | |
| causal_mask: Optional[Array] = None, | |
| dropout_multiplier: Optional[Array] = None, | |
| dtype: Any = jnp.float32) -> Array: | |
| """Simple attention from a set of queries to a set of keys,values. | |
| Args: | |
| keys: of shape [batch_size, num_keys, num_heads, head_dim]. | |
| values: of shape [batch_size, num_keys, num_heads, head_dim]. | |
| queries: of shape [batch_size, num_queries, num_heads, head_dim]. | |
| importance: of shape [batch_size, num_keys]. | |
| *: ---- the following arguments are passed by keyword only ---- | |
| relative_position_bias: A positional attention matrix of shape | |
| [num_heads, num_queries, num_keys] | |
| scale_factor: Learned scale factor for use with normalized keys,queries | |
| of shape [num_heads] | |
| causal_mask: A boolean array of shape [num_heads, num_queries, num_keys] | |
| dropout_multiplier: A random mask of either 0.0 or 1.0/keep_prob, | |
| of shape [num_heads, num_queries, num_keys] | |
| dtype: data type to perform attention at. | |
| Returns: | |
| Attention outputs of shape [batch_size, num_queries, num_heads, head_size] | |
| """ | |
| # (batch_size, num_keys, num_heads, head_dim) | |
| (batch_size, num_keys, num_heads, head_dim) = keys.shape # (b, k, h, d) | |
| num_queries = queries.shape[1] | |
| assert keys.shape == values.shape | |
| assert queries.shape == (batch_size, num_queries, num_heads, head_dim) | |
| if importance is not None: | |
| assert importance.shape == (batch_size, num_keys) | |
| logging.info("attn: keys = %r", keys) | |
| logging.info("attn: queries = %r", queries) | |
| # Compute attention matrix. | |
| attn = jnp.einsum("...qhd,...khd->...hqk", queries, keys) # (b, h, q, k) | |
| logging.info("attn: content attn = %r", attn) | |
| # Apply relative position bias. | |
| if relative_position_bias is not None: | |
| logging.info("attn: pbias = %r", relative_position_bias) | |
| relative_position_bias = jnp.asarray(relative_position_bias, dtype=dtype) | |
| pbias = position.broadcast_mask(relative_position_bias, attn) | |
| attn = attn + pbias | |
| # Apply learned attention scale. | |
| if scale_factor is not None: | |
| logging.info("attn: learned attention scale: %s", scale_factor) | |
| # Broadcast scale over batch/keys/queries. | |
| scale_factor = jnp.asarray(scale_factor, dtype=dtype) | |
| scale_factor = scale_factor.reshape((1, num_heads, 1, 1)) | |
| attn = attn * scale_factor | |
| # Apply causal mask. | |
| if causal_mask is not None: | |
| causal_mask = position.broadcast_mask(causal_mask, attn) | |
| attn = jnp.where(causal_mask, attn, jnp.asarray(-1_000_000.0, dtype=dtype)) | |
| logging.info("attn: pre-softmax attn = %r", attn) | |
| # Normalize attention matrix with softmax. | |
| # min_x should be much smaller than minimum expected values in attn, but | |
| # much larger than the masked_out values created by the causal mask. That | |
| # way, if all tokens are masked out, then softmax will attend to nothing, | |
| # rather than attend to everything equally. | |
| min_x = jnp.asarray(-1000.0, dtype=dtype) | |
| attn = nn_components.safe_softmax(attn, axis=-1, min_x=min_x) # (b, h, q, k) | |
| # Apply dropout to attention matrix. | |
| if dropout_multiplier is not None: | |
| logging.debug("attn: drop = %r", dropout_multiplier) | |
| dropout_multiplier = jnp.asarray(dropout_multiplier, dtype=dtype) | |
| attn = attn * dropout_multiplier | |
| logging.info("attn: final attn = %r", attn) | |
| # Compute output -- values weighted by attention matrix. | |
| y = jnp.einsum("...hqk,...khd->...qhd", attn, values) # (b, q, h, d) | |
| logging.info("attn: y = %r", y) | |
| return y | |
| def external_attention(external_keys: Array, | |
| external_values: Array, | |
| queries: Array, | |
| *, | |
| scale_factor: Optional[Array] = None, | |
| dtype: Any = jnp.float32) -> Array: | |
| """Attention over (keys, values) retrieved from external memory. | |
| Args: | |
| external_keys: per-query keys from external memory, of shape | |
| [batch_size, num_queries, num_heads, num_neighbors, head_size] | |
| external_values: per-query values from external memory, of shape | |
| [batch_size, num_queries, num_heads, num_neighbors, head_size] | |
| queries: current queries, of shape: | |
| [batch_size, num_queries, num_heads, head_size] | |
| *: ---- the following arguments are passed by keyword only. --- | |
| scale_factor: Learned scale factor for use with normalized keys,queries | |
| of shape [num_heads] | |
| dtype: data type to perform attention at. | |
| Returns: | |
| Attention outputs of shape [batch_size, num_queries, num_heads, head_size] | |
| """ | |
| (batch_size, num_queries, num_heads, _, head_dim) = external_keys.shape | |
| assert queries.shape == (batch_size, num_queries, num_heads, head_dim) | |
| assert external_values.shape == external_keys.shape | |
| # Build attention matrix. | |
| logging.info("ext_attn: external keys = %r", external_keys) | |
| ext_attn = jnp.einsum("...qhd,...qhid->...hqi", queries, external_keys) | |
| logging.info("ext_attn: external_mem_attn: %s", ext_attn) | |
| if scale_factor is not None: | |
| scale_factor = jnp.asarray(scale_factor, dtype=dtype) | |
| scale_factor = scale_factor.reshape((1, num_heads, 1, 1)) | |
| logging.info("ext_attn: scaling external_mem_attn by %s", scale_factor) | |
| ext_attn = ext_attn * scale_factor | |
| ext_attn = nn.softmax(ext_attn, axis=-1) | |
| # Compute weighted sum of values. | |
| ext_y = jnp.einsum("...hqi,...qhid->...qhd", ext_attn, external_values) | |
| logging.info("ext_attn: ext_y = %r", ext_y) | |
| return ext_y | |
| def sliding_attention_window_shape(kvi: KVITuple, | |
| prev_kvi: Optional[KVITuple], | |
| queries: Array, | |
| window_length: int) -> Tuple[int, int]: | |
| """Return (num_queries, num_keys) for the sliding attention window.""" | |
| # Do error checking here. | |
| (keys, values, importance) = kvi | |
| assert keys.shape == queries.shape | |
| assert values.shape == queries.shape | |
| # Get sizes... | |
| (batch_size, sequence_length, _, _) = queries.shape | |
| if importance is not None: | |
| assert importance.ndim == 2 | |
| assert importance.shape == (batch_size, sequence_length) | |
| assert window_length > 0 | |
| if window_length >= sequence_length: | |
| # No sliding window. | |
| num_queries = sequence_length | |
| num_keys = sequence_length | |
| if prev_kvi is not None: | |
| num_keys += prev_kvi[0].shape[1] | |
| else: | |
| # Sliding window. | |
| if prev_kvi is not None: | |
| assert prev_kvi[0].shape[1] == window_length | |
| num_queries = window_length | |
| num_keys = window_length * 2 | |
| return (num_queries, num_keys) | |
| def split_tree(tree: ArrayTree, sections: int, axis: int = 0) -> ( | |
| Sequence[ArrayTree]): | |
| """Recursively splits a possibly nested tuple of arrays along the given axis. | |
| Args: | |
| tree: A nested tree of tuples and arrays. | |
| sections: The number of sections to split the tree into. | |
| axis: The axis to do the split on arrays. | |
| Returns: | |
| A list of trees, of length sections, where each has the same shape as the | |
| original, but with arrays of size 1/sections. | |
| """ | |
| if tree is None: | |
| return [None] * sections | |
| elif isinstance(tree, jnp.ndarray): | |
| return jnp.split(tree, sections, axis=axis) | |
| elif isinstance(tree, tuple): | |
| # Recursively split each element of the tuple into a list. | |
| branch_lists = [split_tree(tree_i, sections, axis=axis) for tree_i in tree] | |
| # Rearrange the tuple of lists into a list of tuples. | |
| return [tuple([brs[i] for brs in branch_lists]) for i in range(sections)] | |
| else: | |
| raise ValueError("Argument %r must be an ndarray or tuple." % tree) | |
| def concat_trees(tree_list: Sequence[ArrayTree], axis: int = 0) -> ArrayTree: | |
| """Merges a list of trees into a single tree by concatenating their elements. | |
| Args: | |
| tree_list: A list of trees, all of the same shape. | |
| axis: The axis to concatenate arrays on. | |
| Returns: | |
| A single tree, with the same shape as the trees in tree_list. | |
| """ | |
| # All trees in the list are required to have the same shape. | |
| # We return a tree with the same shape as each of the trees in the list, | |
| first_tree = tree_list[0] | |
| if first_tree is None: | |
| # Merge a list of None into a single None. | |
| for tree_i in tree_list: | |
| assert tree_i is None | |
| return None | |
| elif isinstance(first_tree, jnp.ndarray): | |
| # Concatenate a list of arrays. | |
| for tree_i in tree_list: | |
| assert isinstance(tree_i, jnp.ndarray) | |
| return jnp.concatenate(tree_list, axis=axis) | |
| elif isinstance(first_tree, tuple): | |
| # Reshape a list of tuples into a tuple of concatenated lists. | |
| for tree_i in tree_list: | |
| assert isinstance(tree_i, tuple) and len(tree_i) == len(first_tree) | |
| num_branches = len(first_tree) | |
| return tuple([concat_trees([tree[b] for tree in tree_list], axis=axis) | |
| for b in range(num_branches)]) | |
| else: | |
| raise ValueError("Argument %r must be an ndarray or tuple." % first_tree) | |
| def reshape_transpose_tree(tree: ArrayTree, sections: int, axis: int = 0) -> ( | |
| ArrayTree): | |
| """Reshape and transpose arrays so that the window is dimension 0.""" | |
| # We could use jax tree utils for this, but we do it the hard way so the | |
| # implementaiton can be compared with split_tree. | |
| if tree is None: | |
| return None | |
| elif isinstance(tree, jnp.ndarray): | |
| tree = typing.cast(Array, tree) # Tell type-checker about isinstance | |
| ndim = tree.ndim | |
| wlen = tree.shape[axis] // sections | |
| assert sections * wlen == tree.shape[axis] # Must be evenly divisible. | |
| # Break the axis dimension into sections * window_size | |
| arr = tree | |
| sh = list(arr.shape) | |
| nshape = sh[0:axis] + [sections, wlen] + sh[axis + 1:] | |
| arr = jnp.reshape(arr, nshape) | |
| # Transpose sections to be dimension 0. | |
| tdims = [axis] + list(range(0, axis)) + list(range(axis + 1, ndim + 1)) | |
| arr = jnp.transpose(arr, tdims) | |
| return arr | |
| elif isinstance(tree, tuple): | |
| return tuple([reshape_transpose_tree(b, sections, axis) for b in tree]) | |
| else: | |
| raise ValueError("Argument %r must be an ndarray or tuple." % tree) | |
| def transpose_reshape_tree(tree: ArrayTree, sections: int, axis: int = 0) -> ( | |
| ArrayTree): | |
| """Reshape and transpose arrays so that the window is dimension 0.""" | |
| # We could use jax tree utils for this, but we do it the hard way so the | |
| # implementaiton can be compared with split_tree. | |
| if tree is None: | |
| return None | |
| elif isinstance(tree, jnp.ndarray): | |
| tree = typing.cast(Array, tree) # Tell type-checker about isinstance | |
| ndim = tree.ndim - 1 # Input tree has 1 extra dimension on front. | |
| assert axis < ndim | |
| wlen = tree.shape[axis + 1] # Window length. | |
| # Transpose dimension 0 back to its proper place. | |
| arr = tree | |
| tdims = list(range(1, axis + 1)) + [0] + list(range(axis + 1, ndim + 1)) | |
| arr = jnp.transpose(arr, tdims) | |
| # Combine the sections and window_size dimensions. | |
| sh = list(arr.shape) | |
| nshape = sh[0:axis] + [sections * wlen] + sh[axis + 2:] | |
| arr = jnp.reshape(arr, nshape) | |
| return arr | |
| elif isinstance(tree, tuple): | |
| return tuple([transpose_reshape_tree(b, sections, axis) for b in tree]) | |
| else: | |
| raise ValueError("Argument %r must be an ndarray or tuple." % tree) | |
| def split_and_scan(fn: Callable[[ArrayTree, ArrayTree], | |
| Tuple[ArrayTree, ArrayTree]], | |
| carry: ArrayTree, input_arrays: ArrayTree, | |
| sections: int, axis: int = 0, | |
| max_unrolled_windows: int = -1) -> ( | |
| Tuple[ArrayTree, ArrayTree]): | |
| """Scan over a set of input arrays in chunks. | |
| Splits each array in 'input_arrays' into the number of chunks given by | |
| 'sections', and then loops over the chunks using a scan operation. | |
| Returns a concatenation of the results. | |
| Args: | |
| fn: A function from (carry, input_i) -> (carry, output_i). | |
| carry: The initial state for the scan, that will be passed from one | |
| iteration to the next. | |
| input_arrays: A nested tree of tuples of arrays. | |
| sections: The number of sections or chunks for the split. | |
| axis: The axis to split each array along. | |
| max_unrolled_windows: If 0 <= max_unrolled_windows < sections, | |
| use jax.lax.scan rather than unrolling the windows with a python loop. | |
| Returns: | |
| (carry, output) | |
| """ | |
| if sections == 1: | |
| logging.info("Single window, no scan.") | |
| return fn(carry, input_arrays) | |
| if axis < 0: | |
| raise ValueError(f"Axis must be positive. Got {axis}") | |
| logging.info("Scanning over %d windows", sections) | |
| if 0 <= max_unrolled_windows and max_unrolled_windows < sections: | |
| logging.info("Using jax.lax.scan.") | |
| in_arrs = reshape_transpose_tree(input_arrays, sections, axis) | |
| (carry, out_arrs) = jax.lax.scan(fn, carry, in_arrs) | |
| output_arrays = transpose_reshape_tree(out_arrs, sections, axis) | |
| return (carry, output_arrays) | |
| logging.info("Using unrolled for-loop.") | |
| in_list = split_tree(input_arrays, sections, axis=axis) | |
| out_list = [] | |
| for (k, in_chunk) in enumerate(in_list): | |
| logging.info("Processing window %d", k) | |
| (carry, out_chunk) = fn(carry, in_chunk) | |
| out_list.append(out_chunk) | |
| output_arrays = concat_trees(out_list, axis=axis) | |
| return (carry, output_arrays) | |