Upload attention.py
Browse files- attention.py +388 -0
attention.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Attention layers."""
|
| 2 |
+
import math
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import transformers
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from packaging import version
|
| 10 |
+
from torch import nn
|
| 11 |
+
from .fc import FC_CLASS_REGISTRY
|
| 12 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 13 |
+
|
| 14 |
+
def is_flash_v2_installed(v2_version: str='2.0.0'):
|
| 15 |
+
assert version.parse(v2_version) >= version.parse('2.0.0')
|
| 16 |
+
try:
|
| 17 |
+
import flash_attn as flash_attn
|
| 18 |
+
except:
|
| 19 |
+
return False
|
| 20 |
+
return version.parse(flash_attn.__version__) >= version.parse(v2_version)
|
| 21 |
+
|
| 22 |
+
def is_flash_v1_installed():
|
| 23 |
+
try:
|
| 24 |
+
import flash_attn as flash_attn
|
| 25 |
+
except:
|
| 26 |
+
return False
|
| 27 |
+
return version.parse(flash_attn.__version__) < version.parse('2.0.0')
|
| 28 |
+
|
| 29 |
+
def is_transformers_version_gte(hf_version: str) -> bool:
|
| 30 |
+
return version.parse(transformers.__version__) >= version.parse(hf_version)
|
| 31 |
+
|
| 32 |
+
def check_alibi_support(attention_impl: str) -> bool:
|
| 33 |
+
return attention_impl != 'flash' or is_flash_v2_installed(v2_version='v2.4.2')
|
| 34 |
+
if is_flash_v1_installed():
|
| 35 |
+
import transformers
|
| 36 |
+
transformers.utils.is_flash_attn_available = lambda : False
|
| 37 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
| 38 |
+
|
| 39 |
+
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
|
| 40 |
+
if original_is_causal and num_query_tokens != num_key_tokens:
|
| 41 |
+
if num_query_tokens != 1:
|
| 42 |
+
raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
|
| 43 |
+
else:
|
| 44 |
+
return False
|
| 45 |
+
return original_is_causal
|
| 46 |
+
|
| 47 |
+
def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 48 |
+
"""Perform repeat of kv heads along a particular dimension.
|
| 49 |
+
|
| 50 |
+
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
|
| 51 |
+
n_rep: amount of repetitions of kv_n_heads
|
| 52 |
+
Unlike torch.repeat_interleave, this function avoids allocating new memory.
|
| 53 |
+
"""
|
| 54 |
+
if n_rep == 1:
|
| 55 |
+
return hidden
|
| 56 |
+
(b, s, kv_n_heads, d) = hidden.shape
|
| 57 |
+
hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
|
| 58 |
+
return hidden.reshape(b, s, kv_n_heads * n_rep, d)
|
| 59 |
+
|
| 60 |
+
def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 61 |
+
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
| 62 |
+
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
| 63 |
+
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
| 64 |
+
if past_key_value is not None:
|
| 65 |
+
if len(past_key_value) != 0:
|
| 66 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
| 67 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
| 68 |
+
past_key_value = (k, v)
|
| 69 |
+
(b, _, s_q, d) = q.shape
|
| 70 |
+
s_k = k.size(-1)
|
| 71 |
+
if kv_n_heads > 1 and kv_n_heads < n_heads:
|
| 72 |
+
k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
| 73 |
+
v = repeat_kv_for_gqa(v.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
| 74 |
+
if softmax_scale is None:
|
| 75 |
+
softmax_scale = 1 / math.sqrt(d)
|
| 76 |
+
attn_weight = q.matmul(k) * softmax_scale
|
| 77 |
+
if attn_bias is not None:
|
| 78 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
| 79 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
| 80 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 81 |
+
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
| 82 |
+
raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
|
| 83 |
+
attn_weight = attn_weight + attn_bias
|
| 84 |
+
min_val = torch.finfo(q.dtype).min
|
| 85 |
+
if key_padding_mask is not None:
|
| 86 |
+
if attn_bias is not None:
|
| 87 |
+
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
| 88 |
+
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
| 89 |
+
if is_causal and (not q.size(2) == 1):
|
| 90 |
+
s = max(s_q, s_k)
|
| 91 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
|
| 92 |
+
causal_mask = causal_mask.tril()
|
| 93 |
+
causal_mask = causal_mask.to(torch.bool)
|
| 94 |
+
causal_mask = ~causal_mask
|
| 95 |
+
causal_mask = causal_mask[-s_q:, -s_k:]
|
| 96 |
+
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
| 97 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 98 |
+
if dropout_p:
|
| 99 |
+
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
| 100 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
| 101 |
+
out = rearrange(out, 'b h s d -> b s (h d)')
|
| 102 |
+
if needs_weights:
|
| 103 |
+
return (out, attn_weight, past_key_value)
|
| 104 |
+
return (out, None, past_key_value)
|
| 105 |
+
|
| 106 |
+
def check_valid_inputs(*tensors: torch.Tensor, valid_dtypes: Optional[list[torch.dtype]]=None):
|
| 107 |
+
if valid_dtypes is None:
|
| 108 |
+
valid_dtypes = [torch.float16, torch.bfloat16]
|
| 109 |
+
for tensor in tensors:
|
| 110 |
+
if tensor.dtype not in valid_dtypes:
|
| 111 |
+
raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
|
| 112 |
+
if not tensor.is_cuda:
|
| 113 |
+
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
| 114 |
+
|
| 115 |
+
def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False, should_repeat_kv_for_gqa: Optional[bool]=True, sliding_window_size: int=-1, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 116 |
+
if key_padding_mask is not None:
|
| 117 |
+
raise ValueError('key_padding_mask should be None for flash attn.')
|
| 118 |
+
del key_padding_mask
|
| 119 |
+
if flash_attn_padding_info is None:
|
| 120 |
+
raise ValueError('flash_attn_padding_info is required for flash attn.')
|
| 121 |
+
try:
|
| 122 |
+
from flash_attn import bert_padding, flash_attn_interface
|
| 123 |
+
except:
|
| 124 |
+
raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.6')
|
| 125 |
+
check_valid_inputs(query, key, value)
|
| 126 |
+
if past_key_value is not None:
|
| 127 |
+
if len(past_key_value) != 0:
|
| 128 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
| 129 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
| 130 |
+
past_key_value = (key, value)
|
| 131 |
+
if attn_bias is not None:
|
| 132 |
+
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
| 133 |
+
(batch_size, seqlen) = query.shape[:2]
|
| 134 |
+
indices_q = flash_attn_padding_info['indices_q']
|
| 135 |
+
indices_k = flash_attn_padding_info['indices_k']
|
| 136 |
+
indices_v = flash_attn_padding_info['indices_v']
|
| 137 |
+
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
|
| 138 |
+
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
|
| 139 |
+
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
|
| 140 |
+
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']
|
| 141 |
+
query_unpad = bert_padding.index_first_axis(rearrange(query, 'b s ... -> (b s) ...'), indices_q)
|
| 142 |
+
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
|
| 143 |
+
key_unpad = bert_padding.index_first_axis(rearrange(key, 'b s ... -> (b s) ...'), indices_k)
|
| 144 |
+
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
| 145 |
+
value_unpad = bert_padding.index_first_axis(rearrange(value, 'b s ... -> (b s) ...'), indices_v)
|
| 146 |
+
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
| 147 |
+
if kv_n_heads < n_heads and (not is_flash_v2_installed()) and (not should_repeat_kv_for_gqa):
|
| 148 |
+
raise ValueError('For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.')
|
| 149 |
+
if should_repeat_kv_for_gqa:
|
| 150 |
+
if kv_n_heads == 1:
|
| 151 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
| 152 |
+
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
| 153 |
+
elif kv_n_heads < n_heads:
|
| 154 |
+
key_unpad = repeat_kv_for_gqa(key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
|
| 155 |
+
value_unpad = repeat_kv_for_gqa(value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)
|
| 156 |
+
dropout_p = dropout_p if training else 0.0
|
| 157 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 158 |
+
if is_flash_v1_installed():
|
| 159 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
| 160 |
+
elif is_flash_v2_installed():
|
| 161 |
+
alibi_kwargs = {}
|
| 162 |
+
if check_alibi_support('flash'):
|
| 163 |
+
alibi_kwargs = {'alibi_slopes': alibi_slopes}
|
| 164 |
+
elif alibi_slopes is not None:
|
| 165 |
+
raise ValueError('alibi_slopes is only supported for flash-attn>=2.4.2')
|
| 166 |
+
output_unpad = flash_attn_interface.flash_attn_varlen_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, window_size=(sliding_window_size, sliding_window_size), **alibi_kwargs)
|
| 167 |
+
else:
|
| 168 |
+
raise RuntimeError('flash-attn==1.0.9 or flash-attn==2.4.2 is required.')
|
| 169 |
+
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
| 170 |
+
return (output, None, past_key_value)
|
| 171 |
+
|
| 172 |
+
def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 173 |
+
try:
|
| 174 |
+
from .flash_attn_triton import flash_attn_func
|
| 175 |
+
except:
|
| 176 |
+
_installed = False
|
| 177 |
+
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
| 178 |
+
_installed = True
|
| 179 |
+
try:
|
| 180 |
+
from flash_attn.flash_attn_triton import flash_attn_func
|
| 181 |
+
except:
|
| 182 |
+
_installed = False
|
| 183 |
+
if not _installed:
|
| 184 |
+
raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU ' + 'and `pip install .[gpu]` if installing from llm-foundry source or ' + '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` ' + 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). ' + 'Note: (1) requires you have CMake and PyTorch already installed.')
|
| 185 |
+
check_valid_inputs(query, key, value)
|
| 186 |
+
if past_key_value is not None:
|
| 187 |
+
if len(past_key_value) != 0:
|
| 188 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
| 189 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
| 190 |
+
past_key_value = (key, value)
|
| 191 |
+
if attn_bias is not None:
|
| 192 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
| 193 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
| 194 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 195 |
+
if dropout_p:
|
| 196 |
+
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
| 197 |
+
dropout_p = dropout_p if training else 0.0
|
| 198 |
+
if needs_weights:
|
| 199 |
+
raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
|
| 200 |
+
if key_padding_mask is not None:
|
| 201 |
+
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
| 202 |
+
(b_size, s_k) = key_padding_mask.shape[:2]
|
| 203 |
+
if attn_bias is None:
|
| 204 |
+
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
| 205 |
+
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
| 206 |
+
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
|
| 207 |
+
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
|
| 208 |
+
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
|
| 209 |
+
if kv_n_heads == 1:
|
| 210 |
+
key = key.repeat(1, 1, n_heads, 1)
|
| 211 |
+
value = value.repeat(1, 1, n_heads, 1)
|
| 212 |
+
elif kv_n_heads < n_heads:
|
| 213 |
+
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
|
| 214 |
+
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
|
| 215 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 216 |
+
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
| 217 |
+
output = attn_output.view(*attn_output.shape[:2], -1)
|
| 218 |
+
return (output, None, past_key_value)
|
| 219 |
+
|
| 220 |
+
class GroupedQueryAttention(nn.Module):
|
| 221 |
+
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
|
| 222 |
+
|
| 223 |
+
and Multi-query attention (MQA).
|
| 224 |
+
|
| 225 |
+
This allows the user to set a variable of number of kv_n_heads, rather than
|
| 226 |
+
just n_heads or 1, as in MHA and MQA. Using torch or triton attention
|
| 227 |
+
implementation enables user to also use additive bias.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.attn_impl = attn_impl
|
| 233 |
+
self.clip_qkv = clip_qkv
|
| 234 |
+
self.qk_ln = qk_ln
|
| 235 |
+
self.qk_gn = qk_gn
|
| 236 |
+
self.d_model = d_model
|
| 237 |
+
self.n_heads = n_heads
|
| 238 |
+
self.kv_n_heads = kv_n_heads
|
| 239 |
+
self.sliding_window_size = sliding_window_size
|
| 240 |
+
self.head_dim = d_model // n_heads
|
| 241 |
+
if self.kv_n_heads <= 0:
|
| 242 |
+
raise ValueError('kv_n_heads should be greater than zero.')
|
| 243 |
+
if self.kv_n_heads > self.n_heads:
|
| 244 |
+
raise ValueError('The number of KV heads should be less than or equal to Q heads.')
|
| 245 |
+
if self.n_heads % self.kv_n_heads != 0:
|
| 246 |
+
raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
|
| 247 |
+
if qk_ln and qk_gn:
|
| 248 |
+
raise ValueError('Only one of qk_ln and qk_gn can be set to True.')
|
| 249 |
+
self.softmax_scale = softmax_scale
|
| 250 |
+
if self.softmax_scale is None:
|
| 251 |
+
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
| 252 |
+
self.attn_dropout_p = attn_pdrop
|
| 253 |
+
fc_kwargs: dict[str, Any] = {'bias': bias}
|
| 254 |
+
if fc_type != 'te':
|
| 255 |
+
fc_kwargs['device'] = device
|
| 256 |
+
self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
|
| 257 |
+
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
|
| 258 |
+
self.Wqkv._fused = (0, fuse_splits)
|
| 259 |
+
if self.qk_ln or self.qk_gn:
|
| 260 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
| 261 |
+
norm_size = self.head_dim if qk_gn else d_model
|
| 262 |
+
self.q_ln = norm_class(norm_size, device=device)
|
| 263 |
+
if qk_ln:
|
| 264 |
+
norm_size = self.head_dim * kv_n_heads
|
| 265 |
+
self.k_ln = norm_class(norm_size, device=device)
|
| 266 |
+
if self.attn_impl == 'flash':
|
| 267 |
+
self.attn_fn = flash_attn_fn
|
| 268 |
+
elif self.attn_impl == 'triton':
|
| 269 |
+
self.attn_fn = triton_flash_attn_fn
|
| 270 |
+
elif self.attn_impl == 'torch':
|
| 271 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
| 272 |
+
else:
|
| 273 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
| 274 |
+
self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
|
| 275 |
+
self.out_proj._is_residual = True
|
| 276 |
+
|
| 277 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[dict]=None, is_causal: bool=True, needs_weights: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 278 |
+
qkv = self.Wqkv(x)
|
| 279 |
+
if self.clip_qkv:
|
| 280 |
+
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
| 281 |
+
(query, key, value) = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
|
| 282 |
+
key_padding_mask = attention_mask
|
| 283 |
+
if self.qk_ln or self.qk_gn:
|
| 284 |
+
(q_shape, k_shape) = (query.shape, key.shape)
|
| 285 |
+
if self.qk_gn:
|
| 286 |
+
(b, s) = query.shape[:2]
|
| 287 |
+
query = query.view(b, s, self.n_heads, -1)
|
| 288 |
+
key = key.view(b, s, self.kv_n_heads, -1)
|
| 289 |
+
dtype = query.dtype
|
| 290 |
+
query = self.q_ln(query).to(dtype).view(q_shape)
|
| 291 |
+
key = self.k_ln(key).to(dtype).view(k_shape)
|
| 292 |
+
if rotary_emb_w_meta_info is not None:
|
| 293 |
+
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
| 294 |
+
seq_len = rotary_emb_w_meta_info['seq_len']
|
| 295 |
+
offset_info = rotary_emb_w_meta_info['offset_info']
|
| 296 |
+
(bsz, seqlen) = query.shape[:2]
|
| 297 |
+
query = query.view(bsz, seqlen, -1, self.head_dim)
|
| 298 |
+
key = key.view(bsz, seqlen, -1, self.head_dim)
|
| 299 |
+
if rotary_emb_w_meta_info['impl'] == 'dail':
|
| 300 |
+
value = value.view(bsz, seqlen, -1, self.head_dim)
|
| 301 |
+
kv = torch.stack([key, value], dim=2)
|
| 302 |
+
(query, kv) = rotary_emb(query, kv, seqlen_offset=offset_info, max_seqlen=seq_len)
|
| 303 |
+
[key, value] = torch.unbind(kv, dim=2)
|
| 304 |
+
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
| 305 |
+
elif rotary_emb_w_meta_info['impl'] == 'hf':
|
| 306 |
+
(cos, sin) = rotary_emb(value, seq_len)
|
| 307 |
+
if is_transformers_version_gte('4.36'):
|
| 308 |
+
(query, key) = apply_rotary_pos_emb(query, key, cos, sin, offset_info, unsqueeze_dim=2)
|
| 309 |
+
else:
|
| 310 |
+
query = query.transpose(1, 2)
|
| 311 |
+
key = key.transpose(1, 2)
|
| 312 |
+
(query, key) = apply_rotary_pos_emb(query, key, cos, sin, offset_info)
|
| 313 |
+
query = query.transpose(1, 2)
|
| 314 |
+
key = key.transpose(1, 2)
|
| 315 |
+
query = query.view(bsz, seqlen, self.d_model)
|
| 316 |
+
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
| 317 |
+
extra_attn_kwargs = {}
|
| 318 |
+
if self.attn_impl == 'flash':
|
| 319 |
+
key_padding_mask = None
|
| 320 |
+
extra_attn_kwargs = {'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info}
|
| 321 |
+
(context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, **extra_attn_kwargs)
|
| 322 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
| 323 |
+
|
| 324 |
+
class MultiheadAttention(GroupedQueryAttention):
|
| 325 |
+
"""Multi-head self attention.
|
| 326 |
+
|
| 327 |
+
Using torch or triton attention implementation enables user to also use
|
| 328 |
+
additive bias.
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
|
| 332 |
+
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
|
| 333 |
+
|
| 334 |
+
class MultiQueryAttention(GroupedQueryAttention):
|
| 335 |
+
"""Multi-Query self attention.
|
| 336 |
+
|
| 337 |
+
Using torch or triton attention implementation enables user to also use
|
| 338 |
+
additive bias.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
|
| 342 |
+
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
|
| 343 |
+
|
| 344 |
+
def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
|
| 345 |
+
if attn_impl == 'flash':
|
| 346 |
+
return None
|
| 347 |
+
elif attn_impl in ['torch', 'triton']:
|
| 348 |
+
if alibi:
|
| 349 |
+
if (prefix_lm or not causal) or use_sequence_id:
|
| 350 |
+
return (1, n_heads, seq_len, seq_len)
|
| 351 |
+
return (1, n_heads, 1, seq_len)
|
| 352 |
+
elif prefix_lm or use_sequence_id:
|
| 353 |
+
return (1, 1, seq_len, seq_len)
|
| 354 |
+
return None
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
| 357 |
+
|
| 358 |
+
def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
|
| 359 |
+
if attn_impl == 'flash':
|
| 360 |
+
return None
|
| 361 |
+
elif attn_impl in ['torch', 'triton']:
|
| 362 |
+
if alibi:
|
| 363 |
+
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
| 364 |
+
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
| 365 |
+
return attn_bias
|
| 366 |
+
else:
|
| 367 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
| 368 |
+
|
| 369 |
+
def gen_slopes(n_heads: int, alibi_bias_max: int=8, device: Optional[torch.device]=None, return_1d: bool=False) -> torch.Tensor:
|
| 370 |
+
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
| 371 |
+
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
| 372 |
+
m = m.mul(alibi_bias_max / _n_heads)
|
| 373 |
+
slopes = 1.0 / torch.pow(2, m)
|
| 374 |
+
if _n_heads != n_heads:
|
| 375 |
+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
| 376 |
+
if return_1d:
|
| 377 |
+
return slopes
|
| 378 |
+
return slopes.view(1, n_heads, 1, 1)
|
| 379 |
+
|
| 380 |
+
def build_alibi_bias(n_heads: int, seq_len: int, full: bool=False, alibi_bias_max: int=8, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None) -> torch.Tensor:
|
| 381 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
|
| 382 |
+
if full:
|
| 383 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
|
| 384 |
+
alibi_bias = alibi_bias.abs().mul(-1)
|
| 385 |
+
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
| 386 |
+
alibi_bias = alibi_bias * slopes
|
| 387 |
+
return alibi_bias.to(dtype=dtype)
|
| 388 |
+
ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention, 'grouped_query_attention': GroupedQueryAttention}
|