| | import math |
| | import warnings |
| | from collections.abc import Callable |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | from transformers import initialization as init |
| | from transformers.cache_utils import Cache |
| | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| | from transformers.modeling_layers import GenericForSequenceClassification, GenericForTokenClassification |
| | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| | from transformers.processing_utils import Unpack |
| | from transformers.utils import logging |
| | from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( |
| | DeepseekV3Attention, |
| | DeepseekV3DecoderLayer, |
| | DeepseekV3ForCausalLM, |
| | DeepseekV3MLP, |
| | DeepseekV3Model, |
| | DeepseekV3MoE, |
| | DeepseekV3PreTrainedModel, |
| | DeepseekV3RMSNorm, |
| | DeepseekV3RotaryEmbedding, |
| | apply_rotary_pos_emb_interleave, |
| | yarn_get_mscale, |
| | ) |
| | from transformers.models.llama.modeling_llama import ( |
| | apply_rotary_pos_emb, |
| | eager_attention_forward, |
| | ) |
| | from configuration_deepseek_v32 import DeepseekV32Config |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class DeepseekV32RMSNorm(DeepseekV3RMSNorm): |
| | pass |
| |
|
| |
|
| | class DeepseekV32RotaryEmbedding(DeepseekV3RotaryEmbedding): |
| | pass |
| |
|
| |
|
| | class DeepseekV32MLP(DeepseekV3MLP): |
| | pass |
| |
|
| |
|
| | class DeepseekV32MoE(DeepseekV3MoE): |
| | pass |
| |
|
| |
|
| | class DeepseekV32SparseAttention(nn.Module): |
| | """ |
| | DeepSeek V3.2 sparse attention mechanism with indexer. |
| | |
| | This implements the native sparse attention from DeepSeek V3.2 which uses |
| | an indexer to select top-k tokens for attention computation, making it |
| | more efficient for long sequences. |
| | """ |
| |
|
| | def __init__(self, config: DeepseekV32Config, layer_idx: int): |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| | self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| | self.attention_dropout = config.attention_dropout |
| | self.num_heads = config.num_attention_heads |
| |
|
| | self.q_lora_rank = config.q_lora_rank |
| | self.qk_rope_head_dim = config.qk_rope_head_dim |
| | self.kv_lora_rank = config.kv_lora_rank |
| | self.v_head_dim = config.v_head_dim |
| | self.qk_nope_head_dim = config.qk_nope_head_dim |
| | self.qk_head_dim = config.qk_head_dim |
| | self.index_topk = config.index_topk |
| |
|
| | self.is_causal = True |
| |
|
| | |
| | if self.q_lora_rank is None: |
| | self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) |
| | else: |
| | self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) |
| | self.q_a_layernorm = DeepseekV32RMSNorm(config.q_lora_rank) |
| | self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) |
| |
|
| | |
| | self.kv_a_proj_with_mqa = nn.Linear( |
| | config.hidden_size, |
| | self.kv_lora_rank + self.qk_rope_head_dim, |
| | bias=config.attention_bias, |
| | ) |
| | self.kv_a_layernorm = DeepseekV32RMSNorm(self.kv_lora_rank) |
| | self.kv_b_proj = nn.Linear( |
| | self.kv_lora_rank, |
| | self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), |
| | bias=False, |
| | ) |
| |
|
| | |
| | self.o_proj = nn.Linear( |
| | self.num_heads * self.v_head_dim, |
| | config.hidden_size, |
| | bias=config.attention_bias, |
| | ) |
| |
|
| | |
| | self.wq_b = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) |
| | self.wk = nn.Linear(config.hidden_size, self.qk_head_dim, bias=config.attention_bias) |
| | self.k_norm = DeepseekV32RMSNorm(self.qk_head_dim) |
| | self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) |
| |
|
| | self.scaling = self.qk_head_dim ** (-0.5) |
| | if self.config.rope_parameters.get("rope_type", "default") != "default": |
| | mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) |
| | scaling_factor = self.config.rope_parameters["factor"] |
| | if mscale_all_dim: |
| | mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) |
| | self.scaling = self.scaling * mscale * mscale |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| | attention_mask: Optional[torch.Tensor], |
| | past_key_values: Optional[Cache] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs: Unpack[FlashAttentionKwargs], |
| | ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| | batch_size, seq_length = hidden_states.shape[:-1] |
| |
|
| | |
| | |
| | if self.training or seq_length <= self.index_topk: |
| | warnings.warn( |
| | "DeepSeek V3.2 sparse attention is not fully implemented in this version. " |
| | "Falling back to standard attention. For production use, please use vLLM or " |
| | "other optimized inference engines.", |
| | UserWarning, |
| | ) |
| | return self._standard_attention( |
| | hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs |
| | ) |
| |
|
| | |
| | |
| | return self._standard_attention( |
| | hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs |
| | ) |
| |
|
| | def _standard_attention( |
| | self, |
| | hidden_states: torch.Tensor, |
| | position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| | attention_mask: Optional[torch.Tensor], |
| | past_key_values: Optional[Cache] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs: Unpack[FlashAttentionKwargs], |
| | ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| | """Standard attention fallback (same as DeepSeek V3)""" |
| | batch_size, seq_length = hidden_states.shape[:-1] |
| | query_shape = (batch_size, seq_length, -1, self.qk_head_dim) |
| | key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) |
| |
|
| | if self.q_lora_rank is None: |
| | q_states = self.q_proj(hidden_states) |
| | else: |
| | q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) |
| | q_states = q_states.view(query_shape).transpose(1, 2) |
| | q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
| |
|
| | compressed_kv = self.kv_a_proj_with_mqa(hidden_states) |
| | k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
| |
|
| | k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) |
| | k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) |
| |
|
| | k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) |
| |
|
| | cos, sin = position_embeddings |
| | if self.config.rope_interleave: |
| | q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) |
| | else: |
| | q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) |
| | k_rot = k_rot.expand(*k_pass.shape[:-1], -1) |
| |
|
| | query_states = torch.cat((q_pass, q_rot), dim=-1) |
| | key_states = torch.cat((k_pass, k_rot), dim=-1) |
| |
|
| | if past_key_values is not None: |
| | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| | key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| |
|
| | if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: |
| | value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) |
| |
|
| | attention_interface: Callable = eager_attention_forward |
| | if self.config._attn_implementation != "eager": |
| | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| |
|
| | attn_output, attn_weights = attention_interface( |
| | self, |
| | query_states, |
| | key_states, |
| | value_states, |
| | attention_mask, |
| | dropout=0.0 if not self.training else self.attention_dropout, |
| | scaling=self.scaling, |
| | **kwargs, |
| | ) |
| |
|
| | if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: |
| | attn_output = attn_output[:, :, :, : self.v_head_dim] |
| |
|
| | attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() |
| | attn_output = self.o_proj(attn_output) |
| | return attn_output, attn_weights |
| |
|
| |
|
| | class DeepseekV32DecoderLayer(nn.Module): |
| | def __init__(self, config: DeepseekV32Config, layer_idx: int): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| |
|
| | |
| | self.self_attn = DeepseekV32SparseAttention(config=config, layer_idx=layer_idx) |
| |
|
| | if layer_idx >= config.first_k_dense_replace: |
| | self.mlp = DeepseekV32MoE(config) |
| | else: |
| | self.mlp = DeepseekV32MLP(config) |
| |
|
| | self.input_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.post_attention_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs: Unpack[FlashAttentionKwargs], |
| | ) -> torch.Tensor: |
| | residual = hidden_states |
| |
|
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | |
| | hidden_states, self_attn_weights = self.self_attn( |
| | hidden_states=hidden_states, |
| | position_embeddings=position_embeddings, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class DeepseekV32PreTrainedModel(DeepseekV3PreTrainedModel): |
| | config_class = DeepseekV32Config |
| | _can_compile_fullgraph = False |
| | _keep_in_fp32_modules_strict = ["e_score_correction_bias"] |
| |
|
| |
|
| | class DeepseekV32Model(DeepseekV3Model): |
| | """ |
| | DeepSeek V3.2 Model with native sparse attention. |
| | |
| | This model extends DeepSeek V3 with an efficient sparse attention mechanism |
| | that uses an indexer to select top-k tokens for attention computation. |
| | """ |
| | config_class = DeepseekV32Config |
| | _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"] |
| |
|
| | def __init__(self, config: DeepseekV32Config): |
| | |
| | DeepseekV3PreTrainedModel.__init__(self, config) |
| | self.padding_idx = config.pad_token_id |
| | self.vocab_size = config.vocab_size |
| |
|
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| | |
| | self.layers = nn.ModuleList( |
| | [DeepseekV32DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| | ) |
| | self.norm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.rotary_emb = DeepseekV32RotaryEmbedding(config=config) |
| | self.gradient_checkpointing = False |
| |
|
| | |
| | self.post_init() |
| |
|
| |
|
| | class DeepseekV32ForCausalLM(DeepseekV3ForCausalLM): |
| | """ |
| | DeepSeek V3.2 Model for causal language modeling with sparse attention. |
| | """ |
| | config_class = DeepseekV32Config |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config): |
| | super(DeepseekV3ForCausalLM, self).__init__(config) |
| | self.model = DeepseekV32Model(config) |
| | self.vocab_size = config.vocab_size |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| | self.post_init() |
| |
|
| |
|
| | class DeepseekV32ForSequenceClassification(GenericForSequenceClassification, DeepseekV32PreTrainedModel): |
| | pass |
| |
|
| |
|
| | class DeepseekV32ForTokenClassification(GenericForTokenClassification, DeepseekV32PreTrainedModel): |
| | pass |
| |
|
| |
|
| | __all__ = [ |
| | "DeepseekV32PreTrainedModel", |
| | "DeepseekV32Model", |
| | "DeepseekV32ForCausalLM", |
| | "DeepseekV32ForSequenceClassification", |
| | "DeepseekV32ForTokenClassification", |
| | ] |