Mingke977 commited on
Commit
2a3ee0b
·
verified ·
1 Parent(s): f2121c6

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/joyai-logo.png filter=lfs diff=lfs merge=lfs -text
.vscode/settings.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python-envs.pythonProjects": [
3
+ {
4
+ "path": ".",
5
+ "envManager": "ms-python.python:venv",
6
+ "packageManager": "ms-python.python:pip"
7
+ }
8
+ ]
9
+ }
docs/deploy_guidance.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Deployment Guide
2
+
3
+ > [!Note]
4
+ > This guide offers a selection of deployment command examples for JoyAI-LLM Flash, which may not be the optimal configuration. Given the rapid evolution of inference engines, we recommend referring to their official documentation for the latest updates to ensure peak performance.
5
+
6
+ > Support for JoyAI-LLM Flash’s dense MTP architecture is currently being integrated into vLLM and SGLang. Until these PRs are merged into a stable release, please use the nightly Docker image for access to these features.
7
+
8
+ ## vLLM Deployment
9
+
10
+ Here is the example to serve this model on a single GPU card via vLLM:
11
+
12
+ 1. pull the Docker image.
13
+ ```bash
14
+ docker pull jdopensource/joyai-llm-vllm:v0.15.1-joyai_llm_flash
15
+ ```
16
+ 2. launch JoyAI-LLM Flash model with dense MTP.
17
+ ```bash
18
+ vllm serve jdopensource/JoyAI-LLM-Flash-INT4 --tp 1 --trust-remote-code \
19
+ --tool-call-parser qwen3_coder --enable-auto-tool-choice \
20
+ --speculative-config $'{"method": "mtp", "num_speculative_tokens": 3}'
21
+ ```
22
+ **Key notes**
23
+ - `--tool-call-parser qwen3_coder`: Required for enabling tool calling
24
+
25
+ ## SGLang Deployment
26
+
27
+ Similarly, here is the example to run on a single GPU card via SGLang:
28
+
29
+ 1. pull the Docker image.
30
+ ```bash
31
+ docker pull jdopensource/joyai-llm-sglang:v0.5.8-joyai_llm_flash
32
+ ```
33
+ 2. launch JoyAI-LLM Flash model with dense MTP.
34
+
35
+ ```bash
36
+ python3 -m sglang.launch_server --model-path jdopensource/JoyAI-LLM-Flash-INT4 --tp-size 1 --trust-remote-code \
37
+ --tool-call-parser qwen3_coder \
38
+ --speculative-algorithm EAGLE \
39
+ --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4
40
+ ```
41
+ **Key notes:**
42
+ - `--tool-call-parser qwen3_coder`: Required when enabling tool usage.
figures/joyai-logo.png ADDED

Git LFS Details

  • SHA256: 4ea9d6a20a7707ca8dc427d6dcb5db6e2489f7730d5bffea26d8db20b1c54365
  • Pointer size: 131 Bytes
  • Size of remote file: 250 kB
