fangfangfang123 commited on
Commit
580d962
·
verified ·
1 Parent(s): 10dbb72

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE +27 -0
  3. README.md +374 -0
  4. chat_template.jinja +103 -0
  5. config.json +49 -0
  6. configuration.json +1 -0
  7. configuration_deepseek.py +247 -0
  8. docs/deploy_guidance.md +29 -0
  9. figures/joyai-logo.png +3 -0
  10. model-1-of-40.safetensors +3 -0
  11. model-10-of-40.safetensors +3 -0
  12. model-11-of-40.safetensors +3 -0
  13. model-12-of-40.safetensors +3 -0
  14. model-13-of-40.safetensors +3 -0
  15. model-14-of-40.safetensors +3 -0
  16. model-15-of-40.safetensors +3 -0
  17. model-16-of-40.safetensors +3 -0
  18. model-17-of-40.safetensors +3 -0
  19. model-18-of-40.safetensors +3 -0
  20. model-19-of-40.safetensors +3 -0
  21. model-21-of-40.safetensors +3 -0
  22. model-22-of-40.safetensors +3 -0
  23. model-23-of-40.safetensors +3 -0
  24. model-24-of-40.safetensors +3 -0
  25. model-25-of-40.safetensors +3 -0
  26. model-27-of-40.safetensors +3 -0
  27. model-28-of-40.safetensors +3 -0
  28. model-29-of-40.safetensors +3 -0
  29. model-3-of-40.safetensors +3 -0
  30. model-30-of-40.safetensors +3 -0
  31. model-31-of-40.safetensors +3 -0
  32. model-32-of-40.safetensors +3 -0
  33. model-33-of-40.safetensors +3 -0
  34. model-34-of-40.safetensors +3 -0
  35. model-35-of-40.safetensors +3 -0
  36. model-36-of-40.safetensors +3 -0
  37. model-37-of-40.safetensors +3 -0
  38. model-38-of-40.safetensors +3 -0
  39. model-39-of-40.safetensors +3 -0
  40. model-4-of-40.safetensors +3 -0
  41. model-40-of-40.safetensors +3 -0
  42. model-5-of-40.safetensors +3 -0
  43. model-6-of-40.safetensors +3 -0
  44. model-8-of-40.safetensors +3 -0
  45. model-9-of-40.safetensors +3 -0
  46. model-non-layer.safetensors +3 -0
  47. model.safetensors.index.json +0 -0
  48. modeling_deepseek.py +1028 -0
  49. mtp-1-of-1.safetensors +3 -0
  50. tokenizer.json +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/joyai-logo.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2026 JD AI
