| from functools import partial |
|
|
| import torch |
| from torch.nn import functional as F |
| import xformers.ops as xops |
|
|
|
|
| def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype): |
| """Create a rectangular causal mask. |
| |
| This is especially useful when query length < key length, and ensures that the attention tensor comes from a tensor |
| that initially has dimensions that are a multiple of 8, as required by xformers. |
| |
| >>> get_rectangular_causal_mask((1, 1), 2, 2, "cpu", torch.float32) |
| tensor([[[[0., -inf], |
| [0., 0.]]]]) |
| >>> get_rectangular_causal_mask((1, 1), 3, 5, "cpu", torch.float32) |
| tensor([[[[0., 0., 0., -inf, -inf], |
| [0., 0., 0., 0., -inf], |
| [0., 0., 0., 0., 0.]]]]) |
| >>> get_rectangular_causal_mask((1, 1), 5, 5, "cpu", torch.float32) |
| tensor([[[[0., -inf, -inf, -inf, -inf], |
| [0., 0., -inf, -inf, -inf], |
| [0., 0., 0., -inf, -inf], |
| [0., 0., 0., 0., -inf], |
| [0., 0., 0., 0., 0.]]]]) |
| """ |
| |
| next_multiple_8 = (k_seq_len + 7) // 8 * 8 |
|
|
| mask = torch.ones((q_seq_len, k_seq_len), device=device, dtype=bool) |
| mask[:, -q_seq_len:] = torch.tril(mask[:, -q_seq_len:], diagonal=0) |
|
|
| output_mask = torch.zeros((*shape, q_seq_len, next_multiple_8), device=device, dtype=dtype) |
| output_mask[:, :, :, :k_seq_len].masked_fill_(~mask, torch.finfo(dtype).min) |
| return output_mask[:, :, :, :k_seq_len] |
|
|
|
|
| def apply_attention_mask_(bias, attention_mask, queries_dtype): |
| """Applies attention mask (e.g., from HuggingFace generate) to an attention bias mask in-place. |
| |
| Args: |
| bias (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len)) |
| attention_mask (torch.Tensor, shape (batch_size, sequence_len)) |
| queries_dtype: queries.dtype; used to get minimum value for masked indices. |
| |
| Returns: |
| bias_with_mask (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len)) |
| """ |
| |
| assert attention_mask.dim() == 2 |
| |
| mask_length = attention_mask.shape[-1] |
| |
| |
| padding_mask = bias[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) |
| min_dtype = torch.finfo(queries_dtype).min |
| bias[..., :mask_length] = bias[..., :mask_length].masked_fill(padding_mask, min_dtype) |
| |
| |
| |
| |
| bias.mul_(~torch.all(bias == min_dtype, dim=-1, keepdim=True)) |
|
|
|
|
| def xformers_attn(queries, keys, values, is_causal, attention_mask=None): |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| bias = None |
| if is_causal and queries.shape[1] == keys.shape[1] and attention_mask is None: |
| bias = xops.LowerTriangularMask() |
| elif is_causal and (queries.shape[1] > 1 or attention_mask is not None): |
| |
| batch, q_seq_len, heads, _ = queries.shape |
| k_seq_len = keys.shape[1] |
| bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) |
| if attention_mask is not None: |
| apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) |
| elif not is_causal and attention_mask is not None: |
| raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") |
| return xops.memory_efficient_attention(queries, keys, values, attn_bias=bias) |
|
|
|
|
| def torch_attn(queries, keys, values, is_causal, attention_mask=None): |
| |
| |
| |
| if is_causal and keys.shape[1] > queries.shape[1] > 1: |
| q_seq_len = queries.shape[1] |
| k_seq_len = keys.shape[1] |
| |
| |
| mask = get_rectangular_causal_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) |
| if attention_mask is not None: |
| apply_attention_mask_(mask, attention_mask, queries_dtype=queries.dtype) |
| return ( |
| F.scaled_dot_product_attention( |
| queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask |
| ) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
| else: |
| if attention_mask is None: |
| bias = None |
| |
| if queries.shape == 1: |
| is_causal = False |
| else: |
| if not is_causal: |
| raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") |
| |
| batch, q_seq_len, heads, _ = queries.shape |
| k_seq_len = keys.shape[1] |
| bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) |
| if attention_mask is not None: |
| apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) |
| |
| is_causal = False |
| return ( |
| F.scaled_dot_product_attention( |
| queries.transpose(1, 2), |
| keys.transpose(1, 2), |
| values.transpose(1, 2), |
| attn_mask=bias, |
| is_causal=is_causal, |
| ) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
|
|
|
|
| ATTN_ACTIVATIONS = { |
| "relu": F.relu, |
| "relu_squared": lambda x: torch.pow(F.relu(x), 2), |
| |
| "softplus": F.softplus, |
| "identity": lambda x: x, |
| "relu6": F.relu6, |
| "sigmoid": F.sigmoid, |
| "softmax": partial(F.softmax, dim=-1), |
| } |
|
|
| ATTN_SEQ_SCALARS = { |
| "max": lambda x: x, |
| |
| "avg": lambda x: (x - 1) / 2 + 1, |
| "none": lambda _: 1, |
| } |
|
|
|
|
| def custom_attn( |
| queries, |
| keys, |
| values, |
| attn_activation, |
| attn_seq_scalar, |
| alpha, |
| is_causal=False, |
| attention_mask=None, |
| ) -> torch.Tensor: |
| |
| |
| if attention_mask is not None: |
| raise NotImplementedError("attention_mask not yet implemented for custom_attn.") |
|
|
| batch, q_seq_len, heads, embed_dim = queries.shape |
| _, k_seq_len, _, _ = keys.shape |
|
|
| attn_bias = torch.zeros(batch, heads, q_seq_len, k_seq_len, device=queries.device, dtype=queries.dtype) |
| if is_causal and queries.shape[1] > 1: |
| attn_bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) |
|
|
| inner_scale = embed_dim**-0.5 |
| attn_weight = torch.einsum("bqhd,bkhd->bhqk", inner_scale * queries, keys) |
| attn_weight += attn_bias |
|
|
| |
| outter_scale = ATTN_SEQ_SCALARS[attn_seq_scalar](k_seq_len) ** -alpha |
| attn_weight = outter_scale * ATTN_ACTIVATIONS[attn_activation](attn_weight) |
|
|
| return torch.einsum("bhqk,bkhd->bqhd", attn_weight, values) |
|
|
|
|
| def get_attn_func( |
| attn_name, |
| attn_activation=None, |
| attn_seq_scalar=None, |
| alpha=None, |
| ): |
| if attn_name == "auto": |
| return xformers_attn if torch.cuda.is_available() else torch_attn |
| elif attn_name == "xformers_attn": |
| return xformers_attn |
| elif attn_name == "xformers_attn_variable_length": |
| |
| |
| |
| |
| return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() |
| elif attn_name == "torch_attn": |
| return torch_attn |
| elif attn_name == "custom_attn": |
| assert ( |
| attn_activation is not None and attn_seq_scalar is not None and alpha is not None |
| ), "must provide attn-activation, attn-seq-scalar, attn-seq-scalar-alpha" |
| return partial( |
| custom_attn, |
| attn_activation, |
| attn_seq_scalar, |
| alpha, |
| ) |
| else: |
| raise ValueError(f"Unsupported attn-name: {attn_name}") |
|
|