model-1-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00342c45cd62e28fe183e2529ce61e2cecf0d8ea5451b2a8fe4137ae5e50e901
3
+ size 140785016
model-12-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79507569a9d651fd9a219a6180ba8b93f5766d2e618b7924788db129dd74f290
3
+ size 818458104
model-13-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeeb6a180a6dff03c038a0ddc8805ba306210b64105d4b03f101b6cfdae06b7c
3
+ size 818458104
model-16-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:450c4b413c061e2a8e833931b0e8d96d8c225e88de2bb20095e9fd11be6c1ba3
3
+ size 818458104
model-17-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0a257aba4913146e4b38afd0ecda70da43a5d4a16183bea50d16c485941e90c
3
+ size 818458104
model-18-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33f0bdfef6e90798925e0a8e23698e3ad51493d2b79ca47161e38772e9b830db
3
+ size 818458104
model-19-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7438a2ede536f77eb835413eea60b9e2d03ad21ccd2e4f10720cd5e1f411364
3
+ size 818458104
model-22-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4d6023ef3af8790fa1239e07d507f70ce51d839c5244cdfbcdc62c151341753
3
+ size 818458104
model-23-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:095d872d4fc5e06776d7caca99f15b97ead640591b710616771746fe1bf5feb4
3
+ size 818458104
model-24-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa4902f051050dd6b9f504ac3a50799b4115d80601daff86bc175f0fb8e9203
3
+ size 818458104
model-3-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d7f3c057be0e324901ba437e7123b38d064befc3cac8e43c0a0242e42809e30
3
+ size 818455784
model-30-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ab178211b3c63f8a3f1a012b41562c2ed7155040eb18dc0dae428d5d0257ef7
3
+ size 818458104
model-32-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acaf4effa401cedc72705ac7b96ea608d2389ae75a31d577fc52f80d3529de84
3
+ size 818458104
model-34-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17926839ef263a68a4b7a2ec0e414f2d9b2954604c67c6292c540e50b40b9bd1
3
+ size 818458104
model-35-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4c25a70260109e1bc2f7444da247b6e502e7e1b4ed151012f90cc8e05646e53
3
+ size 818458104
model-36-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a27401e067cb13952bb25122ce5af25e687ec7d2f5c0bd95b5e4f69ea56ea8
3
+ size 818458104
model-37-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12a610b42a9e5d35b3d030499f2d70468c5dbece65e38eb05c7cc9e5da77837d
3
+ size 818458104
model-38-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63f32214047b459c0a56f5c7fddbf5e329d573eebf790a4ea4bc974324b02c15
3
+ size 818458104
model-39-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:658cd299896798e797c7848ebf5ab161501266377a3f09ffab0b531fc2abe518
3
+ size 818458104
model-4-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:676aff1c56518d5e9d439297918b8a7c8a9d2a7f9aaef4b0f8f8d7dc34a0c2cc
3
+ size 818455784
model-40-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b66d3df132309f6c9eebf9b44acb04a8ee1a14de5022b04689f3b17e0432cca7
3
+ size 818458104
model-5-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:355cb6bc5edf28e6336b5c1cec74e085ab678518a213d360f5ff11096bc9f39b
3
+ size 818455784
model-6-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ace27ba26f0132f963cc111910c0e33387e657b279811dc7eceac0d886aea953
3
+ size 818455784
model-7-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a178f673023324cc99ddb08df93101ab454af7f1a7fd05c85b1773927b1e32b0
3
+ size 818455784
model-8-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4067f851d46d6ffb734d23d26709ec2f08bf4002fe79fd302a9ce8a6aa52a6b1
3
+ size 818455784
model-9-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b4463d3d33b5c558417d3d696849417a6dc77c650cdfd23c3774d6349a7577b
3
+ size 818455784
model-non-layer.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd760d732c11c23778a0dbf2280b62431d77d1f4ebc4f01f111cf716786981f0
3
+ size 1059066184
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_deepseek.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_deepseek_v3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ import math
8
+ from functools import partial
9
+ from typing import Callable, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
17
+ from transformers.generation import GenerationMixin
18
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
19
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
20
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
21
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
22
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
23
+ from transformers.processing_utils import Unpack
24
+ from transformers.utils import (
25
+ LossKwargs,
26
+ add_start_docstrings,
27
+ add_start_docstrings_to_model_forward,
28
+ can_return_tuple,
29
+ is_torch_flex_attn_available,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.utils.deprecation import deprecate_kwarg
34
+ from .configuration_deepseek import DeepseekV3Config
35
+
36
+
37
+ if is_torch_flex_attn_available():
38
+ from torch.nn.attention.flex_attention import BlockMask
39
+
40
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
45
+
46
+
47
+ class DeepseekV3RMSNorm(nn.Module):
48
+ def __init__(self, hidden_size, eps=1e-6):
49
+ """
50
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
51
+ """
52
+ super().__init__()
53
+ self.weight = nn.Parameter(torch.ones(hidden_size))
54
+ self.variance_epsilon = eps
55
+
56
+ def forward(self, hidden_states):
57
+ input_dtype = hidden_states.dtype
58
+ hidden_states = hidden_states.to(torch.float32)
59
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
61
+ return self.weight * hidden_states.to(input_dtype)
62
+
63
+ def extra_repr(self):
64
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
65
+
66
+
67
+ class DeepseekV3RotaryEmbedding(nn.Module):
68
+ def __init__(self, config: DeepseekV3Config, device=None):
69
+ super().__init__()
70
+ # BC: "rope_type" was originally "type"
71
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
72
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
73
+ else:
74
+ self.rope_type = "default"
75
+ self.max_seq_len_cached = config.max_position_embeddings
76
+ self.original_max_seq_len = config.max_position_embeddings
77
+
78
+ self.config = config
79
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
80
+
81
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
82
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
83
+ self.original_inv_freq = self.inv_freq
84
+
85
+ @torch.no_grad()
86
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
87
+ def forward(self, x, position_ids):
88
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
89
+ position_ids_expanded = position_ids[:, None, :].float()
90
+
91
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
92
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
93
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
94
+ emb = torch.cat((freqs, freqs), dim=-1)
95
+ cos = emb.cos() * self.attention_scaling
96
+ sin = emb.sin() * self.attention_scaling
97
+
98
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
99
+
100
+
101
+ class DeepseekV3MLP(nn.Module):
102
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
103
+ super().__init__()
104
+ self.config = config
105
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
106
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
107
+
108
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
109
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
110
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
111
+ self.act_fn = ACT2FN[config.hidden_act]
112
+
113
+ def forward(self, x):
114
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
115
+ return down_proj
116
+
117
+
118
+ class DeepseekV3TopkRouter(nn.Module):
119
+ def __init__(self, config):
120
+ super().__init__()
121
+ self.config = config
122
+ self.top_k = config.num_experts_per_tok
123
+ self.n_routed_experts = config.n_routed_experts
124
+ self.routed_scaling_factor = config.routed_scaling_factor
125
+ self.n_group = config.n_group
126
+ self.topk_group = config.topk_group
127
+ self.norm_topk_prob = config.norm_topk_prob
128
+
129
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
130
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
131
+
132
+ @torch.no_grad()
133
+ def get_topk_indices(self, scores):
134
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
135
+ group_scores = (
136
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
137
+ .topk(2, dim=-1)[0]
138
+ .sum(dim=-1)
139
+ )
140
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
141
+ group_mask = torch.zeros_like(group_scores)
142
+ group_mask.scatter_(1, group_idx, 1)
143
+ score_mask = (
144
+ group_mask.unsqueeze(-1)
145
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
146
+ .reshape(-1, self.n_routed_experts)
147
+ )
148
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
149
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
150
+ return topk_indices
151
+
152
+ def forward(self, hidden_states):
153
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
154
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
155
+ scores = router_logits.sigmoid()
156
+ topk_indices = self.get_topk_indices(scores)
157
+ topk_weights = scores.gather(1, topk_indices)
158
+ if self.norm_topk_prob:
159
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
160
+ topk_weights /= denominator
161
+ topk_weights = topk_weights * self.routed_scaling_factor
162
+ return topk_indices, topk_weights
163
+
164
+
165
+ class DeepseekV3MoE(nn.Module):
166
+ """
167
+ A mixed expert module containing shared experts.
168
+ """
169
+
170
+ def __init__(self, config):
171
+ super().__init__()
172
+ self.config = config
173
+ self.experts = nn.ModuleList(
174
+ [
175
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
176
+ for _ in range(config.n_routed_experts)
177
+ ]
178
+ )
179
+ self.gate = DeepseekV3TopkRouter(config)
180
+ self.shared_experts = DeepseekV3MLP(
181
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
182
+ )
183
+
184
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
185
+ r"""
186
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
187
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
188
+ """
189
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
190
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
191
+ expert_mask = expert_mask.permute(2, 0, 1)
192
+
193
+ for expert_idx in range(len(self.experts)):
194
+ expert = self.experts[expert_idx]
195
+ mask = expert_mask[expert_idx]
196
+ token_indices, weight_indices = torch.where(mask)
197
+
198
+ if token_indices.numel() > 0:
199
+ expert_weights = topk_weights[token_indices, weight_indices]
200
+ expert_input = hidden_states[token_indices]
201
+ expert_output = expert(expert_input)
202
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
203
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
204
+
205
+ # in original deepseek, the output of the experts are gathered once we leave this module
206
+ # thus the moe module is itelsf an IsolatedParallel module
207
+ # and all expert are "local" meaning we shard but we don't gather
208
+ return final_hidden_states.type(hidden_states.dtype)
209
+
210
+ def forward(self, hidden_states):
211
+ residuals = hidden_states
212
+ orig_shape = hidden_states.shape
213
+ topk_indices, topk_weights = self.gate(hidden_states)
214
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
215
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
216
+ hidden_states = hidden_states + self.shared_experts(residuals)
217
+ return hidden_states
218
+
219
+
220
+ def rotate_half(x):
221
+ """Rotates half the hidden dims of the input."""
222
+ x1 = x[..., : x.shape[-1] // 2]
223
+ x2 = x[..., x.shape[-1] // 2 :]
224
+ return torch.cat((-x2, x1), dim=-1)
225
+
226
+
227
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
228
+ """Applies Rotary Position Embedding to the query and key tensors.
229
+
230
+ Args:
231
+ q (`torch.Tensor`): The query tensor.
232
+ k (`torch.Tensor`): The key tensor.
233
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
234
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
235
+ position_ids (`torch.Tensor`, *optional*):
236
+ Deprecated and unused.
237
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
238
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
239
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
240
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
241
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
242
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
243
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
244
+ Returns:
245
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
246
+ """
247
+ cos = cos.unsqueeze(unsqueeze_dim)
248
+ sin = sin.unsqueeze(unsqueeze_dim)
249
+ q_embed = (q * cos) + (rotate_half(q) * sin)
250
+ k_embed = (k * cos) + (rotate_half(k) * sin)
251
+ return q_embed, k_embed
252
+
253
+
254
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
255
+ """
256
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
257
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
258
+ """
259
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
260
+ if n_rep == 1:
261
+ return hidden_states
262
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
263
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
264
+
265
+
266
+ def eager_attention_forward(
267
+ module: nn.Module,
268
+ query: torch.Tensor,
269
+ key: torch.Tensor,
270
+ value: torch.Tensor,
271
+ attention_mask: Optional[torch.Tensor],
272
+ scaling: float,
273
+ dropout: float = 0.0,
274
+ **kwargs,
275
+ ):
276
+ key_states = repeat_kv(key, module.num_key_value_groups)
277
+ value_states = repeat_kv(value, module.num_key_value_groups)
278
+
279
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
280
+ if attention_mask is not None:
281
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
282
+ attn_weights = attn_weights + causal_mask
283
+
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
285
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
286
+ attn_output = torch.matmul(attn_weights, value_states)
287
+ attn_output = attn_output.transpose(1, 2).contiguous()
288
+
289
+ return attn_output, attn_weights
290
+
291
+
292
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
293
+ r"""
294
+ TODO let's just use the original freqcis computation to not have the view
295
+ transpose + reshape! This is not optimized!
296
+ Applies Rotary Position Embedding to the query and key tensors.
297
+
298
+ Args:
299
+ q (`torch.Tensor`): The query tensor.
300
+ k (`torch.Tensor`): The key tensor.
301
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
302
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
303
+ position_ids (`torch.Tensor`):
304
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
305
+ used to pass offsetted position ids when working with a KV-cache.
306
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
307
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
308
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
309
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
310
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
311
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
312
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
313
+ Returns:
314
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
315
+ """
316
+ cos = cos.unsqueeze(unsqueeze_dim)
317
+ sin = sin.unsqueeze(unsqueeze_dim)
318
+
319
+ b, h, s, d = q.shape
320
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
321
+
322
+ b, h, s, d = k.shape
323
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
324
+
325
+ q_embed = (q * cos) + (rotate_half(q) * sin)
326
+ k_embed = (k * cos) + (rotate_half(k) * sin)
327
+ return q_embed, k_embed
328
+
329
+
330
+ def yarn_get_mscale(scale=1, mscale=1):
331
+ if scale <= 1:
332
+ return 1.0
333
+ return 0.1 * mscale * math.log(scale) + 1.0
334
+
335
+
336
+ class DeepseekV3Attention(nn.Module):
337
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
338
+
339
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
340
+ super().__init__()
341
+ self.config = config
342
+ self.layer_idx = layer_idx
343
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
344
+ self.attention_dropout = config.attention_dropout
345
+ self.num_heads = config.num_attention_heads
346
+ self.rope_theta = config.rope_theta
347
+ self.q_lora_rank = config.q_lora_rank
348
+ self.qk_rope_head_dim = config.qk_rope_head_dim
349
+ self.kv_lora_rank = config.kv_lora_rank
350
+ self.v_head_dim = config.v_head_dim
351
+ self.qk_nope_head_dim = config.qk_nope_head_dim
352
+ self.qk_head_dim = config.qk_head_dim
353
+
354
+ self.is_causal = True
355
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
356
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
357
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
358
+
359
+ self.kv_a_proj_with_mqa = nn.Linear(
360
+ config.hidden_size,
361
+ self.kv_lora_rank + self.qk_rope_head_dim,
362
+ bias=config.attention_bias,
363
+ )
364
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
365
+ self.kv_b_proj = nn.Linear(
366
+ self.kv_lora_rank,
367
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
368
+ bias=False,
369
+ )
370
+
371
+ self.o_proj = nn.Linear(
372
+ self.num_heads * self.v_head_dim,
373
+ config.hidden_size,
374
+ bias=config.attention_bias,
375
+ )
376
+
377
+ self.scaling = self.qk_head_dim ** (-0.5)
378
+ if self.config.rope_scaling is not None:
379
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
380
+ scaling_factor = self.config.rope_scaling["factor"]
381
+ if mscale_all_dim:
382
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
383
+ self.scaling = self.scaling * mscale * mscale
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
389
+ attention_mask: Optional[torch.Tensor],
390
+ past_key_value: Optional[Cache] = None,
391
+ cache_position: Optional[torch.LongTensor] = None,
392
+ **kwargs: Unpack[FlashAttentionKwargs],
393
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
394
+ batch_size, seq_length = hidden_states.shape[:-1]
395
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
396
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
397
+
398
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
399
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
400
+
401
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
402
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
403
+
404
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
405
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
406
+
407
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
408
+
409
+ cos, sin = position_embeddings
410
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
411
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
412
+ else:
413
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
414
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
415
+
416
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
417
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
418
+
419
+ if past_key_value is not None:
420
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
421
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
422
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
423
+
424
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
425
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
426
+
427
+ attention_interface: Callable = eager_attention_forward
428
+ if self.config._attn_implementation != "eager":
429
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
430
+ logger.warning_once(
431
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
432
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
433
+ )
434
+ else:
435
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
436
+
437
+ attn_output, attn_weights = attention_interface(
438
+ self,
439
+ query_states,
440
+ key_states,
441
+ value_states,
442
+ attention_mask,
443
+ dropout=0.0 if not self.training else self.attention_dropout,
444
+ scaling=self.scaling,
445
+ **kwargs,
446
+ )
447
+
448
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
449
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
450
+
451
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
452
+ attn_output = self.o_proj(attn_output)
453
+ return attn_output, attn_weights
454
+
455
+
456
+ class DeepseekV3DecoderLayer(nn.Module):
457
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
458
+ super().__init__()
459
+ self.hidden_size = config.hidden_size
460
+
461
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
462
+
463
+ if layer_idx >= config.first_k_dense_replace:
464
+ self.mlp = DeepseekV3MoE(config)
465
+ else:
466
+ self.mlp = DeepseekV3MLP(config)
467
+
468
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
469
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
470
+
471
+ def forward(
472
+ self,
473
+ hidden_states: torch.Tensor,
474
+ attention_mask: Optional[torch.Tensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ past_key_value: Optional[Cache] = None,
477
+ output_attentions: Optional[bool] = False,
478
+ use_cache: Optional[bool] = False,
479
+ cache_position: Optional[torch.LongTensor] = None,
480
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
481
+ **kwargs: Unpack[FlashAttentionKwargs],
482
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
483
+ residual = hidden_states
484
+
485
+ hidden_states = self.input_layernorm(hidden_states)
486
+
487
+ # Self Attention
488
+ hidden_states, self_attn_weights = self.self_attn(
489
+ hidden_states=hidden_states,
490
+ attention_mask=attention_mask,
491
+ position_ids=position_ids,
492
+ past_key_value=past_key_value,
493
+ output_attentions=output_attentions,
494
+ use_cache=use_cache,
495
+ cache_position=cache_position,
496
+ position_embeddings=position_embeddings,
497
+ **kwargs,
498
+ )
499
+ hidden_states = residual + hidden_states
500
+
501
+ # Fully Connected
502
+ residual = hidden_states
503
+ hidden_states = self.post_attention_layernorm(hidden_states)
504
+ hidden_states = self.mlp(hidden_states)
505
+ hidden_states = residual + hidden_states
506
+
507
+ outputs = (hidden_states,)
508
+ if output_attentions:
509
+ outputs += (self_attn_weights,)
510
+
511
+ return outputs
512
+
513
+
514
+ DEEPSEEK_V3_START_DOCSTRING = r"""
515
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
516
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
517
+ etc.)
518
+
519
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
520
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
521
+ and behavior.
522
+
523
+ Parameters:
524
+ config ([`DeepseekV3Config`]):
525
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
526
+ load the weights associated with the model, only the configuration. Check out the
527
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
528
+ """
529
+
530
+
531
+ @add_start_docstrings(
532
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
533
+ DEEPSEEK_V3_START_DOCSTRING,
534
+ )
535
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
536
+ config_class = DeepseekV3Config
537
+ base_model_prefix = "model"
538
+ supports_gradient_checkpointing = True
539
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
540
+ _skip_keys_device_placement = ["past_key_values"]
541
+ _supports_flash_attn_2 = True
542
+ _supports_sdpa = True
543
+ _supports_flex_attn = True
544
+ _supports_cache_class = True
545
+ _supports_quantized_cache = True
546
+ _supports_static_cache = True
547
+ _supports_attention_backend = True
548
+
549
+ def _init_weights(self, module):
550
+ std = self.config.initializer_range
551
+ if isinstance(module, nn.Linear):
552
+ module.weight.data.normal_(mean=0.0, std=std)
553
+ if module.bias is not None:
554
+ module.bias.data.zero_()
555
+ elif isinstance(module, nn.Embedding):
556
+ module.weight.data.normal_(mean=0.0, std=std)
557
+ if module.padding_idx is not None:
558
+ module.weight.data[module.padding_idx].zero_()
559
+ elif isinstance(module, DeepseekV3TopkRouter):
560
+ module.weight.data.normal_(mean=0.0, std=std)
561
+ elif isinstance(module, nn.Parameter):
562
+ module.weight.data.normal_(mean=0.0, std=std)
563
+
564
+
565
+ DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
566
+ Args:
567
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
568
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
569
+ it.
570
+
571
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
572
+ [`PreTrainedTokenizer.__call__`] for details.
573
+
574
+ [What are input IDs?](../glossary#input-ids)
575
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
576
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
577
+
578
+ - 1 for tokens that are **not masked**,
579
+ - 0 for tokens that are **masked**.
580
+
581
+ [What are attention masks?](../glossary#attention-mask)
582
+
583
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
584
+ [`PreTrainedTokenizer.__call__`] for details.
585
+
586
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
587
+ `past_key_values`).
588
+
589
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
590
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
591
+ information on the default strategy.
592
+
593
+ - 1 indicates the head is **not masked**,
594
+ - 0 indicates the head is **masked**.
595
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
596
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
597
+ config.n_positions - 1]`.
598
+
599
+ [What are position IDs?](../glossary#position-ids)
600
+ past_key_values (`Cache`, *optional*):
601
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
602
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
603
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
604
+
605
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
606
+
607
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
608
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
609
+ of shape `(batch_size, sequence_length)`.
610
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
611
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
612
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
613
+ model's internal embedding lookup matrix.
614
+ use_cache (`bool`, *optional*):
615
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
616
+ `past_key_values`).
617
+ output_attentions (`bool`, *optional*):
618
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
619
+ tensors for more detail.
620
+ output_hidden_states (`bool`, *optional*):
621
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
622
+ more detail.
623
+ return_dict (`bool`, *optional*):
624
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
625
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
626
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
627
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
628
+ the complete sequence length.
629
+ """
630
+
631
+
632
+ @add_start_docstrings(
633
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
634
+ DEEPSEEK_V3_START_DOCSTRING,
635
+ )
636
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
637
+ """
638
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
639
+
640
+ Args:
641
+ config: DeepseekV3Config
642
+ """
643
+
644
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
645
+
646
+ def __init__(self, config: DeepseekV3Config):
647
+ super().__init__(config)
648
+ self.padding_idx = config.pad_token_id
649
+ self.vocab_size = config.vocab_size
650
+
651
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
652
+ self.layers = nn.ModuleList(
653
+ [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
654
+ )
655
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
656
+ self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
657
+ self.gradient_checkpointing = False
658
+
659
+ # Initialize weights and apply final processing
660
+ self.post_init()
661
+
662
+ def get_input_embeddings(self):
663
+ return self.embed_tokens
664
+
665
+ def set_input_embeddings(self, value):
666
+ self.embed_tokens = value
667
+
668
+ @can_return_tuple
669
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
670
+ def forward(
671
+ self,
672
+ input_ids: Optional[torch.LongTensor] = None,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_values: Optional[Cache] = None,
676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
677
+ use_cache: Optional[bool] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ cache_position: Optional[torch.LongTensor] = None,
681
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
682
+ ) -> BaseModelOutputWithPast:
683
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
684
+ output_hidden_states = (
685
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
686
+ )
687
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
688
+
689
+ if (input_ids is None) ^ (inputs_embeds is not None):
690
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
691
+
692
+ if self.gradient_checkpointing and self.training and use_cache:
693
+ logger.warning_once(
694
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
695
+ )
696
+ use_cache = False
697
+
698
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
699
+ if not isinstance(past_key_values, (type(None), Cache)):
700
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
701
+
702
+ if inputs_embeds is None:
703
+ inputs_embeds = self.embed_tokens(input_ids)
704
+
705
+ if use_cache and past_key_values is None:
706
+ past_key_values = DynamicCache()
707
+
708
+ if cache_position is None:
709
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
710
+ cache_position = torch.arange(
711
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
712
+ )
713
+
714
+ if position_ids is None:
715
+ position_ids = cache_position.unsqueeze(0)
716
+
717
+ causal_mask = self._update_causal_mask(
718
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
719
+ )
720
+
721
+ hidden_states = inputs_embeds
722
+
723
+ # create position embeddings to be shared across the decoder layers
724
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
725
+
726
+ # decoder layers
727
+ all_hidden_states = () if output_hidden_states else None
728
+ all_self_attns = () if output_attentions else None
729
+
730
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
731
+ if output_hidden_states:
732
+ all_hidden_states += (hidden_states,)
733
+
734
+ if self.gradient_checkpointing and self.training:
735
+ layer_outputs = self._gradient_checkpointing_func(
736
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
737
+ hidden_states,
738
+ causal_mask,
739
+ position_ids,
740
+ past_key_values,
741
+ output_attentions,
742
+ use_cache,
743
+ cache_position,
744
+ position_embeddings,
745
+ )
746
+ else:
747
+ layer_outputs = decoder_layer(
748
+ hidden_states,
749
+ attention_mask=causal_mask,
750
+ position_ids=position_ids,
751
+ past_key_value=past_key_values,
752
+ output_attentions=output_attentions,
753
+ use_cache=use_cache,
754
+ cache_position=cache_position,
755
+ position_embeddings=position_embeddings,
756
+ **flash_attn_kwargs,
757
+ )
758
+
759
+ hidden_states = layer_outputs[0]
760
+
761
+ if output_attentions:
762
+ all_self_attns += (layer_outputs[1],)
763
+
764
+ hidden_states = self.norm(hidden_states)
765
+
766
+ # add hidden states from the last decoder layer
767
+ if output_hidden_states:
768
+ all_hidden_states += (hidden_states,)
769
+
770
+ return BaseModelOutputWithPast(
771
+ last_hidden_state=hidden_states,
772
+ past_key_values=past_key_values if use_cache else None,
773
+ hidden_states=all_hidden_states,
774
+ attentions=all_self_attns,
775
+ )
776
+
777
+ def _update_causal_mask(
778
+ self,
779
+ attention_mask: torch.Tensor,
780
+ input_tensor: torch.Tensor,
781
+ cache_position: torch.Tensor,
782
+ past_key_values: Cache,
783
+ output_attentions: bool = False,
784
+ ):
785
+ if self.config._attn_implementation == "flash_attention_2":
786
+ if attention_mask is not None and (attention_mask == 0.0).any():
787
+ return attention_mask
788
+ return None
789
+ if self.config._attn_implementation == "flex_attention":
790
+ if isinstance(attention_mask, torch.Tensor):
791
+ attention_mask = make_flex_block_causal_mask(attention_mask)
792
+ if isinstance(attention_mask, BlockMask):
793
+ return attention_mask
794
+
795
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
796
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
797
+ # to infer the attention mask.
798
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
799
+ using_static_cache = isinstance(past_key_values, StaticCache)
800
+
801
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
802
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
803
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
804
+ attention_mask,
805
+ inputs_embeds=input_tensor,
806
+ past_key_values_length=past_seen_tokens,
807
+ is_training=self.training,
808
+ ):
809
+ return None
810
+
811
+ dtype, device = input_tensor.dtype, input_tensor.device
812
+ sequence_length = input_tensor.shape[1]
813
+ if using_static_cache:
814
+ target_length = past_key_values.get_max_cache_shape()
815
+ else:
816
+ target_length = (
817
+ attention_mask.shape[-1]
818
+ if isinstance(attention_mask, torch.Tensor)
819
+ else past_seen_tokens + sequence_length + 1
820
+ )
821
+
822
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
823
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
824
+ attention_mask,
825
+ sequence_length=sequence_length,
826
+ target_length=target_length,
827
+ dtype=dtype,
828
+ device=device,
829
+ cache_position=cache_position,
830
+ batch_size=input_tensor.shape[0],
831
+ )
832
+
833
+ if (
834
+ self.config._attn_implementation == "sdpa"
835
+ and attention_mask is not None
836
+ and attention_mask.device.type in ["cuda", "xpu"]
837
+ and not output_attentions
838
+ ):
839
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
840
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
841
+ # Details: https://github.com/pytorch/pytorch/issues/110213
842
+ min_dtype = torch.finfo(dtype).min
843
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
844
+
845
+ return causal_mask
846
+
847
+ @staticmethod
848
+ def _prepare_4d_causal_attention_mask_with_cache_position(
849
+ attention_mask: torch.Tensor,
850
+ sequence_length: int,
851
+ target_length: int,
852
+ dtype: torch.dtype,
853
+ device: torch.device,
854
+ cache_position: torch.Tensor,
855
+ batch_size: int,
856
+ **kwargs,
857
+ ):
858
+ """
859
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
860
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
861
+
862
+ Args:
863
+ attention_mask (`torch.Tensor`):
864
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
865
+ `(batch_size, 1, query_length, key_value_length)`.
866
+ sequence_length (`int`):
867
+ The sequence length being processed.
868
+ target_length (`int`):
869
+ The target length: when generating with static cache, the mask should be as long as the static cache,
870
+ to account for the 0 padding, the part of the cache that is not filled yet.
871
+ dtype (`torch.dtype`):
872
+ The dtype to use for the 4D attention mask.
873
+ device (`torch.device`):
874
+ The device to place the 4D attention mask on.
875
+ cache_position (`torch.Tensor`):
876
+ Indices depicting the position of the input sequence tokens in the sequence.
877
+ batch_size (`torch.Tensor`):
878
+ Batch size.
879
+ """
880
+ if attention_mask is not None and attention_mask.dim() == 4:
881
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
882
+ causal_mask = attention_mask
883
+ else:
884
+ min_dtype = torch.finfo(dtype).min
885
+ causal_mask = torch.full(
886
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
887
+ )
888
+ if sequence_length != 1:
889
+ causal_mask = torch.triu(causal_mask, diagonal=1)
890
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
891
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
892
+ if attention_mask is not None:
893
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
894
+ mask_length = attention_mask.shape[-1]
895
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
896
+ causal_mask.device
897
+ )
898
+ padding_mask = padding_mask == 0
899
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
900
+ padding_mask, min_dtype
901
+ )
902
+
903
+ return causal_mask
904
+
905
+
906
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
907
+
908
+
909
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
910
+ _tied_weights_keys = ["lm_head.weight"]
911
+ _tp_plan = {"lm_head": "colwise_rep"}
912
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
913
+
914
+ def __init__(self, config):
915
+ super().__init__(config)
916
+ self.model = DeepseekV3Model(config)
917
+ self.vocab_size = config.vocab_size
918
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
919
+
920
+ # Initialize weights and apply final processing
921
+ self.post_init()
922
+
923
+ def get_input_embeddings(self):
924
+ return self.model.embed_tokens
925
+
926
+ def set_input_embeddings(self, value):
927
+ self.model.embed_tokens = value
928
+
929
+ def get_output_embeddings(self):
930
+ return self.lm_head
931
+
932
+ def set_output_embeddings(self, new_embeddings):
933
+ self.lm_head = new_embeddings
934
+
935
+ def set_decoder(self, decoder):
936
+ self.model = decoder
937
+
938
+ def get_decoder(self):
939
+ return self.model
940
+
941
+ @can_return_tuple
942
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
943
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
944
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
945
+ def forward(
946
+ self,
947
+ input_ids: Optional[torch.LongTensor] = None,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.LongTensor] = None,
950
+ past_key_values: Optional[Cache] = None,
951
+ inputs_embeds: Optional[torch.FloatTensor] = None,
952
+ labels: Optional[torch.LongTensor] = None,
953
+ use_cache: Optional[bool] = None,
954
+ output_attentions: Optional[bool] = None,
955
+ output_hidden_states: Optional[bool] = None,
956
+ cache_position: Optional[torch.LongTensor] = None,
957
+ logits_to_keep: Union[int, torch.Tensor] = 0,
958
+ **kwargs: Unpack[KwargsForCausalLM],
959
+ ) -> CausalLMOutputWithPast:
960
+ r"""
961
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
962
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
963
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
964
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
965
+
966
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
967
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
968
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
969
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
970
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
971
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
972
+
973
+ Returns:
974
+
975
+ Example:
976
+
977
+ ```python
978
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
979
+
980
+ >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
981
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
982
+
983
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
984
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
985
+
986
+ >>> # Generate
987
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
988
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
989
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
990
+ ```"""
991
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
992
+ output_hidden_states = (
993
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
994
+ )
995
+
996
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
997
+ outputs: BaseModelOutputWithPast = self.model(
998
+ input_ids=input_ids,
999
+ attention_mask=attention_mask,
1000
+ position_ids=position_ids,
1001
+ past_key_values=past_key_values,
1002
+ inputs_embeds=inputs_embeds,
1003
+ use_cache=use_cache,
1004
+ output_attentions=output_attentions,
1005
+ output_hidden_states=output_hidden_states,
1006
+ cache_position=cache_position,
1007
+ **kwargs,
1008
+ )
1009
+
1010
+ hidden_states = outputs.last_hidden_state
1011
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1012
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1013
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1014
+
1015
+ loss = None
1016
+ if labels is not None:
1017
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1018
+
1019
+ return CausalLMOutputWithPast(
1020
+ loss=loss,
1021
+ logits=logits,
1022
+ past_key_values=outputs.past_key_values,
1023
+ hidden_states=outputs.hidden_states,
1024
+ attentions=outputs.attentions,
1025
+ )
1026
+
1027
+
1028
+ __all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"]
mtp-1-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e708907b5c5a584e0d81ebd2858ecf9f0f22798616a61fc273f0d39eac9512c0
3
+ size 687105960
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|begin▁of▁sentence|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|end▁of▁sentence|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 131072,
23
+ "pad_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<|▁pad▁|>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "sp_model_kwargs": {},
32
+ "unk_token": null,
33
+ "tokenizer_class": "LlamaTokenizerFast"
34
+ }