cpatonn commited on
Commit
59eb5df
·
verified ·
1 Parent(s): c929050

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +1028 -0
modeling_deepseek.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_deepseek_v3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ import math
8
+ from functools import partial
9
+ from typing import Callable, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
17
+ from transformers.generation import GenerationMixin
18
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
19
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
20
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
21
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
22
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
23
+ from transformers.processing_utils import Unpack
24
+ from transformers.utils import (
25
+ LossKwargs,
26
+ add_start_docstrings,
27
+ add_start_docstrings_to_model_forward,
28
+ can_return_tuple,
29
+ is_torch_flex_attn_available,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.utils.deprecation import deprecate_kwarg
34
+ from .configuration_deepseek import DeepseekV3Config
35
+
36
+
37
+ if is_torch_flex_attn_available():
38
+ from torch.nn.attention.flex_attention import BlockMask
39
+
40
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
45
+
46
+
47
+ class DeepseekV3RMSNorm(nn.Module):
48
+ def __init__(self, hidden_size, eps=1e-6):
49
+ """
50
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
51
+ """
52
+ super().__init__()
53
+ self.weight = nn.Parameter(torch.ones(hidden_size))
54
+ self.variance_epsilon = eps
55
+
56
+ def forward(self, hidden_states):
57
+ input_dtype = hidden_states.dtype
58
+ hidden_states = hidden_states.to(torch.float32)
59
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
61
+ return self.weight * hidden_states.to(input_dtype)
62
+
63
+ def extra_repr(self):
64
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
65
+
66
+
67
+ class DeepseekV3RotaryEmbedding(nn.Module):
68
+ def __init__(self, config: DeepseekV3Config, device=None):
69
+ super().__init__()
70
+ # BC: "rope_type" was originally "type"
71
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
72
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
73
+ else:
74
+ self.rope_type = "default"
75
+ self.max_seq_len_cached = config.max_position_embeddings
76
+ self.original_max_seq_len = config.max_position_embeddings
77
+
78
+ self.config = config
79
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
80
+
81
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
82
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
83
+ self.original_inv_freq = self.inv_freq
84
+
85
+ @torch.no_grad()
86
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
87
+ def forward(self, x, position_ids):
88
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
89
+ position_ids_expanded = position_ids[:, None, :].float()
90
+
91
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
92
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
93
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
94
+ emb = torch.cat((freqs, freqs), dim=-1)
95
+ cos = emb.cos() * self.attention_scaling
96
+ sin = emb.sin() * self.attention_scaling
97
+
98
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
99
+
100
+
101
+ class DeepseekV3MLP(nn.Module):
102
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
103
+ super().__init__()
104
+ self.config = config
105
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
106
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
107
+
108
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
109
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
110
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
111
+ self.act_fn = ACT2FN[config.hidden_act]
112
+
113
+ def forward(self, x):
114
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
115
+ return down_proj
116
+
117
+
118
+ class DeepseekV3TopkRouter(nn.Module):
119
+ def __init__(self, config):
120
+ super().__init__()
121
+ self.config = config
122
+ self.top_k = config.num_experts_per_tok
123
+ self.n_routed_experts = config.n_routed_experts
124
+ self.routed_scaling_factor = config.routed_scaling_factor
125
+ self.n_group = config.n_group
126
+ self.topk_group = config.topk_group
127
+ self.norm_topk_prob = config.norm_topk_prob
128
+
129
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
130
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
131
+
132
+ @torch.no_grad()
133
+ def get_topk_indices(self, scores):
134
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
135
+ group_scores = (
136
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
137
+ .topk(2, dim=-1)[0]
138
+ .sum(dim=-1)
139
+ )
140
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
141
+ group_mask = torch.zeros_like(group_scores)
142
+ group_mask.scatter_(1, group_idx, 1)
143
+ score_mask = (
144
+ group_mask.unsqueeze(-1)
145
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
146
+ .reshape(-1, self.n_routed_experts)
147
+ )
148
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
149
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
150
+ return topk_indices
151
+
152
+ def forward(self, hidden_states):
153
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
154
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
155
+ scores = router_logits.sigmoid()
156
+ topk_indices = self.get_topk_indices(scores)
157
+ topk_weights = scores.gather(1, topk_indices)
158
+ if self.norm_topk_prob:
159
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
160
+ topk_weights /= denominator
161
+ topk_weights = topk_weights * self.routed_scaling_factor
162
+ return topk_indices, topk_weights
163
+
164
+
165
+ class DeepseekV3MoE(nn.Module):
166
+ """
167
+ A mixed expert module containing shared experts.
168
+ """
169
+
170
+ def __init__(self, config):
171
+ super().__init__()
172
+ self.config = config
173
+ self.experts = nn.ModuleList(
174
+ [
175
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
176
+ for _ in range(config.n_routed_experts)
177
+ ]
178
+ )
179
+ self.gate = DeepseekV3TopkRouter(config)
180
+ self.shared_experts = DeepseekV3MLP(
181
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
182
+ )
183
+
184
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
185
+ r"""
186
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
187
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
188
+ """
189
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
190
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
191
+ expert_mask = expert_mask.permute(2, 0, 1)
192
+
193
+ for expert_idx in range(len(self.experts)):
194
+ expert = self.experts[expert_idx]
195
+ mask = expert_mask[expert_idx]
196
+ token_indices, weight_indices = torch.where(mask)
197
+
198
+ if token_indices.numel() > 0:
199
+ expert_weights = topk_weights[token_indices, weight_indices]
200
+ expert_input = hidden_states[token_indices]
201
+ expert_output = expert(expert_input)
202
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
203
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
204
+
205
+ # in original deepseek, the output of the experts are gathered once we leave this module
206
+ # thus the moe module is itelsf an IsolatedParallel module
207
+ # and all expert are "local" meaning we shard but we don't gather
208
+ return final_hidden_states.type(hidden_states.dtype)
209
+
210
+ def forward(self, hidden_states):
211
+ residuals = hidden_states
212
+ orig_shape = hidden_states.shape
213
+ topk_indices, topk_weights = self.gate(hidden_states)
214
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
215
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
216
+ hidden_states = hidden_states + self.shared_experts(residuals)
217
+ return hidden_states
218
+
219
+
220
+ def rotate_half(x):
221
+ """Rotates half the hidden dims of the input."""
222
+ x1 = x[..., : x.shape[-1] // 2]
223
+ x2 = x[..., x.shape[-1] // 2 :]
224
+ return torch.cat((-x2, x1), dim=-1)
225
+
226
+
227
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
228
+ """Applies Rotary Position Embedding to the query and key tensors.
229
+
230
+ Args:
231
+ q (`torch.Tensor`): The query tensor.
232
+ k (`torch.Tensor`): The key tensor.
233
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
234
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
235
+ position_ids (`torch.Tensor`, *optional*):
236
+ Deprecated and unused.
237
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
238
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
239
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
240
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
241
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
242
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
243
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
244
+ Returns:
245
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
246
+ """
247
+ cos = cos.unsqueeze(unsqueeze_dim)
248
+ sin = sin.unsqueeze(unsqueeze_dim)
249
+ q_embed = (q * cos) + (rotate_half(q) * sin)
250
+ k_embed = (k * cos) + (rotate_half(k) * sin)
251
+ return q_embed, k_embed
252
+
253
+
254
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
255
+ """
256
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
257
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
258
+ """
259
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
260
+ if n_rep == 1:
261
+ return hidden_states
262
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
263
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
264
+
265
+
266
+ def eager_attention_forward(
267
+ module: nn.Module,
268
+ query: torch.Tensor,
269
+ key: torch.Tensor,
270
+ value: torch.Tensor,
271
+ attention_mask: Optional[torch.Tensor],
272
+ scaling: float,
273
+ dropout: float = 0.0,
274
+ **kwargs,
275
+ ):
276
+ key_states = repeat_kv(key, module.num_key_value_groups)
277
+ value_states = repeat_kv(value, module.num_key_value_groups)
278
+
279
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
280
+ if attention_mask is not None:
281
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
282
+ attn_weights = attn_weights + causal_mask
283
+
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
285
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
286
+ attn_output = torch.matmul(attn_weights, value_states)
287
+ attn_output = attn_output.transpose(1, 2).contiguous()
288
+
289
+ return attn_output, attn_weights
290
+
291
+
292
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
293
+ r"""
294
+ TODO let's just use the original freqcis computation to not have the view
295
+ transpose + reshape! This is not optimized!
296
+ Applies Rotary Position Embedding to the query and key tensors.
297
+
298
+ Args:
299
+ q (`torch.Tensor`): The query tensor.
300
+ k (`torch.Tensor`): The key tensor.
301
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
302
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
303
+ position_ids (`torch.Tensor`):
304
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
305
+ used to pass offsetted position ids when working with a KV-cache.
306
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
307
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
308
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
309
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
310
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
311
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
312
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
313
+ Returns:
314
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
315
+ """
316
+ cos = cos.unsqueeze(unsqueeze_dim)
317
+ sin = sin.unsqueeze(unsqueeze_dim)
318
+
319
+ b, h, s, d = q.shape
320
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
321
+
322
+ b, h, s, d = k.shape
323
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
324
+
325
+ q_embed = (q * cos) + (rotate_half(q) * sin)
326
+ k_embed = (k * cos) + (rotate_half(k) * sin)
327
+ return q_embed, k_embed
328
+
329
+
330
+ def yarn_get_mscale(scale=1, mscale=1):
331
+ if scale <= 1:
332
+ return 1.0
333
+ return 0.1 * mscale * math.log(scale) + 1.0
334
+
335
+
336
+ class DeepseekV3Attention(nn.Module):
337
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
338
+
339
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
340
+ super().__init__()
341
+ self.config = config
342
+ self.layer_idx = layer_idx
343
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
344
+ self.attention_dropout = config.attention_dropout
345
+ self.num_heads = config.num_attention_heads
346
+ self.rope_theta = config.rope_theta
347
+ self.q_lora_rank = config.q_lora_rank
348
+ self.qk_rope_head_dim = config.qk_rope_head_dim
349
+ self.kv_lora_rank = config.kv_lora_rank
350
+ self.v_head_dim = config.v_head_dim
351
+ self.qk_nope_head_dim = config.qk_nope_head_dim
352
+ self.qk_head_dim = config.qk_head_dim
353
+
354
+ self.is_causal = True
355
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
356
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
357
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
358
+
359
+ self.kv_a_proj_with_mqa = nn.Linear(
360
+ config.hidden_size,
361
+ self.kv_lora_rank + self.qk_rope_head_dim,
362
+ bias=config.attention_bias,
363
+ )
364
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
365
+ self.kv_b_proj = nn.Linear(
366
+ self.kv_lora_rank,
367
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
368
+ bias=False,
369
+ )
370
+
371
+ self.o_proj = nn.Linear(
372
+ self.num_heads * self.v_head_dim,
373
+ config.hidden_size,
374
+ bias=config.attention_bias,
375
+ )
376
+
377
+ self.scaling = self.qk_head_dim ** (-0.5)
378
+ if self.config.rope_scaling is not None:
379
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
380
+ scaling_factor = self.config.rope_scaling["factor"]
381
+ if mscale_all_dim:
382
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
383
+ self.scaling = self.scaling * mscale * mscale
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
389
+ attention_mask: Optional[torch.Tensor],
390
+ past_key_value: Optional[Cache] = None,
391
+ cache_position: Optional[torch.LongTensor] = None,
392
+ **kwargs: Unpack[FlashAttentionKwargs],
393
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
394
+ batch_size, seq_length = hidden_states.shape[:-1]
395
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
396
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
397
+
398
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
399
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
400
+
401
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
402
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
403
+
404
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
405
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
406
+
407
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
408
+
409
+ cos, sin = position_embeddings
410
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
411
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
412
+ else:
413
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
414
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
415
+
416
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
417
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
418
+
419
+ if past_key_value is not None:
420
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
421
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
422
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
423
+
424
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
425
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
426
+
427
+ attention_interface: Callable = eager_attention_forward
428
+ if self.config._attn_implementation != "eager":
429
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
430
+ logger.warning_once(
431
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
432
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
433
+ )
434
+ else:
435
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
436
+
437
+ attn_output, attn_weights = attention_interface(
438
+ self,
439
+ query_states,
440
+ key_states,
441
+ value_states,
442
+ attention_mask,
443
+ dropout=0.0 if not self.training else self.attention_dropout,
444
+ scaling=self.scaling,
445
+ **kwargs,
446
+ )
447
+
448
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
449
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
450
+
451
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
452
+ attn_output = self.o_proj(attn_output)
453
+ return attn_output, attn_weights
454
+
455
+
456
+ class DeepseekV3DecoderLayer(nn.Module):
457
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
458
+ super().__init__()
459
+ self.hidden_size = config.hidden_size
460
+
461
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
462
+
463
+ if layer_idx >= config.first_k_dense_replace:
464
+ self.mlp = DeepseekV3MoE(config)
465
+ else:
466
+ self.mlp = DeepseekV3MLP(config)
467
+
468
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
469
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
470
+
471
+ def forward(
472
+ self,
473
+ hidden_states: torch.Tensor,
474
+ attention_mask: Optional[torch.Tensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ past_key_value: Optional[Cache] = None,
477
+ output_attentions: Optional[bool] = False,
478
+ use_cache: Optional[bool] = False,
479
+ cache_position: Optional[torch.LongTensor] = None,
480
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
481
+ **kwargs: Unpack[FlashAttentionKwargs],
482
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
483
+ residual = hidden_states
484
+
485
+ hidden_states = self.input_layernorm(hidden_states)
486
+
487
+ # Self Attention
488
+ hidden_states, self_attn_weights = self.self_attn(
489
+ hidden_states=hidden_states,
490
+ attention_mask=attention_mask,
491
+ position_ids=position_ids,
492
+ past_key_value=past_key_value,
493
+ output_attentions=output_attentions,
494
+ use_cache=use_cache,
495
+ cache_position=cache_position,
496
+ position_embeddings=position_embeddings,
497
+ **kwargs,
498
+ )
499
+ hidden_states = residual + hidden_states
500
+
501
+ # Fully Connected
502
+ residual = hidden_states
503
+ hidden_states = self.post_attention_layernorm(hidden_states)
504
+ hidden_states = self.mlp(hidden_states)
505
+ hidden_states = residual + hidden_states
506
+
507
+ outputs = (hidden_states,)
508
+ if output_attentions:
509
+ outputs += (self_attn_weights,)
510
+
511
+ return outputs
512
+
513
+
514
+ DEEPSEEK_V3_START_DOCSTRING = r"""
515
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
516
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
517
+ etc.)
518
+
519
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
520
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
521
+ and behavior.
522
+
523
+ Parameters:
524
+ config ([`DeepseekV3Config`]):
525
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
526
+ load the weights associated with the model, only the configuration. Check out the
527
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
528
+ """
529
+
530
+
531
+ @add_start_docstrings(
532
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
533
+ DEEPSEEK_V3_START_DOCSTRING,
534
+ )
535
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
536
+ config_class = DeepseekV3Config
537
+ base_model_prefix = "model"
538
+ supports_gradient_checkpointing = True
539
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
540
+ _skip_keys_device_placement = ["past_key_values"]
541
+ _supports_flash_attn_2 = True
542
+ _supports_sdpa = True
543
+ _supports_flex_attn = True
544
+ _supports_cache_class = True
545
+ _supports_quantized_cache = True
546
+ _supports_static_cache = True
547
+ _supports_attention_backend = True
548
+
549
+ def _init_weights(self, module):
550
+ std = self.config.initializer_range
551
+ if isinstance(module, nn.Linear):
552
+ module.weight.data.normal_(mean=0.0, std=std)
553
+ if module.bias is not None:
554
+ module.bias.data.zero_()
555
+ elif isinstance(module, nn.Embedding):
556
+ module.weight.data.normal_(mean=0.0, std=std)
557
+ if module.padding_idx is not None:
558
+ module.weight.data[module.padding_idx].zero_()
559
+ elif isinstance(module, DeepseekV3TopkRouter):
560
+ module.weight.data.normal_(mean=0.0, std=std)
561
+ elif isinstance(module, nn.Parameter):
562
+ module.weight.data.normal_(mean=0.0, std=std)
563
+
564
+
565
+ DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
566
+ Args:
567
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
568
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
569
+ it.
570
+
571
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
572
+ [`PreTrainedTokenizer.__call__`] for details.
573
+
574
+ [What are input IDs?](../glossary#input-ids)
575
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
576
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
577
+
578
+ - 1 for tokens that are **not masked**,
579
+ - 0 for tokens that are **masked**.
580
+
581
+ [What are attention masks?](../glossary#attention-mask)
582
+
583
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
584
+ [`PreTrainedTokenizer.__call__`] for details.
585
+
586
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
587
+ `past_key_values`).
588
+
589
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
590
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
591
+ information on the default strategy.
592
+
593
+ - 1 indicates the head is **not masked**,
594
+ - 0 indicates the head is **masked**.
595
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
596
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
597
+ config.n_positions - 1]`.
598
+
599
+ [What are position IDs?](../glossary#position-ids)
600
+ past_key_values (`Cache`, *optional*):
601
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
602
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
603
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
604
+
605
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
606
+
607
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
608
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
609
+ of shape `(batch_size, sequence_length)`.
610
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
611
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
612
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
613
+ model's internal embedding lookup matrix.
614
+ use_cache (`bool`, *optional*):
615
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
616
+ `past_key_values`).
617
+ output_attentions (`bool`, *optional*):
618
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
619
+ tensors for more detail.
620
+ output_hidden_states (`bool`, *optional*):
621
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
622
+ more detail.
623
+ return_dict (`bool`, *optional*):
624
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
625
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
626
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
627
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
628
+ the complete sequence length.
629
+ """
630
+
631
+
632
+ @add_start_docstrings(
633
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
634
+ DEEPSEEK_V3_START_DOCSTRING,
635
+ )
636
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
637
+ """
638
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
639
+
640
+ Args:
641
+ config: DeepseekV3Config
642
+ """
643
+
644
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
645
+
646
+ def __init__(self, config: DeepseekV3Config):
647
+ super().__init__(config)
648
+ self.padding_idx = config.pad_token_id
649
+ self.vocab_size = config.vocab_size
650
+
651
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
652
+ self.layers = nn.ModuleList(
653
+ [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
654
+ )
655
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
656
+ self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
657
+ self.gradient_checkpointing = False
658
+
659
+ # Initialize weights and apply final processing
660
+ self.post_init()
661
+
662
+ def get_input_embeddings(self):
663
+ return self.embed_tokens
664
+
665
+ def set_input_embeddings(self, value):
666
+ self.embed_tokens = value
667
+
668
+ @can_return_tuple
669
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
670
+ def forward(
671
+ self,
672
+ input_ids: Optional[torch.LongTensor] = None,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_values: Optional[Cache] = None,
676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
677
+ use_cache: Optional[bool] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ cache_position: Optional[torch.LongTensor] = None,
681
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
682
+ ) -> BaseModelOutputWithPast:
683
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
684
+ output_hidden_states = (
685
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
686
+ )
687
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
688
+
689
+ if (input_ids is None) ^ (inputs_embeds is not None):
690
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
691
+
692
+ if self.gradient_checkpointing and self.training and use_cache:
693
+ logger.warning_once(
694
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
695
+ )
696
+ use_cache = False
697
+
698
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
699
+ if not isinstance(past_key_values, (type(None), Cache)):
700
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
701
+
702
+ if inputs_embeds is None:
703
+ inputs_embeds = self.embed_tokens(input_ids)
704
+
705
+ if use_cache and past_key_values is None:
706
+ past_key_values = DynamicCache()
707
+
708
+ if cache_position is None:
709
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
710
+ cache_position = torch.arange(
711
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
712
+ )
713
+
714
+ if position_ids is None:
715
+ position_ids = cache_position.unsqueeze(0)
716
+
717
+ causal_mask = self._update_causal_mask(
718
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
719
+ )
720
+
721
+ hidden_states = inputs_embeds
722
+
723
+ # create position embeddings to be shared across the decoder layers
724
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
725
+
726
+ # decoder layers
727
+ all_hidden_states = () if output_hidden_states else None
728
+ all_self_attns = () if output_attentions else None
729
+
730
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
731
+ if output_hidden_states:
732
+ all_hidden_states += (hidden_states,)
733
+
734
+ if self.gradient_checkpointing and self.training:
735
+ layer_outputs = self._gradient_checkpointing_func(
736
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
737
+ hidden_states,
738
+ causal_mask,
739
+ position_ids,
740
+ past_key_values,
741
+ output_attentions,
742
+ use_cache,
743
+ cache_position,
744
+ position_embeddings,
745
+ )
746
+ else:
747
+ layer_outputs = decoder_layer(
748
+ hidden_states,
749
+ attention_mask=causal_mask,
750
+ position_ids=position_ids,
751
+ past_key_value=past_key_values,
752
+ output_attentions=output_attentions,
753
+ use_cache=use_cache,
754
+ cache_position=cache_position,
755
+ position_embeddings=position_embeddings,
756
+ **flash_attn_kwargs,
757
+ )
758
+
759
+ hidden_states = layer_outputs[0]
760
+
761
+ if output_attentions:
762
+ all_self_attns += (layer_outputs[1],)
763
+
764
+ hidden_states = self.norm(hidden_states)
765
+
766
+ # add hidden states from the last decoder layer
767
+ if output_hidden_states:
768
+ all_hidden_states += (hidden_states,)
769
+
770
+ return BaseModelOutputWithPast(
771
+ last_hidden_state=hidden_states,
772
+ past_key_values=past_key_values if use_cache else None,
773
+ hidden_states=all_hidden_states,
774
+ attentions=all_self_attns,
775
+ )
776
+
777
+ def _update_causal_mask(
778
+ self,
779
+ attention_mask: torch.Tensor,
780
+ input_tensor: torch.Tensor,
781
+ cache_position: torch.Tensor,
782
+ past_key_values: Cache,
783
+ output_attentions: bool = False,
784
+ ):
785
+ if self.config._attn_implementation == "flash_attention_2":
786
+ if attention_mask is not None and (attention_mask == 0.0).any():
787
+ return attention_mask
788
+ return None
789
+ if self.config._attn_implementation == "flex_attention":
790
+ if isinstance(attention_mask, torch.Tensor):
791
+ attention_mask = make_flex_block_causal_mask(attention_mask)
792
+ if isinstance(attention_mask, BlockMask):
793
+ return attention_mask
794
+
795
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
796
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
797
+ # to infer the attention mask.
798
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
799
+ using_static_cache = isinstance(past_key_values, StaticCache)
800
+
801
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
802
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
803
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
804
+ attention_mask,
805
+ inputs_embeds=input_tensor,
806
+ past_key_values_length=past_seen_tokens,
807
+ is_training=self.training,
808
+ ):
809
+ return None
810
+
811
+ dtype, device = input_tensor.dtype, input_tensor.device
812
+ sequence_length = input_tensor.shape[1]
813
+ if using_static_cache:
814
+ target_length = past_key_values.get_max_cache_shape()
815
+ else:
816
+ target_length = (
817
+ attention_mask.shape[-1]
818
+ if isinstance(attention_mask, torch.Tensor)
819
+ else past_seen_tokens + sequence_length + 1
820
+ )
821
+
822
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
823
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
824
+ attention_mask,
825
+ sequence_length=sequence_length,
826
+ target_length=target_length,
827
+ dtype=dtype,
828
+ device=device,
829
+ cache_position=cache_position,
830
+ batch_size=input_tensor.shape[0],
831
+ )
832
+
833
+ if (
834
+ self.config._attn_implementation == "sdpa"
835
+ and attention_mask is not None
836
+ and attention_mask.device.type in ["cuda", "xpu"]
837
+ and not output_attentions
838
+ ):
839
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
840
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
841
+ # Details: https://github.com/pytorch/pytorch/issues/110213
842
+ min_dtype = torch.finfo(dtype).min
843
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
844
+
845
+ return causal_mask
846
+
847
+ @staticmethod
848
+ def _prepare_4d_causal_attention_mask_with_cache_position(
849
+ attention_mask: torch.Tensor,
850
+ sequence_length: int,
851
+ target_length: int,
852
+ dtype: torch.dtype,
853
+ device: torch.device,
854
+ cache_position: torch.Tensor,
855
+ batch_size: int,
856
+ **kwargs,
857
+ ):
858
+ """
859
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
860
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
861
+
862
+ Args:
863
+ attention_mask (`torch.Tensor`):
864
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
865
+ `(batch_size, 1, query_length, key_value_length)`.
866
+ sequence_length (`int`):
867
+ The sequence length being processed.
868
+ target_length (`int`):
869
+ The target length: when generating with static cache, the mask should be as long as the static cache,
870
+ to account for the 0 padding, the part of the cache that is not filled yet.
871
+ dtype (`torch.dtype`):
872
+ The dtype to use for the 4D attention mask.
873
+ device (`torch.device`):
874
+ The device to place the 4D attention mask on.
875
+ cache_position (`torch.Tensor`):
876
+ Indices depicting the position of the input sequence tokens in the sequence.
877
+ batch_size (`torch.Tensor`):
878
+ Batch size.
879
+ """
880
+ if attention_mask is not None and attention_mask.dim() == 4:
881
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
882
+ causal_mask = attention_mask
883
+ else:
884
+ min_dtype = torch.finfo(dtype).min
885
+ causal_mask = torch.full(
886
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
887
+ )
888
+ if sequence_length != 1:
889
+ causal_mask = torch.triu(causal_mask, diagonal=1)
890
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
891
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
892
+ if attention_mask is not None:
893
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
894
+ mask_length = attention_mask.shape[-1]
895
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
896
+ causal_mask.device
897
+ )
898
+ padding_mask = padding_mask == 0
899
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
900
+ padding_mask, min_dtype
901
+ )
902
+
903
+ return causal_mask
904
+
905
+
906
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
907
+
908
+
909
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
910
+ _tied_weights_keys = ["lm_head.weight"]
911
+ _tp_plan = {"lm_head": "colwise_rep"}
912
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
913
+
914
+ def __init__(self, config):
915
+ super().__init__(config)
916
+ self.model = DeepseekV3Model(config)
917
+ self.vocab_size = config.vocab_size
918
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
919
+
920
+ # Initialize weights and apply final processing
921
+ self.post_init()
922
+
923
+ def get_input_embeddings(self):
924
+ return self.model.embed_tokens
925
+
926
+ def set_input_embeddings(self, value):
927
+ self.model.embed_tokens = value
928
+
929
+ def get_output_embeddings(self):
930
+ return self.lm_head
931
+
932
+ def set_output_embeddings(self, new_embeddings):
933
+ self.lm_head = new_embeddings
934
+
935
+ def set_decoder(self, decoder):
936
+ self.model = decoder
937
+
938
+ def get_decoder(self):
939
+ return self.model
940
+
941
+ @can_return_tuple
942
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
943
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
944
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
945
+ def forward(
946
+ self,
947
+ input_ids: Optional[torch.LongTensor] = None,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.LongTensor] = None,
950
+ past_key_values: Optional[Cache] = None,
951
+ inputs_embeds: Optional[torch.FloatTensor] = None,
952
+ labels: Optional[torch.LongTensor] = None,
953
+ use_cache: Optional[bool] = None,
954
+ output_attentions: Optional[bool] = None,
955
+ output_hidden_states: Optional[bool] = None,
956
+ cache_position: Optional[torch.LongTensor] = None,
957
+ logits_to_keep: Union[int, torch.Tensor] = 0,
958
+ **kwargs: Unpack[KwargsForCausalLM],
959
+ ) -> CausalLMOutputWithPast:
960
+ r"""
961
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
962
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
963
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
964
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
965
+
966
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
967
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
968
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
969
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
970
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
971
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
972
+
973
+ Returns:
974
+
975
+ Example:
976
+
977
+ ```python
978
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
979
+
980
+ >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
981
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
982
+
983
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
984
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
985
+
986
+ >>> # Generate
987
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
988
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
989
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
990
+ ```"""
991
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
992
+ output_hidden_states = (
993
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
994
+ )
995
+
996
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
997
+ outputs: BaseModelOutputWithPast = self.model(
998
+ input_ids=input_ids,
999
+ attention_mask=attention_mask,
1000
+ position_ids=position_ids,
1001
+ past_key_values=past_key_values,
1002
+ inputs_embeds=inputs_embeds,
1003
+ use_cache=use_cache,
1004
+ output_attentions=output_attentions,
1005
+ output_hidden_states=output_hidden_states,
1006
+ cache_position=cache_position,
1007
+ **kwargs,
1008
+ )
1009
+
1010
+ hidden_states = outputs.last_hidden_state
1011
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1012
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1013
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1014
+
1015
+ loss = None
1016
+ if labels is not None:
1017
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1018
+
1019
+ return CausalLMOutputWithPast(
1020
+ loss=loss,
1021
+ logits=logits,
1022
+ past_key_values=outputs.past_key_values,
1023
+ hidden_states=outputs.hidden_states,
1024
+ attentions=outputs.attentions,
1025
+ )
1026
+
1027
+
1028
+ __all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"]