fix
Browse files- modeling_deepseek.py +271 -4
modeling_deepseek.py
CHANGED
|
@@ -651,12 +651,280 @@ class DeepseekV3Attention(nn.Module):
|
|
| 651 |
return attn_output, attn_weights, past_key_value
|
| 652 |
|
| 653 |
|
|
|
|
| 654 |
class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
| 655 |
"""
|
| 656 |
-
|
|
|
|
|
|
|
| 657 |
"""
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
|
| 662 |
ATTENTION_CLASSES = {
|
|
@@ -664,7 +932,6 @@ ATTENTION_CLASSES = {
|
|
| 664 |
"flash_attention_2": DeepseekV3FlashAttention2,
|
| 665 |
}
|
| 666 |
|
| 667 |
-
|
| 668 |
class DeepseekV3DecoderLayer(nn.Module):
|
| 669 |
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
| 670 |
super().__init__()
|
|
|
|
| 651 |
return attn_output, attn_weights, past_key_value
|
| 652 |
|
| 653 |
|
| 654 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
|
| 655 |
class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
| 656 |
"""
|
| 657 |
+
DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
|
| 658 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 659 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 660 |
"""
|
| 661 |
+
|
| 662 |
+
def __init__(self, *args, **kwargs):
|
| 663 |
+
super().__init__(*args, **kwargs)
|
| 664 |
+
|
| 665 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 666 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 667 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 668 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 669 |
+
|
| 670 |
+
def forward(
|
| 671 |
+
self,
|
| 672 |
+
hidden_states: torch.Tensor,
|
| 673 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 674 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 675 |
+
past_key_value: Optional[Cache] = None,
|
| 676 |
+
output_attentions: bool = False,
|
| 677 |
+
use_cache: bool = False,
|
| 678 |
+
**kwargs,
|
| 679 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 680 |
+
# DeepseekV3FlashAttention2 attention does not support output_attentions
|
| 681 |
+
if "padding_mask" in kwargs:
|
| 682 |
+
warnings.warn(
|
| 683 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# overwrite attention_mask with padding_mask
|
| 687 |
+
attention_mask = kwargs.pop("padding_mask")
|
| 688 |
+
|
| 689 |
+
output_attentions = False
|
| 690 |
+
|
| 691 |
+
bsz, q_len, _ = hidden_states.size()
|
| 692 |
+
|
| 693 |
+
if self.q_lora_rank is None:
|
| 694 |
+
q = self.q_proj(hidden_states)
|
| 695 |
+
else:
|
| 696 |
+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
| 697 |
+
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 698 |
+
q_nope, q_pe = torch.split(
|
| 699 |
+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Flash attention requires the input to have the shape
|
| 703 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 704 |
+
# therefore we just need to keep the original shape
|
| 705 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 706 |
+
compressed_kv, k_pe = torch.split(
|
| 707 |
+
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
| 708 |
+
)
|
| 709 |
+
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
| 710 |
+
kv = (
|
| 711 |
+
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
| 712 |
+
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 713 |
+
.transpose(1, 2)
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
k_nope, value_states = torch.split(
|
| 717 |
+
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
| 718 |
+
)
|
| 719 |
+
kv_seq_len = value_states.shape[-2]
|
| 720 |
+
|
| 721 |
+
kv_seq_len = value_states.shape[-2]
|
| 722 |
+
if past_key_value is not None:
|
| 723 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 724 |
+
|
| 725 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 726 |
+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
| 727 |
+
|
| 728 |
+
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 729 |
+
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
| 730 |
+
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
| 731 |
+
|
| 732 |
+
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
| 733 |
+
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
| 734 |
+
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
| 735 |
+
|
| 736 |
+
if self.q_head_dim != self.v_head_dim:
|
| 737 |
+
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
| 738 |
+
|
| 739 |
+
if past_key_value is not None:
|
| 740 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 741 |
+
key_states, value_states = past_key_value.update(
|
| 742 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 746 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 747 |
+
query_states = query_states.transpose(1, 2)
|
| 748 |
+
key_states = key_states.transpose(1, 2)
|
| 749 |
+
value_states = value_states.transpose(1, 2)
|
| 750 |
+
|
| 751 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 752 |
+
|
| 753 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 754 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 755 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 756 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 757 |
+
# in fp32. (DeepseekV3RMSNorm handles it correctly)
|
| 758 |
+
|
| 759 |
+
input_dtype = query_states.dtype
|
| 760 |
+
if input_dtype == torch.float32:
|
| 761 |
+
# Handle the case where the model is quantized
|
| 762 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
| 763 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 764 |
+
elif torch.is_autocast_enabled():
|
| 765 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 766 |
+
else:
|
| 767 |
+
target_dtype = (
|
| 768 |
+
self.q_proj.weight.dtype
|
| 769 |
+
if self.q_lora_rank is None
|
| 770 |
+
else self.q_a_proj.weight.dtype
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
logger.warning_once(
|
| 774 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 775 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 776 |
+
f" {target_dtype}."
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
query_states = query_states.to(target_dtype)
|
| 780 |
+
key_states = key_states.to(target_dtype)
|
| 781 |
+
value_states = value_states.to(target_dtype)
|
| 782 |
+
|
| 783 |
+
attn_output = self._flash_attention_forward(
|
| 784 |
+
query_states,
|
| 785 |
+
key_states,
|
| 786 |
+
value_states,
|
| 787 |
+
attention_mask,
|
| 788 |
+
q_len,
|
| 789 |
+
dropout=dropout_rate,
|
| 790 |
+
softmax_scale=self.softmax_scale,
|
| 791 |
+
)
|
| 792 |
+
if self.q_head_dim != self.v_head_dim:
|
| 793 |
+
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
| 794 |
+
|
| 795 |
+
attn_output = attn_output.reshape(
|
| 796 |
+
bsz, q_len, self.num_heads * self.v_head_dim
|
| 797 |
+
).contiguous()
|
| 798 |
+
attn_output = self.o_proj(attn_output)
|
| 799 |
+
|
| 800 |
+
if not output_attentions:
|
| 801 |
+
attn_weights = None
|
| 802 |
+
|
| 803 |
+
return attn_output, attn_weights, past_key_value
|
| 804 |
+
|
| 805 |
+
def _flash_attention_forward(
|
| 806 |
+
self,
|
| 807 |
+
query_states,
|
| 808 |
+
key_states,
|
| 809 |
+
value_states,
|
| 810 |
+
attention_mask,
|
| 811 |
+
query_length,
|
| 812 |
+
dropout=0.0,
|
| 813 |
+
softmax_scale=None,
|
| 814 |
+
):
|
| 815 |
+
"""
|
| 816 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 817 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 818 |
+
Args:
|
| 819 |
+
query_states (`torch.Tensor`):
|
| 820 |
+
Input query states to be passed to Flash Attention API
|
| 821 |
+
key_states (`torch.Tensor`):
|
| 822 |
+
Input key states to be passed to Flash Attention API
|
| 823 |
+
value_states (`torch.Tensor`):
|
| 824 |
+
Input value states to be passed to Flash Attention API
|
| 825 |
+
attention_mask (`torch.Tensor`):
|
| 826 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 827 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 828 |
+
dropout (`int`, *optional*):
|
| 829 |
+
Attention dropout
|
| 830 |
+
softmax_scale (`float`, *optional*):
|
| 831 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 832 |
+
"""
|
| 833 |
+
if not self._flash_attn_uses_top_left_mask:
|
| 834 |
+
causal = self.is_causal
|
| 835 |
+
else:
|
| 836 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
|
| 837 |
+
causal = self.is_causal and query_length != 1
|
| 838 |
+
|
| 839 |
+
# Contains at least one padding token in the sequence
|
| 840 |
+
if attention_mask is not None:
|
| 841 |
+
batch_size = query_states.shape[0]
|
| 842 |
+
(
|
| 843 |
+
query_states,
|
| 844 |
+
key_states,
|
| 845 |
+
value_states,
|
| 846 |
+
indices_q,
|
| 847 |
+
cu_seq_lens,
|
| 848 |
+
max_seq_lens,
|
| 849 |
+
) = self._upad_input(
|
| 850 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 854 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 855 |
+
|
| 856 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 857 |
+
query_states,
|
| 858 |
+
key_states,
|
| 859 |
+
value_states,
|
| 860 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 861 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 862 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 863 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 864 |
+
dropout_p=dropout,
|
| 865 |
+
softmax_scale=softmax_scale,
|
| 866 |
+
causal=causal,
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
attn_output = pad_input(
|
| 870 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
| 871 |
+
)
|
| 872 |
+
else:
|
| 873 |
+
attn_output = flash_attn_func(
|
| 874 |
+
query_states,
|
| 875 |
+
key_states,
|
| 876 |
+
value_states,
|
| 877 |
+
dropout,
|
| 878 |
+
softmax_scale=softmax_scale,
|
| 879 |
+
causal=causal,
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
return attn_output
|
| 883 |
+
|
| 884 |
+
def _upad_input(
|
| 885 |
+
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
| 886 |
+
):
|
| 887 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 888 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 889 |
+
|
| 890 |
+
key_layer = index_first_axis(
|
| 891 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 892 |
+
indices_k,
|
| 893 |
+
)
|
| 894 |
+
value_layer = index_first_axis(
|
| 895 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 896 |
+
indices_k,
|
| 897 |
+
)
|
| 898 |
+
if query_length == kv_seq_len:
|
| 899 |
+
query_layer = index_first_axis(
|
| 900 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
| 901 |
+
indices_k,
|
| 902 |
+
)
|
| 903 |
+
cu_seqlens_q = cu_seqlens_k
|
| 904 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 905 |
+
indices_q = indices_k
|
| 906 |
+
elif query_length == 1:
|
| 907 |
+
max_seqlen_in_batch_q = 1
|
| 908 |
+
cu_seqlens_q = torch.arange(
|
| 909 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 910 |
+
) # There is a memcpy here, that is very bad.
|
| 911 |
+
indices_q = cu_seqlens_q[:-1]
|
| 912 |
+
query_layer = query_layer.squeeze(1)
|
| 913 |
+
else:
|
| 914 |
+
# The -q_len: slice assumes left padding.
|
| 915 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 916 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
| 917 |
+
query_layer, attention_mask
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
return (
|
| 921 |
+
query_layer,
|
| 922 |
+
key_layer,
|
| 923 |
+
value_layer,
|
| 924 |
+
indices_q,
|
| 925 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 926 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 927 |
+
)
|
| 928 |
|
| 929 |
|
| 930 |
ATTENTION_CLASSES = {
|
|
|
|
| 932 |
"flash_attention_2": DeepseekV3FlashAttention2,
|
| 933 |
}
|
| 934 |
|
|
|
|
| 935 |
class DeepseekV3DecoderLayer(nn.Module):
|
| 936 |
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
| 937 |
super().__init__()
|