| """ |
| An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 |
| Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 |
| """ |
|
|
| import functools |
| from typing import NamedTuple |
|
|
| import flax.linen as nn |
| import jax |
| import jax.lax as lax |
| import jax.numpy as jnp |
| from einops import rearrange |
|
|
| """ |
| Computing ffn blockwise without materializing the large hidden tensor, training |
| 4x longer sequences than the memory-efficient transformer. |
| Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 |
| """ |
| def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True): |
| |
| |
| |
| inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size) |
| def scan_ffn(remat_ffn, carry, hidden_states): |
| outputs = remat_ffn(hidden_states, deterministic=deterministic) |
| return carry, outputs |
| scan_axis = inputs.ndim - 2 |
| _, res = nn.scan( |
| scan_ffn, |
| variable_broadcast="params", |
| split_rngs={"params": False, "dropout": True}, |
| in_axes=scan_axis, |
| out_axes=scan_axis, |
| )(remat_ffn, None, inputs) |
| res = rearrange(res, 'b c n d -> b (c n) d') |
| return res |
|
|
|
|
| """ |
| Compute attention blockwise without materializing the full attention matrix, |
| initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021; |
| flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA |
| efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370 |
| Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x |
| longer sequences than memory-efficient/flash-attention and fusion of attention and FFN. |
| """ |
| def blockwise_attn( |
| query, key, value, |
| bias=None, |
| deterministic=True, |
| dropout_rng=None, |
| attn_pdrop=0.0, |
| causal=True, |
| query_chunk_size=2048, |
| key_chunk_size=2048, |
| dtype=jnp.float32, |
| policy=jax.checkpoint_policies.nothing_saveable(), |
| precision=None, |
| float32_logits=True, |
| prevent_cse=True, |
| ): |
| |
| |
| |
| |
| query = query / jnp.sqrt(query.shape[-1]).astype(dtype) |
| if float32_logits: |
| query = query.astype(jnp.float32) |
| key = key.astype(jnp.float32) |
|
|
| batch, q_len, num_heads, dim_per_head = query.shape |
| batch, kv_len, num_heads, dim_per_head = key.shape |
| batch, kv_len, num_heads, dim_per_head = value.shape |
|
|
| num_q = q_len // query_chunk_size |
| num_kv = kv_len // key_chunk_size |
| query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) |
| key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) |
| value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) |
|
|
| query = jnp.moveaxis(query, 1, 0) |
| key = jnp.moveaxis(key, 1, 0) |
| value = jnp.moveaxis(value, 1, 0) |
|
|
| if bias is not None: |
| for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): |
| assert bias_dim == 1 or bias_dim == broadcast_dim |
| if not deterministic and attn_pdrop > 0.0: |
| attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) |
| attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) |
| else: |
| attn_dropout = None |
|
|
| _chunk_bias_fn = functools.partial( |
| _chunk_attention_bias, |
| query_chunk_size, key_chunk_size, bias, deterministic, |
| attn_dropout, attn_pdrop, causal, dtype) |
|
|
| def scan_attention(args): |
| query_chunk, query_chunk_idx = args |
|
|
| @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) |
| def scan_kv_block(carry, args): |
| key_chunk, value_chunk, key_chunk_idx = args |
| (numerator, denominator, prev_max_score) = carry |
| attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) |
| bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) |
| bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) |
| attn_weights = attn_weights + bias_chunk |
|
|
| max_score = jnp.max(attn_weights, axis=-1, keepdims=True) |
| max_score = jnp.maximum(prev_max_score, max_score) |
| max_score = jax.lax.stop_gradient(max_score) |
| exp_weights = jnp.exp(attn_weights - max_score) |
| exp_values = jnp.einsum( |
| 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision |
| ) |
| correction = jnp.exp(prev_max_score - max_score) |
| numerator = numerator * correction + exp_values |
| denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) |
| return Carry(numerator, denominator, max_score), None |
|
|
| def skip_upper_half(carry, args): |
| key_chunk, value_chunk, key_chunk_idx = args |
| skip_block = jnp.array(False) |
| if causal: |
| skip_block = query_chunk_idx < key_chunk_idx |
| return jax.lax.cond( |
| skip_block, |
| lambda carry, args: (carry, None), |
| scan_kv_block, |
| carry, |
| args, |
| ) |
|
|
| init_carry = Carry( |
| jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), |
| jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), |
| (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), |
| ) |
| (numerator, denominator, max_score), _ = lax.scan( |
| skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) |
| ) |
| outputs = (numerator / denominator).astype(dtype) |
| return outputs |
|
|
| _, res = lax.scan( |
| lambda _, x: ((), scan_attention(x)), |
| (), xs=(query, jnp.arange(0, num_q)) |
| ) |
| res = rearrange(res, 'n b c h d -> b (n c) h d') |
| return res |
|
|
|
|
| class Carry(NamedTuple): |
| numerator: jax.Array |
| denominator: jax.Array |
| max_so_far: jax.Array |
|
|
|
|
| def _chunk_attention_bias(query_chunk_size, key_chunk_size, |
| bias, deterministic, attn_dropout, attn_pdrop, causal, |
| dtype, query_chunk_idx, key_chunk_idx): |
| query_offset = query_chunk_idx * query_chunk_size |
| key_offset = key_chunk_idx * key_chunk_size |
| chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) |
| if bias is not None: |
| chunk_bias = lax.dynamic_slice( |
| bias, |
| start_indices=(0, 0, query_offset, key_offset), |
| slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)), |
| ) |
|
|
| if causal: |
| query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0) |
| key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1) |
| offset = query_offset - key_offset |
| query_idx += offset |
| causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min |
| chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) |
|
|
| if not deterministic and attn_pdrop > 0.0: |
| attn_dropout_slice = lax.dynamic_slice( |
| attn_dropout, |
| start_indices=(0, 0, query_offset, key_offset), |
| slice_sizes=( |
| *attn_dropout.shape[:2], |
| min(attn_dropout.shape[-2], query_chunk_size), |
| min(attn_dropout.shape[-1], key_chunk_size), |
| ), |
| ) |
| chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min |
| return chunk_bias.astype(dtype) |
|
|
|
|
| if __name__ == '__main__': |
| |
| def reference_attn(query, key, value, causal, dtype): |
| query = query / jnp.sqrt(query.shape[-1]).astype(dtype) |
| logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) |
| if causal: |
| mask_value = jnp.finfo(logits.dtype).min |
| _, q_seq_len, _, _ = query.shape |
| _, kv_seq_len, _, _ = key.shape |
| mask_shape = (q_seq_len, kv_seq_len) |
| row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) |
| col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) |
| causal_mask = (row_ids < col_ids)[None, None, :, :] |
| logits = logits + jnp.where(causal_mask, mask_value, 0.0) |
| weights = jax.nn.softmax(logits, axis=-1) |
| out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) |
| return out |
|
|
| |
| shape = (1, 32, 8, 64) |
| query = jax.random.normal(jax.random.PRNGKey(0), shape) |
| key = jax.random.normal(jax.random.PRNGKey(1), shape) |
| value = jax.random.normal(jax.random.PRNGKey(2), shape) |
|
|
| causal = True |
| chunk_size = 4 |
| policy = jax.checkpoint_policies.nothing_saveable() |
|
|
| blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False) |
| reference = reference_attn(query, key, value, causal, 'float32') |
|
|
| assert jnp.allclose(reference, blockwise, atol=1e-6) |
|
|