| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import math |
| from collections.abc import Callable |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, can_return_tuple, logging |
| from transformers.utils.generic import check_model_inputs |
| from .configuration_deepseek_v32 import DeepseekV32Config |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class DeepseekV32RMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps: float = 1e-6) -> None: |
| """ |
| DeepseekV32RMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| r""" |
| TODO let's just use the original freqcis computation to not have the view |
| transpose + reshape! This is not optimized! |
| Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`): |
| The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
| used to pass offsetted position ids when working with a KV-cache. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
|
|
| b, h, s, d = q.shape |
| q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) |
|
|
| b, h, s, d = k.shape |
| k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) |
|
|
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def yarn_get_mscale(scale=1, mscale=1): |
| if scale <= 1: |
| return 1.0 |
| return 0.1 * mscale * math.log(scale) + 1.0 |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class DeepseekV32Indexer(nn.Module): |
| def __init__(self, config: "DeepseekV32Config", index_layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = index_layer_idx |
|
|
| self.hidden_size: int = config.hidden_size |
| self.num_heads: int = config.index_n_heads |
| self.num_local_heads: int = config.index_n_heads |
| self.head_dim: int = config.index_head_dim |
| self.qk_rope_head_dim: int = config.qk_rope_head_dim |
| self.index_topk: int = config.index_topk |
| self.q_lora_rank: int = config.q_lora_rank |
|
|
| self.wq_b = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False) |
| self.wk = nn.Linear(self.hidden_size, self.head_dim, bias=False) |
| self.k_norm = nn.LayerNorm(self.head_dim) |
| self.weights_proj = nn.Linear(self.hidden_size, self.num_heads, dtype=torch.get_default_dtype(), bias=False) |
| self.softmax_scale = self.head_dim**-0.5 |
|
|
| @torch.no_grad() |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| q_resid: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: torch.Tensor | None, |
| past_key_values_index: "Cache", |
| cache_position: torch.LongTensor | None, |
| ) -> torch.LongTensor: |
| B, S, _ = hidden_states.shape |
| cos, sin = position_embeddings |
|
|
| |
| q_states = self.wq_bj(q_resid) |
| q_states = q_states.view(B, S, self.num_heads, self.head_dim) |
| q_rot, q_pass = torch.split(q_states, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) |
| q_rot = apply_rotary_pos_emb_interleave(q_rot, cos, sin) |
| q_states = torch.cat([q_rot, q_pass], dim=-1) |
|
|
| |
| k = self.k_norm(self.wk(hidden_states)) |
| k_rot, k_pass = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) |
| |
| k_rot = k_rot.unsqueeze(1) |
| k_rot = apply_rotary_pos_emb_interleave(k_rot, cos, sin) |
| k_states = torch.cat( |
| [ |
| k_rot.expand(B, self.num_heads, S, -1), |
| k_pass.view(B, 1, S, -1).expand(B, self.num_heads, S, -1), |
| ], |
| dim=-1, |
| ) |
|
|
| |
| |
| |
| |
| k_1h = k_states.mean(dim=1, keepdim=True) |
| k_cache = past_key_values_index.update(k_1h, self.layer_idx, cache_kwargs={"cache_position": cache_position}) |
|
|
| |
| head_weights = self.weights_proj(hidden_states) * (self.num_heads**-0.5) |
| head_weights = head_weights.unsqueeze(-1) * self.softmax_scale |
| logits = torch.matmul(k_cache.unsqueeze(1), q_states.transpose(-1, -2)) |
|
|
| |
| logits.clamp_min_(0) |
| index_scores = logits.sum(dim=-1) |
|
|
| if attention_mask is not None: |
| index_scores = index_scores + attention_mask |
|
|
| T = index_scores.shape[-1] |
| topk = min(self.index_topk, T) |
| topk_indices = index_scores.topk(topk, dim=-1).indices |
| return topk_indices |
|
|
| class DeepseekV32Attention(nn.Module): |
| """ |
| DeepSeek V3.2 sparse attention mechanism with indexer. |
| |
| This implements the native sparse attention from [DeepSeek V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) which uses |
| an indexer to select top-k tokens for attention computation, making it more efficient for long sequences. |
| """ |
|
|
| def __init__(self, config: DeepseekV32Config, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.attention_dropout = config.attention_dropout |
| self.num_heads = config.num_attention_heads |
|
|
| self.q_lora_rank = config.q_lora_rank |
| self.qk_rope_head_dim = config.qk_rope_head_dim |
| self.kv_lora_rank = config.kv_lora_rank |
| self.v_head_dim = config.v_head_dim |
| self.qk_nope_head_dim = config.qk_nope_head_dim |
| self.qk_head_dim = config.qk_head_dim |
| self.index_topk = config.index_topk |
|
|
| self.is_causal = True |
|
|
| |
| if self.q_lora_rank is None: |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) |
| else: |
| self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) |
| self.q_a_layernorm = DeepseekV32RMSNorm(config.q_lora_rank) |
| self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) |
|
|
| |
| self.kv_a_proj_with_mqa = nn.Linear( |
| config.hidden_size, |
| self.kv_lora_rank + self.qk_rope_head_dim, |
| bias=config.attention_bias, |
| ) |
| self.kv_a_layernorm = DeepseekV32RMSNorm(self.kv_lora_rank) |
| self.kv_b_proj = nn.Linear( |
| self.kv_lora_rank, |
| self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), |
| bias=False, |
| ) |
|
|
| |
| self.o_proj = nn.Linear( |
| self.num_heads * self.v_head_dim, |
| config.hidden_size, |
| bias=config.attention_bias, |
| ) |
|
|
| |
| self.wq_b = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) |
| self.wk = nn.Linear(config.hidden_size, self.qk_head_dim, bias=config.attention_bias) |
| self.k_norm = DeepseekV32RMSNorm(self.qk_head_dim) |
| self.weights_proj = nn.Linear(config.hidden_size, self.num_heads, bias=False) |
|
|
| self.scaling = self.qk_head_dim ** (-0.5) |
| if self.config.rope_scaling.get("rope_type", "default") != "default": |
| mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) |
| scaling_factor = self.config.rope_scaling["factor"] |
| if mscale_all_dim: |
| mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) |
| self.scaling = self.scaling * mscale * mscale |
|
|
| self.indexer = DeepseekV32Indexer(config, layer_idx) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: torch.Tensor | None, |
| past_key_values: Cache | None = None, |
| cache_position: torch.LongTensor | None = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: |
| batch_size, seq_length = hidden_states.shape[:-1] |
|
|
| |
| |
| if self.training or seq_length <= self.index_topk: |
| logger.warning_once( |
| "DeepSeek V3.2 sparse attention is not fully implemented in this version. " |
| "Falling back to standard attention. For production use, please use vLLM or " |
| "other optimized inference engines.", |
| ) |
| return self._standard_attention( |
| hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs |
| ) |
|
|
| |
| |
| return self._dsa_attention( |
| hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs |
| ) |
|
|
| def _standard_attention( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: torch.Tensor | None, |
| past_key_values: Cache | None = None, |
| cache_position: torch.LongTensor | None = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: |
| """Standard attention fallback (same as DeepSeek V3)""" |
| batch_size, seq_length = hidden_states.shape[:-1] |
| query_shape = (batch_size, seq_length, -1, self.qk_head_dim) |
| key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) |
|
|
| if self.q_lora_rank is None: |
| q_states = self.q_proj(hidden_states) |
| else: |
| q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) |
| q_states = q_states.view(query_shape).transpose(1, 2) |
| q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
|
|
| compressed_kv = self.kv_a_proj_with_mqa(hidden_states) |
| k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
|
|
| k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) |
| k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) |
|
|
| k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) |
|
|
| cos, sin = position_embeddings |
| if self.config.rope_interleave: |
| q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) |
| else: |
| q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) |
| k_rot = k_rot.expand(*k_pass.shape[:-1], -1) |
|
|
| query_states = torch.cat((q_pass, q_rot), dim=-1) |
| key_states = torch.cat((k_pass, k_rot), dim=-1) |
|
|
| if past_key_values is not None: |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: |
| value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) |
|
|
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( |
| self.config._attn_implementation, eager_attention_forward |
| ) |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: |
| attn_output = attn_output[:, :, :, : self.v_head_dim] |
|
|
| attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
| def _dsa_attention( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: torch.Tensor | None, |
| past_key_values: Cache | None = None, |
| cache_position: torch.LongTensor | None = None, |
| **kwargs: Unpack[FlashAttentionKwargs] |
| ): |
|
|
| B, S, _ = hidden_states.shape |
| cos, sin = position_embeddings |
|
|
| |
| q_resid = self.q_a_layernorm(self.q_a_proj(hidden_states)) |
| q_states = self.q_b_proj(q_resid).view(B, S, self.num_heads, self.qk_head_dim) |
| |
| q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
| q_rot = apply_rotary_pos_emb(q_rot, cos, sin) |
| q_states = torch.cat([q_pass, q_rot], dim=-1) |
|
|
| |
| q_states = q_states.transpose(1, 2).contiguous() |
|
|
| |
| kv_all = self.kv_a_proj_with_mqa(hidden_states) |
| kv_compressed, k_rot = torch.split(kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
| kv_compressed = self.kv_a_layernorm(kv_compressed) |
| |
| kv_proj = self.kv_b_proj(kv_compressed) |
| kv_proj = kv_proj.view(B, S, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) |
| k_pass, v_states = torch.split( |
| kv_proj, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 |
| ) |
|
|
| |
| k_rot = k_rot.view(B, 1, S, self.qk_rope_head_dim) |
| k_rot = apply_rotary_pos_emb(k_rot, cos, sin) |
|
|
| |
| k_states = torch.cat( |
| ( |
| k_pass.transpose(1, 2), |
| k_rot.expand(B, self.num_heads, S, -1), |
| ), |
| dim=-1, |
| ) |
| v_states = v_states.transpose(1, 2).contiguous() |
|
|
| |
| if past_key_values is not None: |
| |
| |
| |
| kv_comp_cache = kv_compressed.view(B, 1, S, self.kv_lora_rank).expand(B, self.num_heads, S, -1) |
| k_rot_cache = k_rot |
| cached_kv, cached_pe = past_key_values.update( |
| kv_comp_cache, k_rot_cache, layer_idx=self.layer_idx, cache_kwargs={"cache_position": cache_position} |
| ) |
| |
|
|
| |
| if attention_mask is not None: |
| |
|
|
| |
| |
| scores = (q_states.float() @ k_states.float().transpose(-1, -2)) * self.scaling |
|
|
| |
| if past_key_values is not None: |
| topk_idx = self.indexer( |
| hidden_states, |
| q_resid, |
| position_embeddings, |
| attention_mask, |
| past_key_values_index=past_key_values, |
| cache_position=cache_position, |
| ) |
| |
| |
| keep_mask = torch.full_like(scores, float("-inf")) |
| |
| if topk_idx.dim() == 3: |
| topk_idx = topk_idx.unsqueeze(1).expand(B, self.num_heads, S, -1) |
| keep_mask.scatter_(-1, topk_idx, 0.0) |
| scores = scores + keep_mask |
|
|
| probs = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(hidden_states) |
| attn_output = probs @ v_states |
|
|
| elif past_key_values is not None: |
| |
| |
| |
| wkv_b = self.kv_b_proj.weight.view( |
| self.num_heads, self.qk_nope_head_dim + self.v_head_dim, self.kv_lora_rank |
| ) |
| w_k_nope = wkv_b[:, : self.qk_nope_head_dim, :] |
| w_v = wkv_b[:, self.qk_nope_head_dim :, :] |
|
|
| |
| q_pass = q_states[..., : self.qk_nope_head_dim] |
| kv_comp = past_key_values[self.layer_idx][0] |
| pe_full = past_key_values[self.layer_idx][1] |
| |
| qk_nope = torch.matmul(q_pass, w_k_nope.transpose(-1, -2)) |
| |
| scores_nope = torch.matmul(qk_nope.float(), kv_comp.float().transpose(-1, -2)) |
|
|
| |
| q_rot_only = q_states[..., -self.qk_rope_head_dim :] |
| k_rot_only = pe_full.expand(B, self.num_heads, -1, -1) |
| scores_rot = torch.matmul(q_rot_only.float(), k_rot_only.float().transpose(-1, -2)) |
|
|
| scores = (scores_nope + scores_rot) * self.scaling |
|
|
| |
| topk_idx = self.indexer( |
| hidden_states, |
| q_resid, |
| position_embeddings, |
| attention_mask, |
| past_key_values_index=past_key_values, |
| cache_position=cache_position, |
| ) |
| |
| keep_mask = torch.full_like(scores, float("-inf")) |
| if topk_idx.dim() == 3: |
| topk_idx = topk_idx.unsqueeze(1).expand(B, self.num_heads, S, -1) |
| keep_mask.scatter_(-1, topk_idx, 0.0) |
| scores = scores + keep_mask |
|
|
| probs = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(hidden_states) |
|
|
| |
| |
| v_from_comp = torch.matmul(kv_comp, w_v.transpose(-1, -2)) |
| attn_output = torch.matmul(probs, v_from_comp) |
|
|
| |
| attn_output = attn_output.transpose(1, 2).reshape(B, S, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, None |
|
|
|
|
|
|
| class DeepseekV32MLP(nn.Module): |
| def __init__(self, config, intermediate_size=None): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| class DeepseekV32TopkRouter(nn.Module): |
| def __init__(self, config: DeepseekV32Config): |
| super().__init__() |
| self.config = config |
| self.top_k = config.num_experts_per_tok |
| self.n_routed_experts = config.n_routed_experts |
| self.routed_scaling_factor = config.routed_scaling_factor |
| self.n_group = config.n_group |
| self.topk_group = config.topk_group |
| self.norm_topk_prob = config.norm_topk_prob |
|
|
| self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) |
| self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) |
|
|
| def forward(self, hidden_states): |
| hidden_states = hidden_states.view(-1, self.config.hidden_size) |
| router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) |
| return router_logits |
|
|
|
|
| class DeepseekV32MoE(nn.Module): |
| """ |
| A mixed expert module containing shared experts. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.experts = nn.ModuleList( |
| [ |
| DeepseekV32MLP(config, intermediate_size=config.moe_intermediate_size) |
| for _ in range(config.n_routed_experts) |
| ] |
| ) |
| self.gate = DeepseekV32TopkRouter(config) |
| self.shared_experts = DeepseekV32MLP( |
| config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts |
| ) |
|
|
| def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): |
| r""" |
| CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused |
| to not have to do a loop here (deepseek has 256 experts soooo yeah). |
| """ |
| final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) |
| expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) |
| expert_mask = expert_mask.permute(2, 0, 1) |
|
|
| for expert_idx in range(len(self.experts)): |
| expert = self.experts[expert_idx] |
| mask = expert_mask[expert_idx] |
| token_indices, weight_indices = torch.where(mask) |
|
|
| if token_indices.numel() > 0: |
| expert_weights = topk_weights[token_indices, weight_indices] |
| expert_input = hidden_states[token_indices] |
| expert_output = expert(expert_input) |
| weighted_output = expert_output * expert_weights.unsqueeze(-1) |
| final_hidden_states.index_add_(0, token_indices, weighted_output) |
|
|
| |
| |
| |
| return final_hidden_states.type(hidden_states.dtype) |
|
|
| def forward(self, hidden_states): |
| residuals = hidden_states |
| orig_shape = hidden_states.shape |
| topk_indices, topk_weights = self.gate(hidden_states) |
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) |
| hidden_states = hidden_states + self.shared_experts(residuals) |
| return hidden_states |
|
|
|
|
|
|
| class DeepseekV32DecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: DeepseekV32Config, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = DeepseekV32Attention(config, layer_idx) |
|
|
| if layer_idx >= config.first_k_dense_replace: |
| self.mlp = DeepseekV32MoE(config) |
| else: |
| self.mlp = DeepseekV32MLP(config) |
|
|
| self.input_layernorm = DeepseekV32RMSNorm(config.hidden_size, config.rms_norm_eps) |
| self.post_attention_layernorm = DeepseekV32RMSNorm(config.hidden_size, config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| use_cache: bool | None = False, |
| cache_position: torch.LongTensor | None = None, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| class DeepseekV32PreTrainedModel(PreTrainedModel): |
| config: DeepseekV32Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["DeepseekV32DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states":DeepseekV32DecoderLayer, |
| "attentions": DeepseekV32Attention, |
| } |
|
|
| @torch.no_grad() |
| def _init_weights(self, module): |
| super()._init_weights(module) |
| if isinstance(module, DeepseekV32TopkRouter): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
|
|
| class DeepseekV32RotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: DeepseekV32Config, device=None): |
| super().__init__() |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
|
|
| self.rope_type = self.config.rope_scaling.get("rope_type", "default") |
| rope_init_fn: Callable = self.compute_default_rope_parameters |
| if self.rope_type != "default": |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) |
|
|
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) |
|
|
| @staticmethod |
| def compute_default_rope_parameters( |
| config: DeepseekV32Config | None = None, |
| device: Optional["torch.device"] = None, |
| seq_len: int | None = None, |
| ) -> tuple["torch.Tensor", float]: |
| """ |
| Computes the inverse frequencies according to the original RoPE implementation |
| Args: |
| config ([`~transformers.PreTrainedConfig`]): |
| The model configuration. |
| device (`torch.device`): |
| The device to use for initialization of the inverse frequencies. |
| seq_len (`int`, *optional*): |
| The current sequence length. Unused for this type of RoPE. |
| Returns: |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the |
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). |
| """ |
| base = config.rope_theta |
| partial_rotary_factor = config.rope_scaling.get("partial_rotary_factor", 1.0) |
| head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
| dim = int(head_dim * partial_rotary_factor) |
|
|
| attention_factor = 1.0 |
|
|
| |
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) |
| ) |
| return inv_freq, attention_factor |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| class DeepseekV32Model(DeepseekV32PreTrainedModel): |
| _keys_to_ignore_on_load_unexpected = [r"model\.layers\.78.*"] |
|
|
| def __init__(self, config: DeepseekV32Config): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [DeepseekV32DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = DeepseekV32RotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| @check_model_inputs |
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| cache_position: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position: torch.Tensor = ( |
| torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_embeddings=position_embeddings, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| class DeepseekV32ForCausalLM(DeepseekV32PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
| _tp_plan = {"lm_head": "colwise_gather_output"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = DeepseekV32Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| labels: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| cache_position: torch.LongTensor | None = None, |
| logits_to_keep: int | torch.Tensor = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| __all__ = ["DeepseekV32PreTrainedModel", "DeepseekV32Model", "DeepseekV32ForCausalLM"] |
|
|