fxmeng commited on
Commit
4e1ffb9
·
verified ·
1 Parent(s): de60045

Upload mla.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mla.py +149 -0
mla.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from transformers.cache_utils import Cache
8
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
9
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
10
+ from transformers.processing_utils import Unpack
11
+
12
+ from transformers.models.gemma2.modeling_gemma2 import (
13
+ eager_attention_forward, # for supporting softcap
14
+ logger
15
+ )
16
+ from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
17
+ apply_rotary_pos_emb_interleave,
18
+ DeepseekV3RMSNorm
19
+ )
20
+
21
+
22
+ class MLAAttention(nn.Module):
23
+ """
24
+ Modified from `transformers.models.llama.modeling_deepseek_v3.DeepseekV3Attention`
25
+ add support for attention bias and softcapping
26
+ """
27
+ def __init__(self, config, layer_idx: int):
28
+
29
+ super().__init__()
30
+ self.config = config
31
+ self.layer_idx = layer_idx
32
+
33
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
34
+ self.attention_dropout = config.attention_dropout
35
+ self.num_heads = config.num_attention_heads
36
+ self.rope_theta = config.rope_theta
37
+ self.q_lora_rank = config.q_lora_rank
38
+ self.kv_lora_rank = config.kv_lora_rank
39
+ self.qk_rope_head_dim = config.qk_rope_head_dim
40
+ self.qk_nope_head_dim = config.qk_nope_head_dim
41
+ self.v_head_dim = config.v_head_dim
42
+ self.qk_head_dim = config.qk_head_dim
43
+
44
+ self.qk_latent_layernorm = getattr(config, "qk_latent_layernorm", True)
45
+
46
+ self.is_causal = True
47
+ if self.q_lora_rank is None:
48
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=config.attention_bias)
49
+ else:
50
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=False)
51
+ if self.qk_latent_layernorm:
52
+ self.q_a_layernorm = DeepseekV3RMSNorm(self.q_lora_rank)
53
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=config.attention_bias)
54
+
55
+ self.kv_a_proj_with_mqa = nn.Linear(
56
+ config.hidden_size,
57
+ self.kv_lora_rank + self.qk_rope_head_dim,
58
+ bias=config.attention_bias,
59
+ )
60
+ if self.qk_latent_layernorm:
61
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
62
+ self.kv_b_proj = nn.Linear(
63
+ self.kv_lora_rank,
64
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
65
+ bias=False,
66
+ )
67
+
68
+ self.o_proj = nn.Linear(
69
+ self.num_heads * self.v_head_dim,
70
+ config.hidden_size,
71
+ bias=False,
72
+ )
73
+
74
+ self.scaling = self.qk_head_dim**-0.5
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
80
+ attention_mask: Optional[torch.Tensor],
81
+ past_key_value: Optional[Cache] = None,
82
+ cache_position: Optional[torch.LongTensor] = None,
83
+ **kwargs: Unpack[FlashAttentionKwargs],
84
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
85
+ batch_size, seq_length = hidden_states.shape[:-1]
86
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
87
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
88
+
89
+ if self.q_lora_rank is None:
90
+ q_states = self.q_proj(hidden_states)
91
+ elif self.qk_latent_layernorm:
92
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
93
+ else:
94
+ q_states = self.q_b_proj(self.q_a_proj(hidden_states))
95
+ q_states = q_states.view(query_shape).transpose(1, 2)
96
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
97
+
98
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
99
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
100
+
101
+ if self.qk_latent_layernorm:
102
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
103
+ else:
104
+ k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2)
105
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
106
+
107
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
108
+
109
+ cos, sin = position_embeddings
110
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
111
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
112
+
113
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
114
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
115
+
116
+ if past_key_value is not None:
117
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
118
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
119
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
120
+
121
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
122
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
123
+
124
+ attention_interface = eager_attention_forward
125
+ if self.config._attn_implementation != "eager":
126
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
127
+ logger.warning_once(
128
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
129
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
130
+ )
131
+ else:
132
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
133
+
134
+ attn_output, attn_weights = attention_interface(
135
+ self,
136
+ query_states,
137
+ key_states,
138
+ value_states,
139
+ attention_mask,
140
+ dropout=0.0 if not self.training else self.attention_dropout,
141
+ scaling=self.scaling,
142
+ softcap=getattr(self.config, "attn_logit_softcapping", None),
143
+ **kwargs,
144
+ )
145
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
146
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
147
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
148
+ attn_output = self.o_proj(attn_output)
149
+ return attn_output, attn_weights