2
+
3
+ We offer you a license similar to the MIT License. In the event that the Software (or any derivative works
4
+ thereof) is incorporated into 1) any of your commercial products or services; or 2) any of your products or
5
+ services that either have more than 100 million monthly active users or generate more than 20 million US dollars
6
+ (or equivalent in other currencies) in monthly revenue, you are required to conspicuously display "JoyAI-LLM Flash"
7
+ on the user interface of such product or service.
8
+
9
+ ================================================================================
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the “Software”), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <picture>
3
+ <img src="figures/joyai-logo.png" width="30%" alt="JoyAI-LLM Flash">
4
+ </picture>
5
+ </div>
6
+ <hr>
7
+
8
+ <div align="center" style="line-height: 1;">
9
+ <a href="https://huggingface.co/jdopensource" target="_blank"><img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-JD-ffc107?color=ffc107&logoColor=white"/></a>
10
+ <a href="LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Modified_MIT-f5de53?&color=f5de53"/></a>
11
+ </div>
12
+
13
+
14
+
15
+
16
+ ## 1. Model Introduction
17
+
18
+ JoyAI-LLM Flash is a state-of-the-art medium-sized instruct language model with 3 billion activated parameters and 48 billion total parameters. JoyAI-LLM Flash was pretrained on 20 trillion text tokens using Muon optimizer, followed by large-scale supervised fine-tuning (SFT), direct preference optimization (DPO), and reinforcement learning (RL) across diverse environments. JoyAI-LLM Flash achieves strong performance across frontier knowledge, reasoning, coding tasks and agentic capabilities.
19
+
20
+ ### Key Features
21
+
22
+ - Fiber Bundle RL: invole geometric manifold theory into reinforcement learning, proposing an innovative technique known as FiberPO. This approach is designed to address the growing trends of increasing heterogeneous agent scales.
23
+ - Training-Inference Collaboration: apply Muon optimizer with dense MTP, develop novel optimization techniques to resolve instabilities while scaling up, delivering 1.3× to 1.7× the throughput of the non-MTP version.
24
+ - Agentic Intelligence: designed for tool use, reasoning, and autonomous problem-solving.
25
+
26
+ ## 2. Model Summary
27
+
28
+ | | |
29
+ | :-----------------------------------------: | :----------------------: |
30
+ | **Architecture** | Mixture-of-Experts (MoE) |
31
+ | **Total Parameters** | 48B |
32
+ | **Activated Parameters** | 3B |
33
+ | **Number of Layers** (Dense layer included) | 40 |
34
+ | **Number of Dense Layers** | 1 |
35
+ | **Attention Hidden Dimension** | 2048 |
36
+ | **MoE Hidden Dimension** (per Expert) | 768 |
37
+ | **Number of Attention Heads** | 32 |
38
+ | **Number of Experts** | 256 |
39
+ | **Selected Experts per Token** | 8 |
40
+ | **Number of Shared Experts** | 1 |
41
+ | **Vocabulary Size** | 129K |
42
+ | **Context Length** | 128K |
43
+ | **Attention Mechanism** | MLA |
44
+ | **Activation Function** | SwiGLU |
45
+ | </div> | |
46
+
47
+
48
+ ## 3. Evaluation Results
49
+
50
+ <table>
51
+ <thead>
52
+ <tr>
53
+ <th align="center">Benchmark</th>
54
+ <th align="center"><sup>JoyAI-LLM Flash</sup></th>
55
+ <th align="center"><sup>Qwen3-30B-A3B-Instuct-2507</sup></th>
56
+ <th align="center"><sup>GLM-4.7-Flash<br>(Non-thinking)</sup></th>
57
+ </tr>
58
+ </thead>
59
+ <tbody>
60
+
61
+
62
+ <tr>
63
+ <td align="center" colspan=8><strong>Knowledge &amp; Alignment</strong></td>
64
+ </tr>
65
+ <tr>
66
+ <td align="center" style="vertical-align: middle">MMLU</td>
67
+ <td align="center" style="vertical-align: middle"><strong>89.50</strong></td>
68
+ <td align="center" style="vertical-align: middle">86.87</td>
69
+ <td align="center" style="vertical-align: middle">80.53</td>
70
+ </tr>
71
+ <tr>
72
+ <td align="center" style="vertical-align: middle">MMLU-Pro</td>
73
+ <td align="center" style="vertical-align: middle"><strong>81.02</strong></td>
74
+ <td align="center" style="vertical-align: middle">73.88</td>
75
+ <td align="center" style="vertical-align: middle">63.62</td>
76
+ </tr>
77
+ <tr>
78
+ <td align="center" style="vertical-align: middle">CMMLU</td>
79
+ <td align="center" style="vertical-align: middle"><strong>87.03</strong></td>
80
+ <td align="center" style="vertical-align: middle">85.88</td>
81
+ <td align="center" style="vertical-align: middle">75.85</td>
82
+ </tr>
83
+ <tr>
84
+ <td align="center" style="vertical-align: middle">GPQA-Diamond</td>
85
+ <td align="center" style="vertical-align: middle"><strong>74.43</strong></td>
86
+ <td align="center" style="vertical-align: middle">68.69</td>
87
+ <td align="center" style="vertical-align: middle">39.90</td>
88
+ </tr>
89
+ <tr>
90
+ <td align="center" style="vertical-align: middle">SuperGPQA</td>
91
+ <td align="center" style="vertical-align: middle"><strong>55.00</strong></td>
92
+ <td align="center" style="vertical-align: middle">52.00</td>
93
+ <td align="center" style="vertical-align: middle">32.00</td>
94
+ </tr>
95
+ <tr>
96
+ <td align="center" style="vertical-align: middle">LiveBench</td>
97
+ <td align="center" style="vertical-align: middle"><strong>72.90</strong></td>
98
+ <td align="center" style="vertical-align: middle">59.70</td>
99
+ <td align="center" style="vertical-align: middle">43.10</td>
100
+ </tr>
101
+ <tr>
102
+ <td align="center" style="vertical-align: middle">IFEval</td>
103
+ <td align="center" style="vertical-align: middle"><strong>86.69</strong></td>
104
+ <td align="center" style="vertical-align: middle">83.18</td>
105
+ <td align="center" style="vertical-align: middle">82.44</td>
106
+ </tr>
107
+ <tr>
108
+ <td align="center" style="vertical-align: middle">AlignBench</td>
109
+ <td align="center" style="vertical-align: middle"><strong>8.24</strong></td>
110
+ <td align="center" style="vertical-align: middle">8.07</td>
111
+ <td align="center" style="vertical-align: middle">6.85</td>
112
+ </tr>
113
+ <tr>
114
+ <td align="center" style="vertical-align: middle">HellaSwag</td>
115
+ <td align="center" style="vertical-align: middle"><strong>91.79</strong></td>
116
+ <td align="center" style="vertical-align: middle">89.90</td>
117
+ <td align="center" style="vertical-align: middle">60.84</td>
118
+ </tr>
119
+
120
+ <tr>
121
+ <td align="center" colspan=8><strong>Coding</strong></td>
122
+ </tr>
123
+ <tr>
124
+ <td align="center" style="vertical-align: middle">HumanEval</td>
125
+ <td align="center" style="vertical-align: middle"><strong>96.34</strong></td>
126
+ <td align="center" style="vertical-align: middle">95.12</td>
127
+ <td align="center" style="vertical-align: middle">74.39</td>
128
+ </tr>
129
+ <tr>
130
+ <td align="center" style="vertical-align: middle">LiveCodeBench</td>
131
+ <td align="center" style="vertical-align: middle"><strong>65.60</strong></td>
132
+ <td align="center" style="vertical-align: middle">39.71</td>
133
+ <td align="center" style="vertical-align: middle">27.43</td>
134
+ </tr>
135
+ <tr>
136
+ <td align="center" style="vertical-align: middle">SciCode</td>
137
+ <td align="center" style="vertical-align: middle"><strong>3.08/22.92</strong></td>
138
+ <td align="center" style="vertical-align: middle"><strong>3.08/22.92</strong></td>
139
+ <td align="center" style="vertical-align: middle">3.08/15.11</td>
140
+ </tr>
141
+ <tr>
142
+ <td align="center" colspan=8><strong>Mathematics</strong></td>
143
+ </tr>
144
+ <tr>
145
+ <td align="center" style="vertical-align: middle">GSM8K</td>
146
+ <td align="center" style="vertical-align: middle"><strong>95.83</strong></td>
147
+ <td align="center" style="vertical-align: middle">79.83</td>
148
+ <td align="center" style="vertical-align: middle">81.88</td>
149
+ </tr>
150
+ <tr>
151
+ <td align="center" style="vertical-align: middle">AIME2025</td>
152
+ <td align="center" style="vertical-align: middle"><strong>65.83</strong></td>
153
+ <td align="center" style="vertical-align: middle">62.08</td>
154
+ <td align="center" style="vertical-align: middle">24.17</td>
155
+ </tr>
156
+ <tr>
157
+ <td align="center" style="vertical-align: middle">MATH 500</td>
158
+ <td align="center" style="vertical-align: middle"><strong>97.10</strong></td>
159
+ <td align="center" style="vertical-align: middle">89.80</td>
160
+ <td align="center" style="vertical-align: middle">90.90</td>
161
+ </tr>
162
+
163
+ <tr>
164
+ <td align="center" colspan=8><strong>Agentic</strong></td>
165
+ </tr>
166
+ <tr>
167
+ <td align="center" style="vertical-align: middle">SWE-bench Verified</td>
168
+ <td align="center" style="vertical-align: middle"><strong>60.60</strong></td>
169
+ <td align="center" style="vertical-align: middle">24.44</td>
170
+ <td align="center" style="vertical-align: middle">51.60</td>
171
+ </tr>
172
+ <tr>
173
+ <td align="center" style="vertical-align: middle">Tau2-Retail</td>
174
+ <td align="center" style="vertical-align: middle"><strong>67.55</strong></td>
175
+ <td align="center" style="vertical-align: middle">53.51</td>
176
+ <td align="center" style="vertical-align: middle">62.28</td>
177
+ </tr>
178
+ <tr>
179
+ <td align="center" style="vertical-align: middle">Tau2-Airline</td>
180
+ <td align="center" style="vertical-align: middle"><strong>54.00</strong></td>
181
+ <td align="center" style="vertical-align: middle">32.00</td>
182
+ <td align="center" style="vertical-align: middle">52.00</td>
183
+ </tr>
184
+ <tr>
185
+ <td align="center" style="vertical-align: middle">Tau2-Telecom</td>
186
+ <td align="center" style="vertical-align: middle">79.83</td>
187
+ <td align="center" style="vertical-align: middle">4.39</td>
188
+ <td align="center" style="vertical-align: middle"><strong>88.60</strong></td>
189
+ </tr>
190
+
191
+ <tr>
192
+ <td align="center" colspan=8><strong>Long Context</strong></td>
193
+ </tr>
194
+ <tr>
195
+ <td align="center" style="vertical-align: middle">RULER</td>
196
+ <td align="center" style="vertical-align: middle"><strong>95.60</strong></td>
197
+ <td align="center" style="vertical-align: middle">89.66</td>
198
+ <td align="center" style="vertical-align: middle">56.12</td>
199
+ </tr>
200
+ </tbody>
201
+ </table>
202
+
203
+
204
+ ## 4. Deployment
205
+
206
+ > [!Note]
207
+ > You can access JoyAI-LLM Flash API on https://docs.jdcloud.com/cn/jdaip/chat and we provide OpenAI/Anthropic-compatible API for you.
208
+ > Currently, JoyAI-LLM Flash is recommended to run on the following inference engines:
209
+
210
+ * vLLM
211
+ * SGLang
212
+
213
+ The minimum version requirement for `transformers` is `4.57.1`.
214
+
215
+ Deployment examples can be found in the [Model Deployment Guide](docs/deploy_guidance.md).
216
+
217
+
218
+
219
+ ## 5. Model Usage
220
+
221
+ The usage demos below demonstrate how to call our official API.
222
+
223
+ For third-party APIs deployed with vLLM or SGLang, please note that:
224
+
225
+ > [!Note] Recommended sampling parameters: `temperature=0.6`, `top_p=1.0`
226
+
227
+ ### Chat Completion
228
+
229
+ This is a simple chat completion script which shows how to call JoyAI-Flash API.
230
+
231
+ ```python
232
+ from openai import OpenAI
233
+
234
+ client = OpenAI(base_url="http://IP:PORT/v1", api_key="EMPTY")
235
+
236
+
237
+ def simple_chat(client: OpenAI):
238
+ messages = [
239
+ {
240
+ "role": "user",
241
+ "content": [
242
+ {
243
+ "type": "text",
244
+ "text": "which one is bigger, 9.11 or 9.9? think carefully.",
245
+ }
246
+ ],
247
+ },
248
+ ]
249
+ model_name = client.models.list().data[0].id
250
+ response = client.chat.completions.create(
251
+ model=model_name, messages=messages, stream=False, max_tokens=4096
252
+ )
253
+ print(f"response: {response.choices[0].message.content}")
254
+
255
+
256
+ if __name__ == "__main__":
257
+ simple_chat(client)
258
+ ```
259
+
260
+
261
+ ### Tool call Completion
262
+
263
+ This is a simple toll call completion script which shows how to call JoyAI-Flash API.
264
+
265
+ ```python
266
+ import json
267
+
268
+ from openai import OpenAI
269
+
270
+ client = OpenAI(base_url="http://IP:PORT/v1", api_key="EMPTY")
271
+
272
+
273
+ def my_calculator(expression: str) -> str:
274
+ return str(eval(expression))
275
+
276
+
277
+ def rewrite(expression: str) -> str:
278
+ return str(expression)
279
+
280
+
281
+ def simple_tool_call(client: OpenAI):
282
+ messages = [
283
+ {
284
+ "role": "user",
285
+ "content": [
286
+ {
287
+ "type": "text",
288
+ "text": "use my functions to compute the results for the equations: 6+1",
289
+ },
290
+ ],
291
+ },
292
+ ]
293
+ tools = [
294
+ {
295
+ "type": "function",
296
+ "function": {
297
+ "name": "my_calculator",
298
+ "description": "A calculator that can evaluate a mathematical equation and compute its results.",
299
+ "parameters": {
300
+ "type": "object",
301
+ "properties": {
302
+ "expression": {
303
+ "type": "string",
304
+ "description": "The mathematical expression to evaluate.",
305
+ },
306
+ },
307
+ "required": ["expression"],
308
+ },
309
+ },
310
+ },
311
+ {
312
+ "type": "function",
313
+ "function": {
314
+ "name": "rewrite",
315
+ "description": "Rewrite a given text for improved clarity",
316
+ "parameters": {
317
+ "type": "object",
318
+ "properties": {
319
+ "text": {
320
+ "type": "string",
321
+ "description": "The input text to rewrite",
322
+ }
323
+ },
324
+ },
325
+ },
326
+ },
327
+ ]
328
+ model_name = client.models.list().data[0].id
329
+ response = client.chat.completions.create(
330
+ model=model_name,
331
+ messages=messages,
332
+ temperature=1.0,
333
+ max_tokens=1024,
334
+ tools=tools,
335
+ tool_choice="auto",
336
+ )
337
+ tool_calls = response.choices[0].message.tool_calls
338
+
339
+ results = []
340
+ for tool_call in tool_calls:
341
+ function_name = tool_call.function.name
342
+ function_args = tool_call.function.arguments
343
+ if function_name == "my_calculator":
344
+ result = my_calculator(**json.loads(function_args))
345
+ results.append(result)
346
+ messages.append({"role": "assistant", "tool_calls": tool_calls})
347
+ for tool_call, result in zip(tool_calls, results):
348
+ messages.append(
349
+ {
350
+ "role": "tool",
351
+ "tool_call_id": tool_call.id,
352
+ "name": tool_call.function.name,
353
+ "content": result,
354
+ }
355
+ )
356
+ response = client.chat.completions.create(
357
+ model=model_name,
358
+ messages=messages,
359
+ temperature=1.0,
360
+ max_tokens=1024,
361
+ )
362
+ print(response.choices[0].message.content)
363
+
364
+
365
+ if __name__ == "__main__":
366
+ simple_tool_call(client)
367
+
368
+ ```
369
+
370
+ ---
371
+
372
+ ## 6. License
373
+
374
+ Both the code repository and the model weights are released under the [Modified MIT License](LICENSE).
chat_template.jinja ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- macro render_extra_keys(json_dict, handled_keys) -%}
2
+ {%- if json_dict is mapping -%}
3
+ {%- for json_key in json_dict if json_key not in handled_keys -%}
4
+ {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) -%}
5
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' -}}
6
+ {%- else -%}
7
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' -}}
8
+ {%- endif -%}
9
+ {%- endfor -%}
10
+ {%- endif -%}
11
+ {%- endmacro -%}
12
+
13
+ {%- if not add_generation_prompt is defined -%}{%- set add_generation_prompt = false -%}{%- endif -%}
14
+
15
+ {%- set ns = namespace(system_prompt='', is_first_sp=true, is_last_user=false) -%}
16
+ {%- set default_system = "You are JoyAI , a large language model trained by JD(京东)that can interact with a computer to solve tasks. Answer as concisely as possible." -%}
17
+ {%- set ns.system_prompt = default_system -%}
18
+
19
+ {%- for message in messages -%}
20
+ {%- if message['role'] == 'system' -%}
21
+ {%- if ns.is_first_sp -%}
22
+ {%- set ns.system_prompt = message['content'] -%}
23
+ {%- set ns.is_first_sp = false -%}
24
+ {%- else -%}
25
+ {%- set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] -%}
26
+ {%- endif -%}
27
+ {%- endif -%}
28
+ {%- endfor -%}
29
+
30
+ {{- bos_token -}}{{- ns.system_prompt -}}
31
+ {%- if tools is iterable and tools | length > 0 -%}
32
+ {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
33
+ {{- "<tools>" }}
34
+ {%- for tool in tools %}
35
+ {%- if tool.function is defined %}
36
+ {%- set tool = tool.function %}
37
+ {%- endif %}
38
+ {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
39
+ {%- if tool.description is defined %}
40
+ {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
41
+ {%- endif %}
42
+ {{- '\n<parameters>' }}
43
+ {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
44
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
45
+ {{- '\n<parameter>' }}
46
+ {{- '\n<name>' ~ param_name ~ '</name>' }}
47
+ {%- if param_fields.type is defined %}
48
+ {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
49
+ {%- endif %}
50
+ {%- if param_fields.description is defined %}
51
+ {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
52
+ {%- endif %}
53
+ {%- set handled_keys = ['name', 'type', 'description'] %}
54
+ {{- render_extra_keys(param_fields, handled_keys) }}
55
+ {{- '\n</parameter>' }}
56
+ {%- endfor %}
57
+ {%- endif %}
58
+ {% set handled_keys = ['type', 'properties'] %}
59
+ {{- render_extra_keys(tool.parameters, handled_keys) }}
60
+ {{- '\n</parameters>' }}
61
+ {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
62
+ {{- render_extra_keys(tool, handled_keys) }}
63
+ {{- '\n</function>' }}
64
+ {%- endfor %}
65
+ {{- "\n</tools>" }}
66
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
67
+ {%- endif %}
68
+ {%- for message in messages -%}
69
+ {%- if message['role'] == 'user' -%}
70
+ {%- set ns.is_last_user = true -%}
71
+ {{- '<|User|>' + message['content'] -}}
72
+ {%- elif message['role'] == 'assistant' -%}
73
+ {%- if ns.is_last_user -%}
74
+ {{ '<|Assistant|>' }}
75
+ {%- endif -%}
76
+ {%- set ns.is_last_user = false -%}
77
+ {%- set content = message.get('content') | default('', true) -%}
78
+ {{ '<|end_of_thought|>' + content }}
79
+ {%- if message['tool_calls'] is defined and message['tool_calls'] is not none -%}
80
+ {%- for tool in message['tool_calls'] -%}
81
+ {%- if tool.function is defined %}{% set tool = tool.function %}{% endif -%}
82
+ {{- '\n<tool_call>\n<function=' + tool.name + '>\n' -}}
83
+ {%- if tool.arguments is defined -%}
84
+ {%- if tool.arguments is string -%}{%- set args_data = tool.arguments | from_json -%}{%- else -%}{%- set args_data = tool.arguments -%}{%- endif -%}
85
+ {%- for args_name, args_value in args_data.items() -%}
86
+ {{- '<parameter=' + args_name + '>\n' -}}
87
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string -%}
88
+ {{- args_value -}}{{- '\n</parameter>\n' -}}
89
+ {%- endfor -%}
90
+ {%- endif -%}
91
+ {{- '</function>\n</tool_call>' -}}
92
+ {%- endfor -%}
93
+ {%- endif -%}
94
+ {{ '<|end▁of▁sentence|>' }}
95
+ {%- elif message['role'] == 'tool' -%}
96
+ {%- set ns.is_last_user = true -%}
97
+ {{ '\n<tool_response>\n' + message['content'] + '\n</tool_response>' }}
98
+ {%- endif -%}
99
+ {%- endfor -%}
100
+
101
+ {%- if add_generation_prompt -%}
102
+ {{ '<|Assistant|>' }}{{ '<|end_of_thought|>' }}
103
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeepseekV3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_deepseek.DeepseekV3Config",
9
+ "AutoModel": "modeling_deepseek.DeepseekV3Model",
10
+ "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
11
+ },
12
+ "bos_token_id": 0,
13
+ "eos_token_id": 1,
14
+ "ep_size": 1,
15
+ "first_k_dense_replace": 1,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 2048,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 7168,
20
+ "kv_lora_rank": 512,
21
+ "max_position_embeddings": 131072,
22
+ "model_type": "joyai_llm_flash",
23
+ "moe_intermediate_size": 768,
24
+ "moe_layer_freq": 1,
25
+ "n_group": 1,
26
+ "n_routed_experts": 256,
27
+ "n_shared_experts": 1,
28
+ "norm_topk_prob": true,
29
+ "num_attention_heads": 32,
30
+ "num_experts_per_tok": 8,
31
+ "num_hidden_layers": 40,
32
+ "num_key_value_heads": 32,
33
+ "num_nextn_predict_layers": 1,
34
+ "q_lora_rank": 1536,
35
+ "qk_nope_head_dim": 128,
36
+ "qk_rope_head_dim": 64,
37
+ "rms_norm_eps": 1e-06,
38
+ "rope_theta": 32000000,
39
+ "routed_scaling_factor": 2.5,
40
+ "scoring_func": "sigmoid",
41
+ "tie_word_embeddings": false,
42
+ "topk_group": 1,
43
+ "topk_method": "noaux_tc",
44
+ "torch_dtype": "bfloat16",
45
+ "transformers_version": "4.44.2",
46
+ "use_cache": true,
47
+ "v_head_dim": 128,
48
+ "vocab_size": 129280
49
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-generation"}
configuration_deepseek.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3)
5
+
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """DeepSeekV3 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.modeling_rope_utils import rope_config_validation
21
+
22
+
23
+ DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
+
25
+
26
+ class DeepseekV3Config(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
31
+ e.g. [bzantium/tiny-deepseek-v3](https://huggingface.co/bzantium/tiny-deepseek-v3)
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 129280):
38
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
40
+ hidden_size (`int`, *optional*, defaults to 7168):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 18432):
43
+ Dimension of the MLP representations.
44
+ moe_intermediate_size (`int`, *optional*, defaults to 2048):
45
+ Dimension of the MoE representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 61):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 128):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*, defaults to 128):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ n_shared_experts (`int`, *optional*, defaults to 1):
59
+ Number of shared experts.
60
+ n_routed_experts (`int`, *optional*, defaults to 256):
61
+ Number of routed experts.
62
+ routed_scaling_factor (`float`, *optional*, defaults to 2.5):
63
+ Scaling factor or routed experts.
64
+ kv_lora_rank (`int`, *optional*, defaults to 512):
65
+ Rank of the LoRA matrices for key and value projections.
66
+ q_lora_rank (`int`, *optional*, defaults to 1536):
67
+ Rank of the LoRA matrices for query projections.
68
+ qk_rope_head_dim (`int`, *optional*, defaults to 64):
69
+ Dimension of the query/key heads that use rotary position embeddings.
70
+ v_head_dim (`int`, *optional*, defaults to 128):
71
+ Dimension of the value heads.
72
+ qk_nope_head_dim (`int`, *optional*, defaults to 128):
73
+ Dimension of the query/key heads that don't use rotary position embeddings.
74
+ n_group (`int`, *optional*, defaults to 8):
75
+ Number of groups for routed experts.
76
+ topk_group (`int`, *optional*, defaults to 4):
77
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
78
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
79
+ Number of selected experts, None means dense model.
80
+ first_k_dense_replace (`int`, *optional*, defaults to 3):
81
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
82
+ \--k dense layers--/
83
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
84
+ Whether to normalize the weights of the routed experts.
85
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
86
+ The non-linear activation function (function or string) in the decoder.
87
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
88
+ The maximum sequence length that this model might ever be used with.
89
+ initializer_range (`float`, *optional*, defaults to 0.02):
90
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
91
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
92
+ The epsilon used by the rms normalization layers.
93
+ use_cache (`bool`, *optional*, defaults to `True`):
94
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
95
+ relevant if `config.is_decoder=True`.
96
+ pad_token_id (`int`, *optional*):
97
+ Padding token id.
98
+ bos_token_id (`int`, *optional*, defaults to 0):
99
+ Beginning of stream token id.
100
+ eos_token_id (`int`, *optional*, defaults to 1):
101
+ End of stream token id.
102
+ pretraining_tp (`int`, *optional*, defaults to 1):
103
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
104
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
105
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
106
+ issue](https://github.com/pytorch/pytorch/issues/76232).
107
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
108
+ Whether to tie weight embeddings
109
+ rope_theta (`float`, *optional*, defaults to 10000.0):
110
+ The base period of the RoPE embeddings.
111
+ rope_scaling (`Dict`, *optional*):
112
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
113
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
114
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
115
+ `max_position_embeddings` to the expected new maximum.
116
+ rope_interleave (`bool`, *optional*, defaults to `True`):
117
+ Whether to interleave the rotary position embeddings.
118
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
119
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
120
+ attention_dropout (`float`, *optional*, defaults to 0.0):
121
+ The dropout ratio for the attention probabilities.
122
+
123
+ ```python
124
+ >>> from transformers import DeepseekV3Model, DeepseekV3Config
125
+
126
+ >>> # Initializing a Deepseek-V3 style configuration
127
+ >>> configuration = DeepseekV3Config()
128
+
129
+ >>> # Accessing the model configuration
130
+ >>> configuration = model.config
131
+ ```"""
132
+
133
+ model_type = "deepseek_v3"
134
+ keys_to_ignore_at_inference = ["past_key_values"]
135
+ base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
136
+ "layers.*.mlp.experts.*.gate_proj": "local_colwise",
137
+ "layers.*.mlp.experts.*.up_proj": "local_colwise",
138
+ "layers.*.mlp.experts.*.down_proj": "local_rowwise",
139
+ "layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
140
+ "layers.*.mlp.shared_experts.gate_proj": "local_colwise",
141
+ "layers.*.mlp.shared_experts.up_proj": "local_colwise",
142
+ "layers.*.mlp.shared_experts.down_proj": "local_rowwise",
143
+ "layers.*.mlp.shared_experts": "local",
144
+ "layers.*.mlp.gate_proj": "local_colwise",
145
+ "layers.*.mlp.up_proj": "local_colwise",
146
+ "layers.*.mlp.down_proj": "local_rowwise",
147
+ "layers.*.mlp": "gather", # This is the only moment where results are gathered
148
+ }
149
+ base_model_pp_plan = {
150
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
151
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
152
+ "norm": (["hidden_states"], ["hidden_states"]),
153
+ }
154
+
155
+ def __init__(
156
+ self,
157
+ vocab_size=129280,
158
+ hidden_size=7168,
159
+ intermediate_size=18432,
160
+ moe_intermediate_size=2048,
161
+ num_hidden_layers=61,
162
+ num_attention_heads=128,
163
+ num_key_value_heads=128,
164
+ n_shared_experts=1,
165
+ n_routed_experts=256,
166
+ routed_scaling_factor=2.5,
167
+ kv_lora_rank=512,
168
+ q_lora_rank=1536,
169
+ qk_rope_head_dim=64,
170
+ v_head_dim=128,
171
+ qk_nope_head_dim=128,
172
+ n_group=8,
173
+ topk_group=4,
174
+ num_experts_per_tok=8,
175
+ first_k_dense_replace=3,
176
+ norm_topk_prob=True,
177
+ hidden_act="silu",
178
+ max_position_embeddings=4096,
179
+ initializer_range=0.02,
180
+ rms_norm_eps=1e-6,
181
+ use_cache=True,
182
+ pad_token_id=None,
183
+ bos_token_id=0,
184
+ eos_token_id=1,
185
+ pretraining_tp=1,
186
+ tie_word_embeddings=False,
187
+ rope_theta=10000.0,
188
+ rope_scaling=None,
189
+ rope_interleave=True,
190
+ attention_bias=False,
191
+ attention_dropout=0.0,
192
+ **kwargs,
193
+ ):
194
+ self.vocab_size = vocab_size
195
+ self.max_position_embeddings = max_position_embeddings
196
+ self.hidden_size = hidden_size
197
+ self.intermediate_size = intermediate_size
198
+ self.moe_intermediate_size = moe_intermediate_size
199
+ self.num_hidden_layers = num_hidden_layers
200
+ self.num_attention_heads = num_attention_heads
201
+ self.n_shared_experts = n_shared_experts
202
+ self.n_routed_experts = n_routed_experts
203
+ self.routed_scaling_factor = routed_scaling_factor
204
+ self.kv_lora_rank = kv_lora_rank
205
+ self.q_lora_rank = q_lora_rank
206
+ self.qk_rope_head_dim = qk_rope_head_dim
207
+ self.v_head_dim = v_head_dim
208
+ self.qk_nope_head_dim = qk_nope_head_dim
209
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
210
+ self.head_dim = qk_rope_head_dim
211
+ self.n_group = n_group
212
+ self.topk_group = topk_group
213
+ self.num_experts_per_tok = num_experts_per_tok
214
+ self.first_k_dense_replace = first_k_dense_replace
215
+ self.norm_topk_prob = norm_topk_prob
216
+ self.rope_interleave = rope_interleave
217
+
218
+ # for backward compatibility
219
+ if num_key_value_heads is None:
220
+ num_key_value_heads = num_attention_heads
221
+
222
+ self.num_key_value_heads = num_key_value_heads
223
+ self.hidden_act = hidden_act
224
+ self.initializer_range = initializer_range
225
+ self.rms_norm_eps = rms_norm_eps
226
+ self.pretraining_tp = pretraining_tp
227
+ self.use_cache = use_cache
228
+ self.rope_theta = rope_theta
229
+ self.rope_scaling = rope_scaling
230
+ self.attention_bias = attention_bias
231
+ self.attention_dropout = attention_dropout
232
+ # Validate the correctness of rotary position embeddings parameters
233
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
234
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
235
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
236
+ rope_config_validation(self)
237
+
238
+ super().__init__(
239
+ pad_token_id=pad_token_id,
240
+ bos_token_id=bos_token_id,
241
+ eos_token_id=eos_token_id,
242
+ tie_word_embeddings=tie_word_embeddings,
243
+ **kwargs,
244
+ )
245
+
246
+
247
+ __all__ = ["DeepseekV3Config"]
docs/deploy_guidance.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JoyAI-LLM Flash Deployment Guide
2
+
3
+ > [!Note]
4
+ > This guide offers a selection of deployment command examples for JoyAI-Flash, which may not be the optimal configuration. Given the rapid evolution of inference engines, we recommend referring to their official documentation for the latest updates to ensure peak performance.
5
+
6
+ > Support for JoyAI-LLM Flash’s dense MTP architecture is currently being integrated into vLLM and SGLang. Until these PRs are merged into a stable release, please use the nightly Docker image for access to these features.
7
+
8
+ ## vLLM Deployment
9
+
10
+ Here is the example to serve this model on a H200 single node with TP8 via vLLM:
11
+ ```bash
12
+ vllm serve ${MODEL_PATH} --tp 8 --trust-remote-code \
13
+ --tool-call-parser qwen3_coder --enable-auto-tool-choice \
14
+ --speculative-config $'{"method": "mtp", "num_speculative_tokens": 3}'
15
+ ```
16
+ **Key notes**
17
+ - `--tool-call-parser qwen3_coder`: Required for enabling tool calling
18
+
19
+ ## SGLang Deployment
20
+
21
+ Similarly, here is the example to run with TP8 on H200 in a single node via SGLang:
22
+ ```bash
23
+ python3 -m sglang.launch_server --model-path ${MODEL_PATH} --tp-size 8 --trust-remote-code \
24
+ --tool-call-parser qwen3_coder \
25
+ --speculative-algorithm EAGLE --speculative-draft-model-path ${MTP_MODEL_PATH} \
26
+ --speculative-num-steps 2 --speculative-eagle-topk 2 --speculative-num-draft-tokens 3
27
+ ```
28
+ **Key notes:**
29
+ - `--tool-call-parser qwen3_coder`: Required when enabling tool usage.
figures/joyai-logo.png ADDED

Git LFS Details

  • SHA256: 4ea9d6a20a7707ca8dc427d6dcb5db6e2489f7730d5bffea26d8db20b1c54365
  • Pointer size: 131 Bytes
  • Size of remote file: 250 kB
model-1-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00342c45cd62e28fe183e2529ce61e2cecf0d8ea5451b2a8fe4137ae5e50e901
3
+ size 140785016
model-10-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25566dca71af5a8ad118ecff34fe03acda6d2483b9127cd26211be2082d5d6c8
3
+ size 2479205264
model-11-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:106ca52367ca7d80b9f99553e3ad01b820ec12f28be1d2fc57233f2ce33a5199
3
+ size 2479206048
model-12-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1a396967e8b7cfafe9399c031a2e869319b4bf40750a58f91b81037a36f0f2
3
+ size 2479206048
model-13-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b813cd997285699eb10401e7b80c87773900b6c9ca48a305f228824dde553fe3
3
+ size 2479206048
model-14-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b78f08465829cee2cf4d9a06912e33306c76a1251ac0f6637a2a505bef3376f6
3
+ size 2479206048
model-15-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:415653c237c5189bcf725a587b7e7db5c15c858d9748c11fa9ee9a97c77ebf3e
3
+ size 2479206048
model-16-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edc2ec545b27884512149627d4c68c6e80a47429ab3f125e14f889aeb5c4555c
3
+ size 2479206048
model-17-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:849209aeddc1e424dfd58c9b7bc85bb7fc26ecd79e35d0c3e0c1246ce7ea6cd8
3
+ size 2479206048
model-18-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37a1c9d67e32905eee90a59028f843b3e2df75b7fdf310347467525ba8af0778
3
+ size 2479206048
model-19-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71940345d604e462c154b15eba80ce6013a9b47daaae628133c7a7b561ec0947
3
+ size 2479206048
model-21-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36db259d412d9278343b91f5c0844c8a3c7992ceee11ebfe562494fb3934a012
3
+ size 2479206048
model-22-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a63e252bf2374910cf4c8a26d1c2ffe0ea88348dae23fed3cbe99c82443213b
3
+ size 2479206048
model-23-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:490edf4128ec42e523730148c73c29445b8eea7a2d81d2e3296bc957b5d5dad7
3
+ size 2479206048
model-24-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c82fda4cc8b9570898f355b1ca8f640262e269710d0efd914d387825b82e90f4
3
+ size 2479206048
model-25-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:371774aa4bdfbefc9c62e4dd7a289e2cc15fe8ac8dd01f193bc3951c395149d1
3
+ size 2479206048
model-27-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0061eca323f71053806bdd9d78e098fc28e8442b3cd27581407f56b2507e3a80
3
+ size 2479206048
model-28-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e81ab12d6b7ab57a024daba11ccd75eaf4b3579b68aa7fc5e043a7c216df845a
3
+ size 2479206048
model-29-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fc269095eba7b9ffdbd51986cc5d805d73f32333b759a502f09ff55e797789a
3
+ size 2479206048
model-3-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:734577bcc726ead7e0a9e8cad31edddb3d4ee2cc34fa5acecd37daea8332fc58
3
+ size 2479205264
model-30-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:423416965094dc7d45422a5ca9d285cab7b12836bddd7a4c4f704ce996da3a00
3
+ size 2479206048
model-31-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e405c1eebce9eaa75db83d8db6a108bd8d28b5e1e76903790dd20ae13297999e
3
+ size 2479206048
model-32-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f1b25c9b64915ba066ab35a007928d48ed3ae75fabbcc45a269b3638268c999
3
+ size 2479206048
model-33-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f927f6e3246b676dc31fb4f227de92073ee54ea790042f7077d3d5ea5732d90f
3
+ size 2479206048
model-34-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3582ea2db95a41584f8ceb30209fa8415e7fd254dc43be695f9b67dd28931ef2
3
+ size 2479206048
model-35-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:feeffd5917defddd4543b8747bd33da9ee2cabfa3ff57634ebf1718e9ed46563
3
+ size 2479206048
model-36-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:526404ba27ab66b8dcc455e1ca2234fab4843a890d7f6513e37eeebbfe2aaaf7
3
+ size 2479206048
model-37-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d8a43b4d41be94f6f691c254f79d8994611e9229fedd23dfcd4fcc39e0853de
3
+ size 2479206048
model-38-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51467503269ccea044ec96c7f94aff0efa523d586a8736dda70414fca29c5a03
3
+ size 2479206048
model-39-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deb76b58136b1144e79b0de587d991aed8c691fa4efcf1e54d118ec9b489c22d
3
+ size 2479206048
model-4-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb292fd660dfc54cf9afe89938d431bf58d8634576c8cc52bc874d4b68fadbd9
3
+ size 2479205264
model-40-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b09ac0b99a72696162f673ae20ad74b8ae5231fdf04f663d84c4ba981e83cf8
3
+ size 2479206048
model-5-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9c7bc3ad5832be99e39cdbfdc3f843d5c0d6edb9de0c0d8d34524d049b1d945
3
+ size 2479205264
model-6-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eccfb66a683dd836293360c8c8441c466209708fd0220b04f7511814d75aed7c
3
+ size 2479205264
model-8-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42c68787a8622f322c03190dda75801e30f65aa26997df147d9f7dec4fd264b4
3
+ size 2479205264
model-9-of-40.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bf720fd443166d2e2a2f1a524035959febadfa079c7971c5a777642183008d6
3
+ size 2479205264
model-non-layer.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd760d732c11c23778a0dbf2280b62431d77d1f4ebc4f01f111cf716786981f0
3
+ size 1059066184
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_deepseek.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_deepseek_v3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ import math
8
+ from functools import partial
9
+ from typing import Callable, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
17
+ from transformers.generation import GenerationMixin
18
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
19
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
20
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
21
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
22
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
23
+ from transformers.processing_utils import Unpack
24
+ from transformers.utils import (
25
+ LossKwargs,
26
+ add_start_docstrings,
27
+ add_start_docstrings_to_model_forward,
28
+ can_return_tuple,
29
+ is_torch_flex_attn_available,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.utils.deprecation import deprecate_kwarg
34
+ from .configuration_deepseek import DeepseekV3Config
35
+
36
+
37
+ if is_torch_flex_attn_available():
38
+ from torch.nn.attention.flex_attention import BlockMask
39
+
40
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
45
+
46
+
47
+ class DeepseekV3RMSNorm(nn.Module):
48
+ def __init__(self, hidden_size, eps=1e-6):
49
+ """
50
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
51
+ """
52
+ super().__init__()
53
+ self.weight = nn.Parameter(torch.ones(hidden_size))
54
+ self.variance_epsilon = eps
55
+
56
+ def forward(self, hidden_states):
57
+ input_dtype = hidden_states.dtype
58
+ hidden_states = hidden_states.to(torch.float32)
59
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
61
+ return self.weight * hidden_states.to(input_dtype)
62
+
63
+ def extra_repr(self):
64
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
65
+
66
+
67
+ class DeepseekV3RotaryEmbedding(nn.Module):
68
+ def __init__(self, config: DeepseekV3Config, device=None):
69
+ super().__init__()
70
+ # BC: "rope_type" was originally "type"
71
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
72
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
73
+ else:
74
+ self.rope_type = "default"
75
+ self.max_seq_len_cached = config.max_position_embeddings
76
+ self.original_max_seq_len = config.max_position_embeddings
77
+
78
+ self.config = config
79
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
80
+
81
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
82
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
83
+ self.original_inv_freq = self.inv_freq
84
+
85
+ @torch.no_grad()
86
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
87
+ def forward(self, x, position_ids):
88
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
89
+ position_ids_expanded = position_ids[:, None, :].float()
90
+
91
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
92
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
93
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
94
+ emb = torch.cat((freqs, freqs), dim=-1)
95
+ cos = emb.cos() * self.attention_scaling
96
+ sin = emb.sin() * self.attention_scaling
97
+
98
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
99
+
100
+
101
+ class DeepseekV3MLP(nn.Module):
102
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
103
+ super().__init__()
104
+ self.config = config
105
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
106
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
107
+
108
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
109
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
110
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
111
+ self.act_fn = ACT2FN[config.hidden_act]
112
+
113
+ def forward(self, x):
114
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
115
+ return down_proj
116
+
117
+
118
+ class DeepseekV3TopkRouter(nn.Module):
119
+ def __init__(self, config):
120
+ super().__init__()
121
+ self.config = config
122
+ self.top_k = config.num_experts_per_tok
123
+ self.n_routed_experts = config.n_routed_experts
124
+ self.routed_scaling_factor = config.routed_scaling_factor
125
+ self.n_group = config.n_group
126
+ self.topk_group = config.topk_group
127
+ self.norm_topk_prob = config.norm_topk_prob
128
+
129
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
130
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
131
+
132
+ @torch.no_grad()
133
+ def get_topk_indices(self, scores):
134
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
135
+ group_scores = (
136
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
137
+ .topk(2, dim=-1)[0]
138
+ .sum(dim=-1)
139
+ )
140
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
141
+ group_mask = torch.zeros_like(group_scores)
142
+ group_mask.scatter_(1, group_idx, 1)
143
+ score_mask = (
144
+ group_mask.unsqueeze(-1)
145
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
146
+ .reshape(-1, self.n_routed_experts)
147
+ )
148
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
149
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
150
+ return topk_indices
151
+
152
+ def forward(self, hidden_states):
153
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
154
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
155
+ scores = router_logits.sigmoid()
156
+ topk_indices = self.get_topk_indices(scores)
157
+ topk_weights = scores.gather(1, topk_indices)
158
+ if self.norm_topk_prob:
159
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
160
+ topk_weights /= denominator
161
+ topk_weights = topk_weights * self.routed_scaling_factor
162
+ return topk_indices, topk_weights
163
+
164
+
165
+ class DeepseekV3MoE(nn.Module):
166
+ """
167
+ A mixed expert module containing shared experts.
168
+ """
169
+
170
+ def __init__(self, config):
171
+ super().__init__()
172
+ self.config = config
173
+ self.experts = nn.ModuleList(
174
+ [
175
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
176
+ for _ in range(config.n_routed_experts)
177
+ ]
178
+ )
179
+ self.gate = DeepseekV3TopkRouter(config)
180
+ self.shared_experts = DeepseekV3MLP(
181
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
182
+ )
183
+
184
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
185
+ r"""
186
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
187
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
188
+ """
189
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
190
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
191
+ expert_mask = expert_mask.permute(2, 0, 1)
192
+
193
+ for expert_idx in range(len(self.experts)):
194
+ expert = self.experts[expert_idx]
195
+ mask = expert_mask[expert_idx]
196
+ token_indices, weight_indices = torch.where(mask)
197
+
198
+ if token_indices.numel() > 0:
199
+ expert_weights = topk_weights[token_indices, weight_indices]
200
+ expert_input = hidden_states[token_indices]
201
+ expert_output = expert(expert_input)
202
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
203
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
204
+
205
+ # in original deepseek, the output of the experts are gathered once we leave this module
206
+ # thus the moe module is itelsf an IsolatedParallel module
207
+ # and all expert are "local" meaning we shard but we don't gather
208
+ return final_hidden_states.type(hidden_states.dtype)
209
+
210
+ def forward(self, hidden_states):
211
+ residuals = hidden_states
212
+ orig_shape = hidden_states.shape
213
+ topk_indices, topk_weights = self.gate(hidden_states)
214
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
215
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
216
+ hidden_states = hidden_states + self.shared_experts(residuals)
217
+ return hidden_states
218
+
219
+
220
+ def rotate_half(x):
221
+ """Rotates half the hidden dims of the input."""
222
+ x1 = x[..., : x.shape[-1] // 2]
223
+ x2 = x[..., x.shape[-1] // 2 :]
224
+ return torch.cat((-x2, x1), dim=-1)
225
+
226
+
227
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
228
+ """Applies Rotary Position Embedding to the query and key tensors.
229
+
230
+ Args:
231
+ q (`torch.Tensor`): The query tensor.
232
+ k (`torch.Tensor`): The key tensor.
233
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
234
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
235
+ position_ids (`torch.Tensor`, *optional*):
236
+ Deprecated and unused.
237
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
238
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
239
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
240
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
241
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
242
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
243
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
244
+ Returns:
245
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
246
+ """
247
+ cos = cos.unsqueeze(unsqueeze_dim)
248
+ sin = sin.unsqueeze(unsqueeze_dim)
249
+ q_embed = (q * cos) + (rotate_half(q) * sin)
250
+ k_embed = (k * cos) + (rotate_half(k) * sin)
251
+ return q_embed, k_embed
252
+
253
+
254
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
255
+ """
256
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
257
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
258
+ """
259
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
260
+ if n_rep == 1:
261
+ return hidden_states
262
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
263
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
264
+
265
+
266
+ def eager_attention_forward(
267
+ module: nn.Module,
268
+ query: torch.Tensor,
269
+ key: torch.Tensor,
270
+ value: torch.Tensor,
271
+ attention_mask: Optional[torch.Tensor],
272
+ scaling: float,
273
+ dropout: float = 0.0,
274
+ **kwargs,
275
+ ):
276
+ key_states = repeat_kv(key, module.num_key_value_groups)
277
+ value_states = repeat_kv(value, module.num_key_value_groups)
278
+
279
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
280
+ if attention_mask is not None:
281
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
282
+ attn_weights = attn_weights + causal_mask
283
+
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
285
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
286
+ attn_output = torch.matmul(attn_weights, value_states)
287
+ attn_output = attn_output.transpose(1, 2).contiguous()
288
+
289
+ return attn_output, attn_weights
290
+
291
+
292
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
293
+ r"""
294
+ TODO let's just use the original freqcis computation to not have the view
295
+ transpose + reshape! This is not optimized!
296
+ Applies Rotary Position Embedding to the query and key tensors.
297
+
298
+ Args:
299
+ q (`torch.Tensor`): The query tensor.
300
+ k (`torch.Tensor`): The key tensor.
301
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
302
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
303
+ position_ids (`torch.Tensor`):
304
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
305
+ used to pass offsetted position ids when working with a KV-cache.
306
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
307
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
308
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
309
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
310
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
311
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
312
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
313
+ Returns:
314
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
315
+ """
316
+ cos = cos.unsqueeze(unsqueeze_dim)
317
+ sin = sin.unsqueeze(unsqueeze_dim)
318
+
319
+ b, h, s, d = q.shape
320
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
321
+
322
+ b, h, s, d = k.shape
323
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
324
+
325
+ q_embed = (q * cos) + (rotate_half(q) * sin)
326
+ k_embed = (k * cos) + (rotate_half(k) * sin)
327
+ return q_embed, k_embed
328
+
329
+
330
+ def yarn_get_mscale(scale=1, mscale=1):
331
+ if scale <= 1:
332
+ return 1.0
333
+ return 0.1 * mscale * math.log(scale) + 1.0
334
+
335
+
336
+ class DeepseekV3Attention(nn.Module):
337
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
338
+
339
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
340
+ super().__init__()
341
+ self.config = config
342
+ self.layer_idx = layer_idx
343
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
344
+ self.attention_dropout = config.attention_dropout
345
+ self.num_heads = config.num_attention_heads
346
+ self.rope_theta = config.rope_theta
347
+ self.q_lora_rank = config.q_lora_rank
348
+ self.qk_rope_head_dim = config.qk_rope_head_dim
349
+ self.kv_lora_rank = config.kv_lora_rank
350
+ self.v_head_dim = config.v_head_dim
351
+ self.qk_nope_head_dim = config.qk_nope_head_dim
352
+ self.qk_head_dim = config.qk_head_dim
353
+
354
+ self.is_causal = True
355
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
356
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
357
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
358
+
359
+ self.kv_a_proj_with_mqa = nn.Linear(
360
+ config.hidden_size,
361
+ self.kv_lora_rank + self.qk_rope_head_dim,
362
+ bias=config.attention_bias,
363
+ )
364
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
365
+ self.kv_b_proj = nn.Linear(
366
+ self.kv_lora_rank,
367
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
368
+ bias=False,
369
+ )
370
+
371
+ self.o_proj = nn.Linear(
372
+ self.num_heads * self.v_head_dim,
373
+ config.hidden_size,
374
+ bias=config.attention_bias,
375
+ )
376
+
377
+ self.scaling = self.qk_head_dim ** (-0.5)
378
+ if self.config.rope_scaling is not None:
379
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
380
+ scaling_factor = self.config.rope_scaling["factor"]
381
+ if mscale_all_dim:
382
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
383
+ self.scaling = self.scaling * mscale * mscale
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
389
+ attention_mask: Optional[torch.Tensor],
390
+ past_key_value: Optional[Cache] = None,
391
+ cache_position: Optional[torch.LongTensor] = None,
392
+ **kwargs: Unpack[FlashAttentionKwargs],
393
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
394
+ batch_size, seq_length = hidden_states.shape[:-1]
395
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
396
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
397
+
398
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
399
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
400
+
401
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
402
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
403
+
404
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
405
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
406
+
407
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
408
+
409
+ cos, sin = position_embeddings
410
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
411
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
412
+ else:
413
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
414
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
415
+
416
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
417
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
418
+
419
+ if past_key_value is not None:
420
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
421
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
422
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
423
+
424
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
425
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
426
+
427
+ attention_interface: Callable = eager_attention_forward
428
+ if self.config._attn_implementation != "eager":
429
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
430
+ logger.warning_once(
431
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
432
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
433
+ )
434
+ else:
435
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
436
+
437
+ attn_output, attn_weights = attention_interface(
438
+ self,
439
+ query_states,
440
+ key_states,
441
+ value_states,
442
+ attention_mask,
443
+ dropout=0.0 if not self.training else self.attention_dropout,
444
+ scaling=self.scaling,
445
+ **kwargs,
446
+ )
447
+
448
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
449
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
450
+
451
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
452
+ attn_output = self.o_proj(attn_output)
453
+ return attn_output, attn_weights
454
+
455
+
456
+ class DeepseekV3DecoderLayer(nn.Module):
457
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
458
+ super().__init__()
459
+ self.hidden_size = config.hidden_size
460
+
461
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
462
+
463
+ if layer_idx >= config.first_k_dense_replace:
464
+ self.mlp = DeepseekV3MoE(config)
465
+ else:
466
+ self.mlp = DeepseekV3MLP(config)
467
+
468
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
469
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
470
+
471
+ def forward(
472
+ self,
473
+ hidden_states: torch.Tensor,
474
+ attention_mask: Optional[torch.Tensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ past_key_value: Optional[Cache] = None,
477
+ output_attentions: Optional[bool] = False,
478
+ use_cache: Optional[bool] = False,
479
+ cache_position: Optional[torch.LongTensor] = None,
480
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
481
+ **kwargs: Unpack[FlashAttentionKwargs],
482
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
483
+ residual = hidden_states
484
+
485
+ hidden_states = self.input_layernorm(hidden_states)
486
+
487
+ # Self Attention
488
+ hidden_states, self_attn_weights = self.self_attn(
489
+ hidden_states=hidden_states,
490
+ attention_mask=attention_mask,
491
+ position_ids=position_ids,
492
+ past_key_value=past_key_value,
493
+ output_attentions=output_attentions,
494
+ use_cache=use_cache,
495
+ cache_position=cache_position,
496
+ position_embeddings=position_embeddings,
497
+ **kwargs,
498
+ )
499
+ hidden_states = residual + hidden_states
500
+
501
+ # Fully Connected
502
+ residual = hidden_states
503
+ hidden_states = self.post_attention_layernorm(hidden_states)
504
+ hidden_states = self.mlp(hidden_states)
505
+ hidden_states = residual + hidden_states
506
+
507
+ outputs = (hidden_states,)
508
+ if output_attentions:
509
+ outputs += (self_attn_weights,)
510
+
511
+ return outputs
512
+
513
+
514
+ DEEPSEEK_V3_START_DOCSTRING = r"""
515
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
516
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
517
+ etc.)
518
+
519
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
520
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
521
+ and behavior.
522
+
523
+ Parameters:
524
+ config ([`DeepseekV3Config`]):
525
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
526
+ load the weights associated with the model, only the configuration. Check out the
527
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
528
+ """
529
+
530
+
531
+ @add_start_docstrings(
532
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
533
+ DEEPSEEK_V3_START_DOCSTRING,
534
+ )
535
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
536
+ config_class = DeepseekV3Config
537
+ base_model_prefix = "model"
538
+ supports_gradient_checkpointing = True
539
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
540
+ _skip_keys_device_placement = ["past_key_values"]
541
+ _supports_flash_attn_2 = True
542
+ _supports_sdpa = True
543
+ _supports_flex_attn = True
544
+ _supports_cache_class = True
545
+ _supports_quantized_cache = True
546
+ _supports_static_cache = True
547
+ _supports_attention_backend = True
548
+
549
+ def _init_weights(self, module):
550
+ std = self.config.initializer_range
551
+ if isinstance(module, nn.Linear):
552
+ module.weight.data.normal_(mean=0.0, std=std)
553
+ if module.bias is not None:
554
+ module.bias.data.zero_()
555
+ elif isinstance(module, nn.Embedding):
556
+ module.weight.data.normal_(mean=0.0, std=std)
557
+ if module.padding_idx is not None:
558
+ module.weight.data[module.padding_idx].zero_()
559
+ elif isinstance(module, DeepseekV3TopkRouter):
560
+ module.weight.data.normal_(mean=0.0, std=std)
561
+ elif isinstance(module, nn.Parameter):
562
+ module.weight.data.normal_(mean=0.0, std=std)
563
+
564
+
565
+ DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
566
+ Args:
567
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
568
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
569
+ it.
570
+
571
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
572
+ [`PreTrainedTokenizer.__call__`] for details.
573
+
574
+ [What are input IDs?](../glossary#input-ids)
575
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
576
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
577
+
578
+ - 1 for tokens that are **not masked**,
579
+ - 0 for tokens that are **masked**.
580
+
581
+ [What are attention masks?](../glossary#attention-mask)
582
+
583
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
584
+ [`PreTrainedTokenizer.__call__`] for details.
585
+
586
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
587
+ `past_key_values`).
588
+
589
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
590
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
591
+ information on the default strategy.
592
+
593
+ - 1 indicates the head is **not masked**,
594
+ - 0 indicates the head is **masked**.
595
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
596
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
597
+ config.n_positions - 1]`.
598
+
599
+ [What are position IDs?](../glossary#position-ids)
600
+ past_key_values (`Cache`, *optional*):
601
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
602
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
603
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
604
+
605
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
606
+
607
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
608
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
609
+ of shape `(batch_size, sequence_length)`.
610
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
611
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
612
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
613
+ model's internal embedding lookup matrix.
614
+ use_cache (`bool`, *optional*):
615
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
616
+ `past_key_values`).
617
+ output_attentions (`bool`, *optional*):
618
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
619
+ tensors for more detail.
620
+ output_hidden_states (`bool`, *optional*):
621
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
622
+ more detail.
623
+ return_dict (`bool`, *optional*):
624
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
625
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
626
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
627
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
628
+ the complete sequence length.
629
+ """
630
+
631
+
632
+ @add_start_docstrings(
633
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
634
+ DEEPSEEK_V3_START_DOCSTRING,
635
+ )
636
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
637
+ """
638
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
639
+
640
+ Args:
641
+ config: DeepseekV3Config
642
+ """
643
+
644
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
645
+
646
+ def __init__(self, config: DeepseekV3Config):
647
+ super().__init__(config)
648
+ self.padding_idx = config.pad_token_id
649
+ self.vocab_size = config.vocab_size
650
+
651
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
652
+ self.layers = nn.ModuleList(
653
+ [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
654
+ )
655
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
656
+ self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
657
+ self.gradient_checkpointing = False
658
+
659
+ # Initialize weights and apply final processing
660
+ self.post_init()
661
+
662
+ def get_input_embeddings(self):
663
+ return self.embed_tokens
664
+
665
+ def set_input_embeddings(self, value):
666
+ self.embed_tokens = value
667
+
668
+ @can_return_tuple
669
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
670
+ def forward(
671
+ self,
672
+ input_ids: Optional[torch.LongTensor] = None,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_values: Optional[Cache] = None,
676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
677
+ use_cache: Optional[bool] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ cache_position: Optional[torch.LongTensor] = None,
681
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
682
+ ) -> BaseModelOutputWithPast:
683
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
684
+ output_hidden_states = (
685
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
686
+ )
687
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
688
+
689
+ if (input_ids is None) ^ (inputs_embeds is not None):
690
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
691
+
692
+ if self.gradient_checkpointing and self.training and use_cache:
693
+ logger.warning_once(
694
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
695
+ )
696
+ use_cache = False
697
+
698
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
699
+ if not isinstance(past_key_values, (type(None), Cache)):
700
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
701
+
702
+ if inputs_embeds is None:
703
+ inputs_embeds = self.embed_tokens(input_ids)
704
+
705
+ if use_cache and past_key_values is None:
706
+ past_key_values = DynamicCache()
707
+
708
+ if cache_position is None:
709
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
710
+ cache_position = torch.arange(
711
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
712
+ )
713
+
714
+ if position_ids is None:
715
+ position_ids = cache_position.unsqueeze(0)
716
+
717
+ causal_mask = self._update_causal_mask(
718
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
719
+ )
720
+
721
+ hidden_states = inputs_embeds
722
+
723
+ # create position embeddings to be shared across the decoder layers
724
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
725
+
726
+ # decoder layers
727
+ all_hidden_states = () if output_hidden_states else None
728
+ all_self_attns = () if output_attentions else None
729
+
730
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
731
+ if output_hidden_states:
732
+ all_hidden_states += (hidden_states,)
733
+
734
+ if self.gradient_checkpointing and self.training:
735
+ layer_outputs = self._gradient_checkpointing_func(
736
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
737
+ hidden_states,
738
+ causal_mask,
739
+ position_ids,
740
+ past_key_values,
741
+ output_attentions,
742
+ use_cache,
743
+ cache_position,
744
+ position_embeddings,
745
+ )
746
+ else:
747
+ layer_outputs = decoder_layer(
748
+ hidden_states,
749
+ attention_mask=causal_mask,
750
+ position_ids=position_ids,
751
+ past_key_value=past_key_values,
752
+ output_attentions=output_attentions,
753
+ use_cache=use_cache,
754
+ cache_position=cache_position,
755
+ position_embeddings=position_embeddings,
756
+ **flash_attn_kwargs,
757
+ )
758
+
759
+ hidden_states = layer_outputs[0]
760
+
761
+ if output_attentions:
762
+ all_self_attns += (layer_outputs[1],)
763
+
764
+ hidden_states = self.norm(hidden_states)
765
+
766
+ # add hidden states from the last decoder layer
767
+ if output_hidden_states:
768
+ all_hidden_states += (hidden_states,)
769
+
770
+ return BaseModelOutputWithPast(
771
+ last_hidden_state=hidden_states,
772
+ past_key_values=past_key_values if use_cache else None,
773
+ hidden_states=all_hidden_states,
774
+ attentions=all_self_attns,
775
+ )
776
+
777
+ def _update_causal_mask(
778
+ self,
779
+ attention_mask: torch.Tensor,
780
+ input_tensor: torch.Tensor,
781
+ cache_position: torch.Tensor,
782
+ past_key_values: Cache,
783
+ output_attentions: bool = False,
784
+ ):
785
+ if self.config._attn_implementation == "flash_attention_2":
786
+ if attention_mask is not None and (attention_mask == 0.0).any():
787
+ return attention_mask
788
+ return None
789
+ if self.config._attn_implementation == "flex_attention":
790
+ if isinstance(attention_mask, torch.Tensor):
791
+ attention_mask = make_flex_block_causal_mask(attention_mask)
792
+ if isinstance(attention_mask, BlockMask):
793
+ return attention_mask
794
+
795
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
796
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
797
+ # to infer the attention mask.
798
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
799
+ using_static_cache = isinstance(past_key_values, StaticCache)
800
+
801
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
802
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
803
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
804
+ attention_mask,
805
+ inputs_embeds=input_tensor,
806
+ past_key_values_length=past_seen_tokens,
807
+ is_training=self.training,
808
+ ):
809
+ return None
810
+
811
+ dtype, device = input_tensor.dtype, input_tensor.device
812
+ sequence_length = input_tensor.shape[1]
813
+ if using_static_cache:
814
+ target_length = past_key_values.get_max_cache_shape()
815
+ else:
816
+ target_length = (
817
+ attention_mask.shape[-1]
818
+ if isinstance(attention_mask, torch.Tensor)
819
+ else past_seen_tokens + sequence_length + 1
820
+ )
821
+
822
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
823
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
824
+ attention_mask,
825
+ sequence_length=sequence_length,
826
+ target_length=target_length,
827
+ dtype=dtype,
828
+ device=device,
829
+ cache_position=cache_position,
830
+ batch_size=input_tensor.shape[0],
831
+ )
832
+
833
+ if (
834
+ self.config._attn_implementation == "sdpa"
835
+ and attention_mask is not None
836
+ and attention_mask.device.type in ["cuda", "xpu"]
837
+ and not output_attentions
838
+ ):
839
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
840
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
841
+ # Details: https://github.com/pytorch/pytorch/issues/110213
842
+ min_dtype = torch.finfo(dtype).min
843
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
844
+
845
+ return causal_mask
846
+
847
+ @staticmethod
848
+ def _prepare_4d_causal_attention_mask_with_cache_position(
849
+ attention_mask: torch.Tensor,
850
+ sequence_length: int,
851
+ target_length: int,
852
+ dtype: torch.dtype,
853
+ device: torch.device,
854
+ cache_position: torch.Tensor,
855
+ batch_size: int,
856
+ **kwargs,
857
+ ):
858
+ """
859
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
860
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
861
+
862
+ Args:
863
+ attention_mask (`torch.Tensor`):
864
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
865
+ `(batch_size, 1, query_length, key_value_length)`.
866
+ sequence_length (`int`):
867
+ The sequence length being processed.
868
+ target_length (`int`):
869
+ The target length: when generating with static cache, the mask should be as long as the static cache,
870
+ to account for the 0 padding, the part of the cache that is not filled yet.
871
+ dtype (`torch.dtype`):
872
+ The dtype to use for the 4D attention mask.
873
+ device (`torch.device`):
874
+ The device to place the 4D attention mask on.
875
+ cache_position (`torch.Tensor`):
876
+ Indices depicting the position of the input sequence tokens in the sequence.
877
+ batch_size (`torch.Tensor`):
878
+ Batch size.
879
+ """
880
+ if attention_mask is not None and attention_mask.dim() == 4:
881
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
882
+ causal_mask = attention_mask
883
+ else:
884
+ min_dtype = torch.finfo(dtype).min
885
+ causal_mask = torch.full(
886
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
887
+ )
888
+ if sequence_length != 1:
889
+ causal_mask = torch.triu(causal_mask, diagonal=1)
890
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
891
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
892
+ if attention_mask is not None:
893
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
894
+ mask_length = attention_mask.shape[-1]
895
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
896
+ causal_mask.device
897
+ )
898
+ padding_mask = padding_mask == 0
899
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
900
+ padding_mask, min_dtype
901
+ )
902
+
903
+ return causal_mask
904
+
905
+
906
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
907
+
908
+
909
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
910
+ _tied_weights_keys = ["lm_head.weight"]
911
+ _tp_plan = {"lm_head": "colwise_rep"}
912
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
913
+
914
+ def __init__(self, config):
915
+ super().__init__(config)
916
+ self.model = DeepseekV3Model(config)
917
+ self.vocab_size = config.vocab_size
918
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
919
+
920
+ # Initialize weights and apply final processing
921
+ self.post_init()
922
+
923
+ def get_input_embeddings(self):
924
+ return self.model.embed_tokens
925
+
926
+ def set_input_embeddings(self, value):
927
+ self.model.embed_tokens = value
928
+
929
+ def get_output_embeddings(self):
930
+ return self.lm_head
931
+
932
+ def set_output_embeddings(self, new_embeddings):
933
+ self.lm_head = new_embeddings
934
+
935
+ def set_decoder(self, decoder):
936
+ self.model = decoder
937
+
938
+ def get_decoder(self):
939
+ return self.model
940
+
941
+ @can_return_tuple
942
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
943
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
944
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
945
+ def forward(
946
+ self,
947
+ input_ids: Optional[torch.LongTensor] = None,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.LongTensor] = None,
950
+ past_key_values: Optional[Cache] = None,
951
+ inputs_embeds: Optional[torch.FloatTensor] = None,
952
+ labels: Optional[torch.LongTensor] = None,
953
+ use_cache: Optional[bool] = None,
954
+ output_attentions: Optional[bool] = None,
955
+ output_hidden_states: Optional[bool] = None,
956
+ cache_position: Optional[torch.LongTensor] = None,
957
+ logits_to_keep: Union[int, torch.Tensor] = 0,
958
+ **kwargs: Unpack[KwargsForCausalLM],
959
+ ) -> CausalLMOutputWithPast:
960
+ r"""
961
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
962
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
963
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
964
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
965
+
966
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
967
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
968
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
969
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
970
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
971
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
972
+
973
+ Returns:
974
+
975
+ Example:
976
+
977
+ ```python
978
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
979
+
980
+ >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
981
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
982
+
983
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
984
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
985
+
986
+ >>> # Generate
987
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
988
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
989
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
990
+ ```"""
991
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
992
+ output_hidden_states = (
993
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
994
+ )
995
+
996
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
997
+ outputs: BaseModelOutputWithPast = self.model(
998
+ input_ids=input_ids,
999
+ attention_mask=attention_mask,
1000
+ position_ids=position_ids,
1001
+ past_key_values=past_key_values,
1002
+ inputs_embeds=inputs_embeds,
1003
+ use_cache=use_cache,
1004
+ output_attentions=output_attentions,
1005
+ output_hidden_states=output_hidden_states,
1006
+ cache_position=cache_position,
1007
+ **kwargs,
1008
+ )
1009
+
1010
+ hidden_states = outputs.last_hidden_state
1011
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1012
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1013
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1014
+
1015
+ loss = None
1016
+ if labels is not None:
1017
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1018
+
1019
+ return CausalLMOutputWithPast(
1020
+ loss=loss,
1021
+ logits=logits,
1022
+ past_key_values=outputs.past_key_values,
1023
+ hidden_states=outputs.hidden_states,
1024
+ attentions=outputs.attentions,
1025
+ )
1026
+
1027
+
1028
+ __all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"]
mtp-1-of-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e708907b5c5a584e0d81ebd2858ecf9f0f22798616a61fc273f0d39eac9512c0
3
+ size 687105960
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff