JunHowie commited on
Commit
0a20608
·
verified ·
1 Parent(s): be59b05

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .mdl +0 -0
  3. .msc +0 -0
  4. .mv +1 -0
  5. LICENSE +27 -0
  6. README.md +743 -0
  7. THIRD_PARTY_NOTICES.md +43 -0
  8. chat_template.jinja +112 -0
  9. config.json +192 -0
  10. configuration.json +1 -0
  11. configuration_deepseek.py +214 -0
  12. configuration_kimi_k25.py +123 -0
  13. docs/deploy_guidance.md +82 -0
  14. figures/demo_video.mp4 +3 -0
  15. figures/kimi-logo.png +0 -0
  16. generation_config.json +7 -0
  17. kimi_k25_processor.py +165 -0
  18. kimi_k25_vision_processing.py +251 -0
  19. media_utils.py +368 -0
  20. model-00067-of-00160.safetensors +3 -0
  21. model-00068-of-00160.safetensors +3 -0
  22. model-00069-of-00160.safetensors +3 -0
  23. model-00071-of-00160.safetensors +3 -0
  24. model-00074-of-00160.safetensors +3 -0
  25. model-00108-of-00160.safetensors +3 -0
  26. model-00112-of-00160.safetensors +3 -0
  27. model-00113-of-00160.safetensors +3 -0
  28. model-00114-of-00160.safetensors +3 -0
  29. model-00115-of-00160.safetensors +3 -0
  30. model-00116-of-00160.safetensors +3 -0
  31. model-00117-of-00160.safetensors +3 -0
  32. model-00118-of-00160.safetensors +3 -0
  33. model-00119-of-00160.safetensors +3 -0
  34. model-00120-of-00160.safetensors +3 -0
  35. model-00121-of-00160.safetensors +3 -0
  36. model-00122-of-00160.safetensors +3 -0
  37. model-00123-of-00160.safetensors +3 -0
  38. model-00124-of-00160.safetensors +3 -0
  39. model-00125-of-00160.safetensors +3 -0
  40. model-00126-of-00160.safetensors +3 -0
  41. model-00127-of-00160.safetensors +3 -0
  42. model-00128-of-00160.safetensors +3 -0
  43. model-00160-of-00160.safetensors +3 -0
  44. model.safetensors.index.json +3 -0
  45. modeling_deepseek.py +1808 -0
  46. modeling_kimi_k25.py +1248 -0
  47. preprocessor_config.json +30 -0
  48. tiktoken.model +3 -0
  49. tokenization_kimi.py +351 -0
  50. tokenizer_config.json +216 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/demo_video.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text
.mdl ADDED
Binary file (46 Bytes). View file
 
.msc ADDED
Binary file (15.8 kB). View file
 
.mv ADDED
@@ -0,0 +1 @@
 
 
1
+ Revision:master,CreatedAt:1769742814
LICENSE ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Modified MIT License
2
+
3
+ Copyright (c) 2026 Moonshot AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the “Software”), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ Our only modification part is that, if the Software (or any derivative works
24
+ thereof) is used for any of your commercial products or services that have
25
+ more than 100 million monthly active users, or more than 20 million US dollars
26
+ (or equivalent in other currencies) in monthly revenue, you shall prominently
27
+ display "Kimi K2.5" on the user interface of such product or service.
README.md ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license_name: modified-mit
4
+ pipeline_tag: image-text-to-text
5
+ tags:
6
+ - vLLM
7
+ - sglang
8
+ - Int4
9
+ base_model:
10
+ - moonshotai/Kimi-K2.5
11
+ base_model_relation: quantized
12
+ ---
13
+ # Kimi-K2.5-E304
14
+ Base model: [moonshotai/Kimi-K2.5](https://www.modelscope.cn/models/moonshotai/Kimi-K2.5)
15
+
16
+ This repo trims 20% of the experts (384 → 304);
17
+
18
+ The model format and serving setup (vLLM/SGLang versions and launch commands) match the original release.
19
+
20
+ ### 【Logs】
21
+ ```
22
+ 2026-01-30
23
+ 1. Initial commit
24
+ ```
25
+
26
+ ### 【Model Files】
27
+ | File Size | Last Updated |
28
+ |-----------|--------------|
29
+ | `444 GiB` | `2026-01-30` |
30
+
31
+ ### 【Model Download】
32
+ ```python
33
+ from modelscope import snapshot_download
34
+ snapshot_download('tclf90/Kimi-K2.5-E304', cache_dir="your_local_path")
35
+ ```
36
+
37
+ ### 【Overview】
38
+ <div align="center">
39
+ <picture>
40
+ <img src="figures/kimi-logo.png" width="30%" alt="Kimi K2.5">
41
+ </picture>
42
+ </div>
43
+ <hr>
44
+ <div align="center" style="line-height:1">
45
+ <a href="https://www.kimi.com" target="_blank"><img alt="Chat" src="https://img.shields.io/badge/🤖%20Chat-Kimi%20K2.5-ff6b6b?color=1783ff&logoColor=white"/></a>
46
+ <a href="https://www.moonshot.ai" target="_blank"><img alt="Homepage" src="https://img.shields.io/badge/Homepage-Moonshot%20AI-white?logo=Kimi&logoColor=white"/></a>
47
+ </div>
48
+
49
+ <div align="center" style="line-height: 1;">
50
+ <a href="https://huggingface.co/moonshotai" target="_blank"><img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Moonshot%20AI-ffc107?color=ffc107&logoColor=white"/></a>
51
+ <a href="https://twitter.com/kimi_moonshot" target="_blank"><img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-Kimi.ai-white?logo=x&logoColor=white"/></a>
52
+ <a href="https://discord.gg/TYU2fdJykW" target="_blank"><img alt="Discord" src="https://img.shields.io/badge/Discord-Kimi.ai-white?logo=discord&logoColor=white"/></a>
53
+ </div>
54
+ <div align="center" style="line-height: 1;">
55
+ <a href="https://huggingface.co/moonshotai/Kimi-K2.5/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Modified_MIT-f5de53?&color=f5de53"/></a>
56
+ </div>
57
+ <p align="center">
58
+ <b>📰&nbsp;&nbsp;<a href="https://www.kimi.com/blog/kimi-k2-5.html">Tech Blog</a></b>
59
+ </p>
60
+
61
+ ## 1. Model Introduction
62
+
63
+ Kimi K2.5 is an open-source, native multimodal agentic model built through continual pretraining on approximately 15 trillion mixed visual and text tokens atop Kimi-K2-Base. It seamlessly integrates vision and language understanding with advanced agentic capabilities, instant and thinking modes, as well as conversational and agentic paradigms.
64
+
65
+ ### Key Features
66
+ - **Native Multimodality**: Pre-trained on vision–language tokens, K2.5 excels in visual knowledge, cross-modal reasoning, and agentic tool use grounded in visual inputs.
67
+ - **Coding with Vision**: K2.5 generates code from visual specifications (UI designs, video workflows) and autonomously orchestrates tools for visual data processing.
68
+ - **Agent Swarm**: K2.5 transitions from single-agent scaling to a self-directed, coordinated swarm-like execution scheme. It decomposes complex tasks into parallel sub-tasks executed by dynamically instantiated, domain-specific agents.
69
+
70
+ ## 2. Model Summary
71
+
72
+ <div align="center">
73
+
74
+
75
+ | | |
76
+ |:---:|:---:|
77
+ | **Architecture** | Mixture-of-Experts (MoE) |
78
+ | **Total Parameters** | 1T |
79
+ | **Activated Parameters** | 32B |
80
+ | **Number of Layers** (Dense layer included) | 61 |
81
+ | **Number of Dense Layers** | 1 |
82
+ | **Attention Hidden Dimension** | 7168 |
83
+ | **MoE Hidden Dimension** (per Expert) | 2048 |
84
+ | **Number of Attention Heads** | 64 |
85
+ | **Number of Experts** | 384 |
86
+ | **Selected Experts per Token** | 8 |
87
+ | **Number of Shared Experts** | 1 |
88
+ | **Vocabulary Size** | 160K |
89
+ | **Context Length** | 256K |
90
+ | **Attention Mechanism** | MLA |
91
+ | **Activation Function** | SwiGLU |
92
+ | **Vision Encoder** | MoonViT |
93
+ | **Parameters of Vision Encoder** | 400M |
94
+ </div>
95
+
96
+ ## 3. Evaluation Results
97
+
98
+
99
+
100
+ <div align="center">
101
+ <table>
102
+ <thead>
103
+ <tr>
104
+ <th align="center">Benchmark</th>
105
+ <th align="center"><sup>Kimi K2.5<br><sup>(Thinking)</sup></sup></th>
106
+ <th align="center"><sup>GPT-5.2 <br><sup>(xhigh)</sup></sup></th>
107
+ <th align="center"><sup>Claude 4.5 Opus <br><sup>(Extended Thinking)</sup></sup></th>
108
+ <th align="center"><sup>Gemini 3 Pro <br><sup>(High Thinking Level)</sup></sup></th>
109
+ <th align="center"><sup>DeepSeek V3.2 <br><sup>(Thinking)</sup></sup></th>
110
+ <th align="center"><sup>Qwen3-VL-<br>235B-A22B-<br>Thinking</sup></th>
111
+ </tr>
112
+ </thead>
113
+ <tbody>
114
+ <tr>
115
+ <td align="center" colspan=8><strong>Reasoning &amp; Knowledge</strong></td>
116
+ </tr>
117
+ <tr>
118
+ <td align="center" style="vertical-align: middle">HLE-Full</td>
119
+ <td align="center" style="vertical-align: middle">30.1</td>
120
+ <td align="center" style="vertical-align: middle">34.5</td>
121
+ <td align="center" style="vertical-align: middle">30.8</td>
122
+ <td align="center" style="vertical-align: middle">37.5</td>
123
+ <td align="center" style="vertical-align: middle">25.1<sup>†</sup></td>
124
+ <td align="center" style="vertical-align: middle">-</td>
125
+ </tr>
126
+ <tr>
127
+ <td align="center" style="vertical-align: middle">HLE-Full<br>(w/ tools)</td>
128
+ <td align="center" style="vertical-align: middle">50.2</td>
129
+ <td align="center" style="vertical-align: middle">45.5</td>
130
+ <td align="center" style="vertical-align: middle">43.2</td>
131
+ <td align="center" style="vertical-align: middle">45.8</td>
132
+ <td align="center" style="vertical-align: middle">40.8<sup>†</sup></td>
133
+ <td align="center" style="vertical-align: middle">-</td>
134
+ </tr>
135
+ <tr>
136
+ <td align="center" style="vertical-align: middle">AIME 2025</td>
137
+ <td align="center" style="vertical-align: middle">96.1</td>
138
+ <td align="center" style="vertical-align: middle">100</td>
139
+ <td align="center" style="vertical-align: middle">92.8</td>
140
+ <td align="center" style="vertical-align: middle">95.0</td>
141
+ <td align="center" style="vertical-align: middle">93.1</td>
142
+ <td align="center" style="vertical-align: middle">-</td>
143
+ </tr>
144
+ <tr>
145
+ <td align="center" style="vertical-align: middle">HMMT 2025 (Feb)</td>
146
+ <td align="center" style="vertical-align: middle">95.4</td>
147
+ <td align="center" style="vertical-align: middle">99.4</td>
148
+ <td align="center" style="vertical-align: middle">92.9*</td>
149
+ <td align="center" style="vertical-align: middle">97.3*</td>
150
+ <td align="center" style="vertical-align: middle">92.5</td>
151
+ <td align="center" style="vertical-align: middle">-</td>
152
+ </tr>
153
+ <tr>
154
+ <td align="center" style="vertical-align: middle">IMO-AnswerBench</td>
155
+ <td align="center" style="vertical-align: middle">81.8</td>
156
+ <td align="center" style="vertical-align: middle">86.3</td>
157
+ <td align="center" style="vertical-align: middle">78.5*</td>
158
+ <td align="center" style="vertical-align: middle">83.1*</td>
159
+ <td align="center" style="vertical-align: middle">78.3</td>
160
+ <td align="center" style="vertical-align: middle">-</td>
161
+ </tr>
162
+ <tr>
163
+ <td align="center" style="vertical-align: middle">GPQA-Diamond</td>
164
+ <td align="center" style="vertical-align: middle">87.6</td>
165
+ <td align="center" style="vertical-align: middle">92.4</td>
166
+ <td align="center" style="vertical-align: middle">87.0</td>
167
+ <td align="center" style="vertical-align: middle">91.9</td>
168
+ <td align="center" style="vertical-align: middle">82.4</td>
169
+ <td align="center" style="vertical-align: middle">-</td>
170
+ </tr>
171
+ <tr>
172
+ <td align="center" style="vertical-align: middle">MMLU-Pro</td>
173
+ <td align="center" style="vertical-align: middle">87.1</td>
174
+ <td align="center" style="vertical-align: middle">86.7*</td>
175
+ <td align="center" style="vertical-align: middle">89.3*</td>
176
+ <td align="center" style="vertical-align: middle">90.1</td>
177
+ <td align="center" style="vertical-align: middle">85.0</td>
178
+ <td align="center" style="vertical-align: middle">-</td>
179
+ </tr>
180
+ <tr>
181
+ <td align="center" colspan=8><strong>Image &amp; Video</strong></td>
182
+ </tr>
183
+ <tr>
184
+ <td align="center" style="vertical-align: middle">MMMU-Pro</td>
185
+ <td align="center" style="vertical-align: middle">78.5</td>
186
+ <td align="center" style="vertical-align: middle">79.5*</td>
187
+ <td align="center" style="vertical-align: middle">74.0</td>
188
+ <td align="center" style="vertical-align: middle">81.0</td>
189
+ <td align="center" style="vertical-align: middle">-</td>
190
+ <td align="center" style="vertical-align: middle">69.3</td>
191
+ </tr>
192
+ <tr>
193
+ <td align="center" style="vertical-align: middle">CharXiv (RQ)</td>
194
+ <td align="center" style="vertical-align: middle">77.5</td>
195
+ <td align="center" style="vertical-align: middle">82.1</td>
196
+ <td align="center" style="vertical-align: middle">67.2*</td>
197
+ <td align="center" style="vertical-align: middle">81.4</td>
198
+ <td align="center" style="vertical-align: middle">-</td>
199
+ <td align="center" style="vertical-align: middle">66.1</td>
200
+ </tr>
201
+ <tr>
202
+ <td align="center" style="vertical-align: middle">MathVision</td>
203
+ <td align="center" style="vertical-align: middle">84.2</td>
204
+ <td align="center" style="vertical-align: middle">83.0</td>
205
+ <td align="center" style="vertical-align: middle">77.1*</td>
206
+ <td align="center" style="vertical-align: middle">86.1*</td>
207
+ <td align="center" style="vertical-align: middle">-</td>
208
+ <td align="center" style="vertical-align: middle">74.6</td>
209
+ </tr>
210
+ <tr>
211
+ <td align="center" style="vertical-align: middle">MathVista (mini)</td>
212
+ <td align="center" style="vertical-align: middle">90.1</td>
213
+ <td align="center" style="vertical-align: middle">82.8*</td>
214
+ <td align="center" style="vertical-align: middle">80.2*</td>
215
+ <td align="center" style="vertical-align: middle">89.8*</td>
216
+ <td align="center" style="vertical-align: middle">-</td>
217
+ <td align="center" style="vertical-align: middle">85.8</td>
218
+ </tr>
219
+ <tr>
220
+ <td align="center" style="vertical-align: middle">ZeroBench</td>
221
+ <td align="center" style="vertical-align: middle">9</td>
222
+ <td align="center" style="vertical-align: middle">9*</td>
223
+ <td align="center" style="vertical-align: middle">3*</td>
224
+ <td align="center" style="vertical-align: middle">8*</td>
225
+ <td align="center" style="vertical-align: middle">-</td>
226
+ <td align="center" style="vertical-align: middle">4*</td>
227
+ </tr>
228
+ <tr>
229
+ <td align="center" style="vertical-align: middle">ZeroBench<br>(w/ tools)</td>
230
+ <td align="center" style="vertical-align: middle">11</td>
231
+ <td align="center" style="vertical-align: middle">7*</td>
232
+ <td align="center" style="vertical-align: middle">9*</td>
233
+ <td align="center" style="vertical-align: middle">12*</td>
234
+ <td align="center" style="vertical-align: middle">-</td>
235
+ <td align="center" style="vertical-align: middle">3*</td>
236
+ </tr>
237
+ <tr>
238
+ <td align="center" style="vertical-align: middle">OCRBench</td>
239
+ <td align="center" style="vertical-align: middle">92.3</td>
240
+ <td align="center" style="vertical-align: middle">80.7*</td>
241
+ <td align="center" style="vertical-align: middle">86.5*</td>
242
+ <td align="center" style="vertical-align: middle">90.3*</td>
243
+ <td align="center" style="vertical-align: middle">-</td>
244
+ <td align="center" style="vertical-align: middle">87.5</td>
245
+ </tr>
246
+ <tr>
247
+ <td align="center" style="vertical-align: middle">OmniDocBench 1.5</td>
248
+ <td align="center" style="vertical-align: middle">88.8</td>
249
+ <td align="center" style="vertical-align: middle">85.7</td>
250
+ <td align="center" style="vertical-align: middle">87.7*</td>
251
+ <td align="center" style="vertical-align: middle">88.5</td>
252
+ <td align="center" style="vertical-align: middle">-</td>
253
+ <td align="center" style="vertical-align: middle">82.0*</td>
254
+ </tr>
255
+ <tr>
256
+ <td align="center" style="vertical-align: middle">InfoVQA (val)</td>
257
+ <td align="center" style="vertical-align: middle">92.6</td>
258
+ <td align="center" style="vertical-align: middle">84*</td>
259
+ <td align="center" style="vertical-align: middle">76.9*</td>
260
+ <td align="center" style="vertical-align: middle">57.2*</td>
261
+ <td align="center" style="vertical-align: middle">-</td>
262
+ <td align="center" style="vertical-align: middle">89.5</td>
263
+ </tr>
264
+ <tr>
265
+ <td align="center" style="vertical-align: middle">SimpleVQA</td>
266
+ <td align="center" style="vertical-align: middle">71.2</td>
267
+ <td align="center" style="vertical-align: middle">55.8*</td>
268
+ <td align="center" style="vertical-align: middle">69.7*</td>
269
+ <td align="center" style="vertical-align: middle">69.7*</td>
270
+ <td align="center" style="vertical-align: middle">-</td>
271
+ <td align="center" style="vertical-align: middle">56.8*</td>
272
+ </tr>
273
+ <tr>
274
+ <td align="center" style="vertical-align: middle"><a href="https://github.com/MoonshotAI/WorldVQA">WorldVQA</a></td>
275
+ <td align="center" style="vertical-align: middle">46.3</td>
276
+ <td align="center" style="vertical-align: middle">28.0</td>
277
+ <td align="center" style="vertical-align: middle">36.8</td>
278
+ <td align="center" style="vertical-align: middle">47.4</td>
279
+ <td align="center" style="vertical-align: middle">-</td>
280
+ <td align="center" style="vertical-align: middle">23.5</td>
281
+ </tr>
282
+ <tr>
283
+ <td align="center" style="vertical-align: middle">VideoMMMU</td>
284
+ <td align="center" style="vertical-align: middle">86.6</td>
285
+ <td align="center" style="vertical-align: middle">85.9</td>
286
+ <td align="center" style="vertical-align: middle">84.4*</td>
287
+ <td align="center" style="vertical-align: middle">87.6</td>
288
+ <td align="center" style="vertical-align: middle">-</td>
289
+ <td align="center" style="vertical-align: middle">80.0</td>
290
+ </tr>
291
+ <tr>
292
+ <td align="center" style="vertical-align: middle">MMVU</td>
293
+ <td align="center" style="vertical-align: middle">80.4</td>
294
+ <td align="center" style="vertical-align: middle">80.8*</td>
295
+ <td align="center" style="vertical-align: middle">77.3</td>
296
+ <td align="center" style="vertical-align: middle">77.5</td>
297
+ <td align="center" style="vertical-align: middle">-</td>
298
+ <td align="center" style="vertical-align: middle">71.1</td>
299
+ </tr>
300
+ <tr>
301
+ <td align="center" style="vertical-align: middle">MotionBench</td>
302
+ <td align="center" style="vertical-align: middle">70.4</td>
303
+ <td align="center" style="vertical-align: middle">64.8</td>
304
+ <td align="center" style="vertical-align: middle">60.3</td>
305
+ <td align="center" style="vertical-align: middle">70.3</td>
306
+ <td align="center" style="vertical-align: middle">-</td>
307
+ <td align="center" style="vertical-align: middle">-</td>
308
+ </tr>
309
+ <tr>
310
+ <td align="center" style="vertical-align: middle">VideoMME</td>
311
+ <td align="center" style="vertical-align: middle">87.4</td>
312
+ <td align="center" style="vertical-align: middle">86.0*</td>
313
+ <td align="center" style="vertical-align: middle">-</td>
314
+ <td align="center" style="vertical-align: middle">88.4*</td>
315
+ <td align="center" style="vertical-align: middle">-</td>
316
+ <td align="center" style="vertical-align: middle">79.0</td>
317
+ </tr>
318
+ <tr>
319
+ <td align="center" style="vertical-align: middle">LongVideoBench</td>
320
+ <td align="center" style="vertical-align: middle">79.8</td>
321
+ <td align="center" style="vertical-align: middle">76.5*</td>
322
+ <td align="center" style="vertical-align: middle">67.2*</td>
323
+ <td align="center" style="vertical-align: middle">77.7*</td>
324
+ <td align="center" style="vertical-align: middle">-</td>
325
+ <td align="center" style="vertical-align: middle">65.6*</td>
326
+ </tr>
327
+ <tr>
328
+ <td align="center" style="vertical-align: middle">LVBench</td>
329
+ <td align="center" style="vertical-align: middle">75.9</td>
330
+ <td align="center" style="vertical-align: middle">-</td>
331
+ <td align="center" style="vertical-align: middle">-</td>
332
+ <td align="center" style="vertical-align: middle">73.5*</td>
333
+ <td align="center" style="vertical-align: middle">-</td>
334
+ <td align="center" style="vertical-align: middle">63.6</td>
335
+ </tr>
336
+ <tr>
337
+ <td align="center" colspan=8><strong>Coding</strong></td>
338
+ </tr>
339
+ <tr>
340
+ <td align="center" style="vertical-align: middle">SWE-Bench Verified</td>
341
+ <td align="center" style="vertical-align: middle">76.8</td>
342
+ <td align="center" style="vertical-align: middle">80.0</td>
343
+ <td align="center" style="vertical-align: middle">80.9</td>
344
+ <td align="center" style="vertical-align: middle">76.2</td>
345
+ <td align="center" style="vertical-align: middle">73.1</td>
346
+ <td align="center" style="vertical-align: middle">-</td>
347
+ </tr>
348
+ <tr>
349
+ <td align="center" style="vertical-align: middle">SWE-Bench Pro</td>
350
+ <td align="center" style="vertical-align: middle">50.7</td>
351
+ <td align="center" style="vertical-align: middle">55.6</td>
352
+ <td align="center" style="vertical-align: middle">55.4*</td>
353
+ <td align="center" style="vertical-align: middle">-</td>
354
+ <td align="center" style="vertical-align: middle">-</td>
355
+ <td align="center" style="vertical-align: middle">-</td>
356
+ </tr>
357
+ <tr>
358
+ <td align="center" style="vertical-align: middle">SWE-Bench Multilingual</td>
359
+ <td align="center" style="vertical-align: middle">73.0</td>
360
+ <td align="center" style="vertical-align: middle">72.0</td>
361
+ <td align="center" style="vertical-align: middle">77.5</td>
362
+ <td align="center" style="vertical-align: middle">65.0</td>
363
+ <td align="center" style="vertical-align: middle">70.2</td>
364
+ <td align="center" style="vertical-align: middle">-</td>
365
+ </tr>
366
+ <tr>
367
+ <td align="center" style="vertical-align: middle">Terminal Bench 2.0</td>
368
+ <td align="center" style="vertical-align: middle">50.8</td>
369
+ <td align="center" style="vertical-align: middle">54.0</td>
370
+ <td align="center" style="vertical-align: middle">59.3</td>
371
+ <td align="center" style="vertical-align: middle">54.2</td>
372
+ <td align="center" style="vertical-align: middle">46.4</td>
373
+ <td align="center" style="vertical-align: middle">-</td>
374
+ </tr>
375
+ <tr>
376
+ <td align="center" style="vertical-align: middle">PaperBench</td>
377
+ <td align="center" style="vertical-align: middle">63.5</td>
378
+ <td align="center" style="vertical-align: middle">63.7*</td>
379
+ <td align="center" style="vertical-align: middle">72.9*</td>
380
+ <td align="center" style="vertical-align: middle">-</td>
381
+ <td align="center" style="vertical-align: middle">47.1</td>
382
+ <td align="center" style="vertical-align: middle">-</td>
383
+ </tr>
384
+ <tr>
385
+ <td align="center" style="vertical-align: middle">CyberGym</td>
386
+ <td align="center" style="vertical-align: middle">41.3</td>
387
+ <td align="center" style="vertical-align: middle">-</td>
388
+ <td align="center" style="vertical-align: middle">50.6</td>
389
+ <td align="center" style="vertical-align: middle">39.9*</td>
390
+ <td align="center" style="vertical-align: middle">17.3*</td>
391
+ <td align="center" style="vertical-align: middle">-</td>
392
+ </tr>
393
+ <tr>
394
+ <td align="center" style="vertical-align: middle">SciCode</td>
395
+ <td align="center" style="vertical-align: middle">48.7</td>
396
+ <td align="center" style="vertical-align: middle">52.1</td>
397
+ <td align="center" style="vertical-align: middle">49.5</td>
398
+ <td align="center" style="vertical-align: middle">56.1</td>
399
+ <td align="center" style="vertical-align: middle">38.9</td>
400
+ <td align="center" style="vertical-align: middle">-</td>
401
+ </tr>
402
+ <tr>
403
+ <td align="center" style="vertical-align: middle">OJBench (cpp)</td>
404
+ <td align="center" style="vertical-align: middle">57.4</td>
405
+ <td align="center" style="vertical-align: middle">-</td>
406
+ <td align="center" style="vertical-align: middle">54.6*</td>
407
+ <td align="center" style="vertical-align: middle">68.5*</td>
408
+ <td align="center" style="vertical-align: middle">54.7*</td>
409
+ <td align="center" style="vertical-align: middle">-</td>
410
+ </tr>
411
+ <tr>
412
+ <td align="center" style="vertical-align: middle">LiveCodeBench (v6)</td>
413
+ <td align="center" style="vertical-align: middle">85.0</td>
414
+ <td align="center" style="vertical-align: middle">-</td>
415
+ <td align="center" style="vertical-align: middle">82.2*</td>
416
+ <td align="center" style="vertical-align: middle">87.4*</td>
417
+ <td align="center" style="vertical-align: middle">83.3</td>
418
+ <td align="center" style="vertical-align: middle">-</td>
419
+ </tr>
420
+ <tr>
421
+ <td align="center" colspan=8><strong>Long Context</strong></td>
422
+ </tr>
423
+ <tr>
424
+ <td align="center" style="vertical-align: middle">Longbench v2</td>
425
+ <td align="center" style="vertical-align: middle">61.0</td>
426
+ <td align="center" style="vertical-align: middle">54.5*</td>
427
+ <td align="center" style="vertical-align: middle">64.4*</td>
428
+ <td align="center" style="vertical-align: middle">68.2*</td>
429
+ <td align="center" style="vertical-align: middle">59.8*</td>
430
+ <td align="center" style="vertical-align: middle">-</td>
431
+ </tr>
432
+ <tr>
433
+ <td align="center" style="vertical-align: middle">AA-LCR</td>
434
+ <td align="center" style="vertical-align: middle">70.0</td>
435
+ <td align="center" style="vertical-align: middle">72.3*</td>
436
+ <td align="center" style="vertical-align: middle">71.3*</td>
437
+ <td align="center" style="vertical-align: middle">65.3*</td>
438
+ <td align="center" style="vertical-align: middle">64.3*</td>
439
+ <td align="center" style="vertical-align: middle">-</td>
440
+ <tr>
441
+ <td align="center" colspan=8><strong>Agentic Search</strong></td>
442
+ </tr>
443
+ <tr>
444
+ <td align="center" style="vertical-align: middle">BrowseComp</td>
445
+ <td align="center" style="vertical-align: middle">60.6</td>
446
+ <td align="center" style="vertical-align: middle" rowspan="2">65.8</td>
447
+ <td align="center" style="vertical-align: middle">37.0</td>
448
+ <td align="center" style="vertical-align: middle">37.8</td>
449
+ <td align="center" style="vertical-align: middle">51.4</td>
450
+ <td align="center" style="vertical-align: middle">-</td>
451
+ </tr>
452
+ <tr>
453
+ <td align="center" style="vertical-align: middle">BrowseComp<br>(w/ctx manage)</td>
454
+ <td align="center" style="vertical-align: middle">74.9</td>
455
+ <td align="center" style="vertical-align: middle">57.8</td>
456
+ <td align="center" style="vertical-align: middle">59.2</td>
457
+ <td align="center" style="vertical-align: middle">67.6</td>
458
+ <td align="center" style="vertical-align: middle">-</td>
459
+ </tr>
460
+ <tr>
461
+ <td align="center" style="vertical-align: middle">BrowseComp<br>(Agent Swarm)</td>
462
+ <td align="center" style="vertical-align: middle">78.4</td>
463
+ <td align="center" style="vertical-align: middle">-</td>
464
+ <td align="center" style="vertical-align: middle">-</td>
465
+ <td align="center" style="vertical-align: middle">-</td>
466
+ <td align="center" style="vertical-align: middle">-</td>
467
+ <td align="center" style="vertical-align: middle">-</td>
468
+ </tr>
469
+ <tr>
470
+ <td align="center" style="vertical-align: middle">WideSearch<br> (iter-f1)</td>
471
+ <td align="center" style="vertical-align: middle">72.7</td>
472
+ <td align="center" style="vertical-align: middle">-</td>
473
+ <td align="center" style="vertical-align: middle">76.2*</td>
474
+ <td align="center" style="vertical-align: middle">57.0</td>
475
+ <td align="center" style="vertical-align: middle">32.5*</td>
476
+ <td align="center" style="vertical-align: middle">-</td>
477
+ </tr>
478
+ <tr>
479
+ <td align="center" style="vertical-align: middle">WideSearch<br> (iter-f1 Agent Swarm)</td>
480
+ <td align="center" style="vertical-align: middle">79.0</td>
481
+ <td align="center" style="vertical-align: middle">-</td>
482
+ <td align="center" style="vertical-align: middle">-</td>
483
+ <td align="center" style="vertical-align: middle">-</td>
484
+ <td align="center" style="vertical-align: middle">-</td>
485
+ <td align="center" style="vertical-align: middle">-</td>
486
+ </tr>
487
+ <tr>
488
+ <td align="center" style="vertical-align: middle">DeepSearchQA</td>
489
+ <td align="center" style="vertical-align: middle">77.1</td>
490
+ <td align="center" style="vertical-align: middle">71.3*</td>
491
+ <td align="center" style="vertical-align: middle">76.1*</td>
492
+ <td align="center" style="vertical-align: middle">63.2*</td>
493
+ <td align="center" style="vertical-align: middle">60.9*</td>
494
+ <td align="center" style="vertical-align: middle">-</td>
495
+ </tr>
496
+ <tr>
497
+ <td align="center" style="vertical-align: middle">FinSearchCompT2&T3</td>
498
+ <td align="center" style="vertical-align: middle">67.8</td>
499
+ <td align="center" style="vertical-align: middle">-</td>
500
+ <td align="center" style="vertical-align: middle">66.2*</td>
501
+ <td align="center" style="vertical-align: middle">49.9</td>
502
+ <td align="center" style="vertical-align: middle">59.1*</td>
503
+ <td align="center" style="vertical-align: middle">-</td>
504
+ </tr>
505
+ <tr>
506
+ <td align="center" style="vertical-align: middle">Seal-0</td>
507
+ <td align="center" style="vertical-align: middle">57.4</td>
508
+ <td align="center" style="vertical-align: middle">45.0</td>
509
+ <td align="center" style="vertical-align: middle">47.7*</td>
510
+ <td align="center" style="vertical-align: middle">45.5*</td>
511
+ <td align="center" style="vertical-align: middle">49.5*</td>
512
+ <td align="center" style="vertical-align: middle">-</td>
513
+ </tr>
514
+ </tbody>
515
+ </table>
516
+ </div>
517
+
518
+ <details>
519
+ <summary><b>Footnotes</b></summary>
520
+
521
+ 1. General Testing Details
522
+ - We report results for Kimi K2.5 and DeepSeek-V3.2 with thinking mode enabled, Claude Opus 4.5 with extended thinking mode, GPT-5.2 with xhigh reasoning effort, and Gemini 3 Pro with a high thinking level. For vision benchmarks, we additionally report results for Qwen3-VL-235B-A22B-Thinking.
523
+ - Unless otherwise specified, all Kimi K2.5 experiments were conducted with temperature = 1.0, top-p = 0.95, and a context length of 256k tokens.
524
+ - Benchmarks without publicly available scores were re-evaluated under the same conditions used for Kimi K2.5 and are marked with an asterisk (*).
525
+ - We could not evaluate GPT-5.2 xhigh on all benchmarks due to service stability issues. For benchmarks that were not tested, we mark them as "-".
526
+ 2. Text and Reasoning
527
+ - HLE, AIME 2025, HMMT 2025 (Feb), and GPQA-Diamond were evaluated with a maximum completion budget of 96k tokens.
528
+ - Results for AIME and HMMT are averaged over 32 runs (avg@32); GPQA-Diamond over 8 runs (avg@8).
529
+ - For HLE, we report scores on the full set (text & image). Kimi K2.5 scores 31.5 (text) and 21.3 (image) without tools, and 51.8 (text) and 39.8 (image) with tools. The DeepSeek-V3.2 score corresponds to its text-only subset (marked with †) . Hugging Face access was blocked to prevent potential data leakage. HLE with tools uses simple context management: once the context exceeds a threshold, only the latest round of tool messages is retained.
530
+ 3. Tool-Augmented / Agentic Search
531
+ - Kimi K2.5 was equipped with search, code-interpreter, and web-browsing tools for HLE with tools and all agentic search benchmarks.
532
+ - Except for BrowseComp (where K2.5 and DeepSeek-V3.2 used the discard-all strategy), no context management was applied, and tasks exceeding the supported context length were directly counted as failed.
533
+ - The test system prompts emphasize deep and proactive tool use, instructing models to reason carefully, leverage tools, and verify uncertain information. Full prompts will be provided in the technical report.
534
+ - Results for Seal-0 and WideSearch are averaged over four runs (avg@4).
535
+ 4. Vision Benchmarks
536
+ - Max-tokens = 64k, averaged over three runs (avg@3).
537
+ - ZeroBench (w/ tools) uses max-tokens-per-step = 24k and max-steps = 30 for multi-step reasoning.
538
+ - MMMU-Pro follows the official protocol, preserving input order and prepending images.
539
+ - GPT-5.2-xhigh had ~10% failure rate (no output despite 3 retries), treated as incorrect; reported scores likely underestimate true performance.
540
+ - WorldVQA, a benchmark designed to evaluate atomic vision-centric world knowledge. Access WorldVQA at https://github.com/MoonshotAI/WorldVQA.
541
+ - OmniDocBench Score is computed as (1 − normalized Levenshtein distance) × 100, where a higher score denotes superior accuracy.
542
+ 5. Coding Tasks
543
+ - Terminal-Bench 2.0 scores were obtained with the default agent framework (Terminus-2) and the provided JSON parser. In our implementation, we evaluated Terminal-Bench 2.0 under non-thinking mode. This choice was made because our current context management strategy for the thinking mode is incompatible with Terminus-2.
544
+ - For the SWE-Bench series of evaluations (including verified, multilingual, and pro), we used an internally developed evaluation framework. This framework includes a minimal set of tools—bash tool, createfile tool, insert tool, view tool, strreplace tool, and submit tool—along with tailored system prompts designed for the tasks. The highest scores were achieved under non-thinking mode.
545
+ - The score of Claude Opus 4.5 on CyberGym is reported under the non-thinking setting.
546
+ - All reported scores of coding tasks are averaged over 5 independent runs.
547
+ 6. Long-Context Benchmarks
548
+ - AA-LCR: scores averaged over three runs (avg@3).
549
+ - LongBench-V2: identical prompts and input contexts standardized to ~128k tokens.
550
+ 7. Agent Swarm
551
+ - BrowseComp (Swarm Mode): main agent max 15 steps; sub-agents max 100 steps.
552
+ - WideSearch (Swarm Mode): main and sub-agents max 100 steps.
553
+
554
+ </details>
555
+
556
+ ## 4. Native INT4 Quantization
557
+ Kimi-K2.5 adopts the same native int4 quantization method as [Kimi-K2-Thinking](https://huggingface.co/moonshotai/Kimi-K2-Thinking#4-native-int4-quantization).
558
+
559
+ ## 5. Deployment
560
+ > [!Note]
561
+ > You can access Kimi-K2.5's API on https://platform.moonshot.ai , we provide OpenAI/Anthropic-compatible API for you. To verify the deployment is correct, we also provide the [Kimi Vendor Verifier](https://kimi.com/blog/kimi-vendor-verifier.html).
562
+ Currently, Kimi-K2.5 is recommended to run on the following inference engines:
563
+ * vLLM
564
+ * SGLang
565
+ * KTransformers
566
+
567
+ Deployment examples can be found in the [Model Deployment Guide](docs/deploy_guidance.md).
568
+
569
+ ---
570
+ ## 6. Model Usage
571
+
572
+ The usage demos below demonstrate how to call our official API.
573
+
574
+ For third-party API deployed with vLLM or SGLang, please note that :
575
+ > [!Note]
576
+ > - Chat with video content is an experimental feature and is only supported in our official API for now
577
+ >
578
+ > - The recommended `temperature` will be `1.0` for Thinking mode and `0.6` for Instant mode.
579
+ >
580
+ > - The recommended `top_p` is `0.95`
581
+ >
582
+ > - To use instant mode, you need to pass `{'chat_template_kwargs': {"thinking": False}}` in `extra_body`.
583
+
584
+ ### Chat Completion
585
+
586
+ This is a simple chat completion script which shows how to call K2.5 API in Thinking and Instant modes.
587
+
588
+ ```python
589
+ import openai
590
+ import base64
591
+ import requests
592
+ def simple_chat(client: openai.OpenAI, model_name: str):
593
+ messages = [
594
+ {'role': 'system', 'content': 'You are Kimi, an AI assistant created by Moonshot AI.'},
595
+ {
596
+ 'role': 'user',
597
+ 'content': [
598
+ {'type': 'text', 'text': 'which one is bigger, 9.11 or 9.9? think carefully.'}
599
+ ],
600
+ },
601
+ ]
602
+ response = client.chat.completions.create(
603
+ model=model_name, messages=messages, stream=False, max_tokens=4096
604
+ )
605
+ print('===== Below is reasoning_content in Thinking Mode ======')
606
+ print(f'reasoning content: {response.choices[0].message.reasoning_content}')
607
+ print('===== Below is response in Thinking Mode ======')
608
+ print(f'response: {response.choices[0].message.content}')
609
+
610
+ # To use instant mode, pass {"thinking" = {"type":"disabled"}}
611
+ response = client.chat.completions.create(
612
+ model=model_name,
613
+ messages=messages,
614
+ stream=False,
615
+ max_tokens=4096,
616
+ extra_body={'thinking': {'type': 'disabled'}}, # this is for official API
617
+ # extra_body= {'chat_template_kwargs': {"thinking": False}} # this is for vLLM/SGLang
618
+ )
619
+ print('===== Below is response in Instant Mode ======')
620
+ print(f'response: {response.choices[0].message.content}')
621
+ ```
622
+
623
+
624
+ ### Chat Completion with visual content
625
+
626
+ K2.5 supports Image and Video input.
627
+
628
+ The following example demonstrates how to call K2.5 API with image input:
629
+
630
+ ```python
631
+ import openai
632
+ import base64
633
+ import requests
634
+
635
+ def chat_with_image(client: openai.OpenAI, model_name: str):
636
+ url = 'https://huggingface.co/moonshotai/Kimi-K2.5/resolve/main/figures/kimi-logo.png'
637
+ image_base64 = base64.b64encode(requests.get(url).content).decode()
638
+ messages = [
639
+ {
640
+ 'role': 'user',
641
+ 'content': [
642
+ {'type': 'text', 'text': 'Describe this image in detail.'},
643
+ {
644
+ 'type': 'image_url',
645
+ 'image_url': {'url': f'data:image/png;base64, {image_base64}'},
646
+ },
647
+ ],
648
+ }
649
+ ]
650
+
651
+ response = client.chat.completions.create(
652
+ model=model_name, messages=messages, stream=False, max_tokens=8192
653
+ )
654
+ print('===== Below is reasoning_content in Thinking Mode ======')
655
+ print(f'reasoning content: {response.choices[0].message.reasoning_content}')
656
+ print('===== Below is response in Thinking Mode ======')
657
+ print(f'response: {response.choices[0].message.content}')
658
+
659
+ # Also support instant mode if pass {"thinking" = {"type":"disabled"}}
660
+ response = client.chat.completions.create(
661
+ model=model_name,
662
+ messages=messages,
663
+ stream=False,
664
+ max_tokens=4096,
665
+ extra_body={'thinking': {'type': 'disabled'}}, # this is for official API
666
+ # extra_body= {'chat_template_kwargs': {"thinking": False}} # this is for vLLM/SGLang
667
+ )
668
+ print('===== Below is response in Instant Mode ======')
669
+ print(f'response: {response.choices[0].message.content}')
670
+
671
+ return response.choices[0].message.content
672
+ ```
673
+
674
+ The following example demonstrates how to call K2.5 API with video input:
675
+
676
+ ```python
677
+ import openai
678
+ import base64
679
+ import requests
680
+
681
+ def chat_with_video(client: openai.OpenAI, model_name:str):
682
+ url = 'https://huggingface.co/moonshotai/Kimi-K2.5/resolve/main/figures/demo_video.mp4'
683
+ video_base64 = base64.b64encode(requests.get(url).content).decode()
684
+ messages = [
685
+ {
686
+ "role": "user",
687
+ "content": [
688
+ {"type": "text","text": "Describe the video in detail."},
689
+ {
690
+ "type": "video_url",
691
+ "video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
692
+ },
693
+ ],
694
+ }
695
+ ]
696
+
697
+ response = client.chat.completions.create(model=model_name, messages=messages)
698
+ print('===== Below is reasoning_content in Thinking Mode ======')
699
+ print(f'reasoning content: {response.choices[0].message.reasoning_content}')
700
+ print('===== Below is response in Thinking Mode ======')
701
+ print(f'response: {response.choices[0].message.content}')
702
+
703
+ # Also support instant mode if pass {"thinking" = {"type":"disabled"}}
704
+ response = client.chat.completions.create(
705
+ model=model_name,
706
+ messages=messages,
707
+ stream=False,
708
+ max_tokens=4096,
709
+ extra_body={'thinking': {'type': 'disabled'}}, # this is for official API
710
+ # extra_body= {'chat_template_kwargs': {"thinking": False}} # this is for vLLM/SGLang
711
+ )
712
+ print('===== Below is response in Instant Mode ======')
713
+ print(f'response: {response.choices[0].message.content}')
714
+ return response.choices[0].message.content
715
+ ```
716
+
717
+ ### Interleaved Thinking and Multi-Step Tool Call
718
+
719
+ K2.5 shares the same design of Interleaved Thinking and Multi-Step Tool Call as K2 Thinking. For usage example, please refer to the [K2 Thinking documentation](https://platform.moonshot.ai/docs/guide/use-kimi-k2-thinking-model#complete-example).
720
+
721
+
722
+ ### Coding Agent Framework
723
+
724
+ Kimi K2.5 works best with Kimi Code CLI as its agent framework — give it a try at https://www.kimi.com/code.
725
+
726
+
727
+ ---
728
+
729
+ ## 7. License
730
+
731
+ Both the code repository and the model weights are released under the [Modified MIT License](LICENSE).
732
+
733
+ ---
734
+
735
+ ## 8. Third Party Notices
736
+
737
+ See [THIRD PARTY NOTICES](THIRD_PARTY_NOTICES.md)
738
+
739
+ ---
740
+
741
+ ## 9. Contact Us
742
+
743
+ If you have any questions, please reach out at [support@moonshot.cn](mailto:support@moonshot.cn).
THIRD_PARTY_NOTICES.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIRD_PARTY_NOTICES
2
+
3
+ This file lists third-party software contained in Kimi-K2.5 along with their licenses, in compliance with the redistribution clauses of those licenses.
4
+
5
+ ---
6
+
7
+ ## 1. DeepSeek-V3
8
+
9
+ Our model archietecture is DeepSeek-V3-like. Some of modeling codes are copied from the source repository.
10
+
11
+ - **Source Repository**
12
+ https://huggingface.co/deepseek-ai/DeepSeek-V3
13
+
14
+ - **Files / Directories Used**
15
+ - configuration_deepseek.py
16
+ - modeling_deepseek.py
17
+
18
+ - **License Type**
19
+ MIT License
20
+
21
+ - **Copyright Notice**
22
+ Copyright (c) 2023 DeepSeek
23
+
24
+ - **Full License Text**
25
+ ```
26
+ MIT License
27
+ Copyright (c) 2023 DeepSeek
28
+ Permission is hereby granted, free of charge, to any person obtaining a copy
29
+ of this software and associated documentation files (the "Software"), to deal
30
+ in the Software without restriction, including without limitation the rights
31
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
32
+ copies of the Software, and to permit persons to whom the Software is
33
+ furnished to do so, subject to the following conditions:
34
+ The above copyright notice and this permission notice shall be included in all
35
+ copies or substantial portions of the Software.
36
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
37
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
38
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
39
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
40
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
41
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
42
+ SOFTWARE.
43
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- macro render_content(msg) -%}
2
+ {%- set c = msg.get('content') -%}
3
+ {%- if c is string -%}
4
+ {{ c }}
5
+ {%- elif c is not none -%}
6
+ {% for content in c -%}
7
+ {% if content['type'] == 'image' or content['type'] == 'image_url' -%}
8
+ <|media_start|>image<|media_content|><|media_pad|><|media_end|>
9
+ {% elif content['type'] == 'video' or content['type']== 'video_url'-%}
10
+ <|kimi_k25_video_placeholder|>
11
+ {% else -%}
12
+ {{ content['text'] }}
13
+ {%- endif -%}
14
+ {%- endfor -%}
15
+ {%- endif -%}
16
+ {%- endmacro -%}
17
+
18
+ {% macro set_roles(message) -%}
19
+ {%- set role_name = message.get('name') or message['role'] -%}
20
+ {%- if message['role'] == 'user' -%}
21
+ <|im_user|>{{role_name}}<|im_middle|>
22
+ {%- elif message['role'] == 'assistant' -%}
23
+ <|im_assistant|>{{role_name}}<|im_middle|>
24
+ {%- else -%}
25
+ <|im_system|>{{role_name}}<|im_middle|>
26
+ {%- endif -%}
27
+ {%- endmacro -%}
28
+
29
+
30
+ {%- macro render_toolcalls(message) -%}
31
+ <|tool_calls_section_begin|>
32
+ {%- for tool_call in message['tool_calls'] -%}
33
+ {%- set formatted_id = tool_call['id'] -%}
34
+ <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
35
+ {%- endfor -%}
36
+ <|tool_calls_section_end|>
37
+ {%- endmacro -%}
38
+
39
+
40
+ {# Find last non-tool-call assisitant message #}
41
+ {%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%}
42
+ {%- for idx in range(messages|length-1, -1, -1) -%}
43
+ {%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%}
44
+ {%- set ns.last_non_tool_call_assistant_msg = idx -%}
45
+ {%- break -%}
46
+ {%- endif -%}
47
+ {%- endfor -%}
48
+
49
+ {# split all messages into history & suffix, reasoning_content in suffix should be reserved.#}
50
+ {%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%}
51
+ {%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%}
52
+
53
+ {%- if tools -%}
54
+ {%- if tools_ts_str -%}
55
+ <|im_system|>tool_declare<|im_middle|>{{ tools_ts_str }}<|im_end|>
56
+ {%- else -%}
57
+ <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>
58
+ {%- endif -%}
59
+ {%- endif -%}
60
+
61
+ {%- if messages|length == 0 or messages[0]['role'] != 'system' -%}
62
+ <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>
63
+ {%- endif -%}
64
+
65
+ {%- for message in hist_msgs -%}
66
+ {{set_roles(message)}}
67
+ {%- if message['role'] == 'assistant' -%}
68
+ <think></think>{{render_content(message)}}
69
+ {%- if message.get('tool_calls') -%}
70
+ {{render_toolcalls(message)}}
71
+ {%- endif -%}
72
+ {%- elif message['role'] == 'tool' -%}
73
+ {%- set tool_call_id = message.tool_call_id -%}
74
+ ## Return of {{ tool_call_id }}
75
+ {{render_content(message)}}
76
+ {%- elif message['content'] is not none -%}
77
+ {{render_content(message)}}
78
+ {%- endif -%}
79
+ <|im_end|>
80
+ {%- endfor -%}
81
+
82
+ {%- for message in suffix_msgs -%}
83
+ {{set_roles(message)}}
84
+ {%- if message['role'] == 'assistant' -%}
85
+ {%- if thinking is defined and thinking is false -%}
86
+ <think></think>{{render_content(message)}}
87
+ {%- else -%}
88
+ {%- set rc = message.get('reasoning_content', '') -%}
89
+ <think>{{rc}}</think>{{render_content(message)}}
90
+ {%- endif -%}
91
+ {%- if message.get('tool_calls') -%}
92
+ {{render_toolcalls(message)}}
93
+ {%- endif -%}
94
+ {%- elif message['role'] == 'tool' -%}
95
+ {%- set tool_call_id = message.tool_call_id -%}
96
+ ## Return of {{ tool_call_id }}
97
+ {{render_content(message)}}
98
+ {%- elif message['content'] is not none -%}
99
+ {{render_content(message)}}
100
+ {%- endif -%}
101
+ <|im_end|>
102
+ {%- endfor -%}
103
+
104
+
105
+ {%- if add_generation_prompt -%}
106
+ <|im_assistant|>assistant<|im_middle|>
107
+ {%- if thinking is defined and thinking is false -%}
108
+ <think></think>
109
+ {%- else -%}
110
+ <think>
111
+ {%- endif -%}
112
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiK25ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi_k25.KimiK25Config",
7
+ "AutoModel": "modeling_kimi_k25.KimiK25ForConditionalGeneration",
8
+ "AutoModelForCausalLM": "modeling_kimi_k25.KimiK25ForConditionalGeneration"
9
+ },
10
+ "bos_token_id": 163584,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 163585,
13
+ "ignore_index": -100,
14
+ "media_placeholder_token_id": 163605,
15
+ "model_type": "kimi_k25",
16
+ "pad_token_id": 163839,
17
+ "text_config": {
18
+ "_name_or_path": "",
19
+ "add_cross_attention": false,
20
+ "architectures": [
21
+ "DeepseekV3ForCausalLM"
22
+ ],
23
+ "attention_bias": false,
24
+ "attention_dropout": 0.0,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_deepseek.DeepseekV3Config",
27
+ "AutoModel": "modeling_deepseek.DeepseekV3Model",
28
+ "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
29
+ },
30
+ "aux_loss_alpha": 0.001,
31
+ "bad_words_ids": null,
32
+ "begin_suppress_tokens": null,
33
+ "bos_token_id": 163584,
34
+ "chunk_size_feed_forward": 0,
35
+ "cross_attention_hidden_size": null,
36
+ "decoder_start_token_id": null,
37
+ "diversity_penalty": 0.0,
38
+ "do_sample": false,
39
+ "dtype": "bfloat16",
40
+ "early_stopping": false,
41
+ "encoder_no_repeat_ngram_size": 0,
42
+ "eos_token_id": 163585,
43
+ "ep_size": 1,
44
+ "exponential_decay_length_penalty": null,
45
+ "finetuning_task": null,
46
+ "first_k_dense_replace": 1,
47
+ "forced_bos_token_id": null,
48
+ "forced_eos_token_id": null,
49
+ "hidden_act": "silu",
50
+ "hidden_size": 7168,
51
+ "id2label": {
52
+ "0": "LABEL_0",
53
+ "1": "LABEL_1"
54
+ },
55
+ "initializer_range": 0.02,
56
+ "intermediate_size": 18432,
57
+ "is_decoder": false,
58
+ "is_encoder_decoder": false,
59
+ "kv_lora_rank": 512,
60
+ "label2id": {
61
+ "LABEL_0": 0,
62
+ "LABEL_1": 1
63
+ },
64
+ "length_penalty": 1.0,
65
+ "max_length": 20,
66
+ "max_position_embeddings": 262144,
67
+ "min_length": 0,
68
+ "model_type": "kimi_k2",
69
+ "moe_intermediate_size": 2048,
70
+ "moe_layer_freq": 1,
71
+ "n_group": 1,
72
+ "n_routed_experts": 304,
73
+ "n_shared_experts": 1,
74
+ "no_repeat_ngram_size": 0,
75
+ "norm_topk_prob": true,
76
+ "num_attention_heads": 64,
77
+ "num_beam_groups": 1,
78
+ "num_beams": 1,
79
+ "num_experts_per_tok": 8,
80
+ "num_hidden_layers": 61,
81
+ "num_key_value_heads": 64,
82
+ "num_nextn_predict_layers": 0,
83
+ "num_return_sequences": 1,
84
+ "output_attentions": false,
85
+ "output_hidden_states": false,
86
+ "output_scores": false,
87
+ "pad_token_id": 163839,
88
+ "prefix": null,
89
+ "pretraining_tp": 1,
90
+ "problem_type": null,
91
+ "pruned_heads": {},
92
+ "q_lora_rank": 1536,
93
+ "qk_nope_head_dim": 128,
94
+ "qk_rope_head_dim": 64,
95
+ "quantization_config": {
96
+ "config_groups": {
97
+ "group_0": {
98
+ "input_activations": null,
99
+ "output_activations": null,
100
+ "targets": [
101
+ "Linear"
102
+ ],
103
+ "weights": {
104
+ "actorder": null,
105
+ "block_structure": null,
106
+ "dynamic": false,
107
+ "group_size": 32,
108
+ "num_bits": 4,
109
+ "observer": "minmax",
110
+ "observer_kwargs": {},
111
+ "strategy": "group",
112
+ "symmetric": true,
113
+ "type": "int"
114
+ }
115
+ }
116
+ },
117
+ "format": "pack-quantized",
118
+ "ignore": [
119
+ "lm_head",
120
+ "re:.*self_attn.*",
121
+ "re:.*shared_experts.*",
122
+ "re:.*mlp\\.(gate|up|gate_up|down)_proj.*"
123
+ ],
124
+ "kv_cache_scheme": null,
125
+ "quant_method": "compressed-tensors",
126
+ "quantization_status": "compressed"
127
+ },
128
+ "remove_invalid_values": false,
129
+ "repetition_penalty": 1.0,
130
+ "return_dict": true,
131
+ "return_dict_in_generate": false,
132
+ "rms_norm_eps": 1e-05,
133
+ "rope_scaling": {
134
+ "beta_fast": 32.0,
135
+ "beta_slow": 1.0,
136
+ "factor": 64.0,
137
+ "mscale": 1.0,
138
+ "mscale_all_dim": 1.0,
139
+ "original_max_position_embeddings": 4096,
140
+ "type": "yarn"
141
+ },
142
+ "rope_theta": 50000.0,
143
+ "routed_scaling_factor": 2.827,
144
+ "scoring_func": "sigmoid",
145
+ "sep_token_id": null,
146
+ "seq_aux": true,
147
+ "suppress_tokens": null,
148
+ "task_specific_params": null,
149
+ "temperature": 1.0,
150
+ "tf_legacy_loss": false,
151
+ "tie_encoder_decoder": false,
152
+ "tie_word_embeddings": false,
153
+ "tokenizer_class": null,
154
+ "top_k": 50,
155
+ "top_p": 1.0,
156
+ "topk_group": 1,
157
+ "topk_method": "noaux_tc",
158
+ "torchscript": false,
159
+ "transformers_version": "4.56.2",
160
+ "typical_p": 1.0,
161
+ "use_bfloat16": false,
162
+ "use_cache": true,
163
+ "v_head_dim": 128,
164
+ "vocab_size": 163840
165
+ },
166
+ "tie_word_embeddings": false,
167
+ "use_unified_vision_chunk": true,
168
+ "video_placeholder": "<|kimi_k25_video_placeholder|>",
169
+ "vision_config": {
170
+ "_attn_implementation": "flash_attention_2",
171
+ "init_pos_emb_height": 64,
172
+ "init_pos_emb_time": 4,
173
+ "init_pos_emb_width": 64,
174
+ "merge_kernel_size": [
175
+ 2,
176
+ 2
177
+ ],
178
+ "merge_type": "sd2_tpool",
179
+ "mm_hidden_size": 1152,
180
+ "mm_projector_type": "patchmerger",
181
+ "patch_size": 14,
182
+ "pos_emb_type": "divided_fixed",
183
+ "projector_hidden_act": "gelu",
184
+ "projector_ln_eps": 1e-05,
185
+ "text_hidden_size": 7168,
186
+ "video_attn_type": "spatial_temporal",
187
+ "vt_hidden_size": 1152,
188
+ "vt_intermediate_size": 4304,
189
+ "vt_num_attention_heads": 16,
190
+ "vt_num_hidden_layers": 27
191
+ }
192
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"image-text-to-text"}
configuration_deepseek.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/configuration_deepseek.py
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+ DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
9
+
10
+
11
+ class DeepseekV3Config(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+
21
+ Args:
22
+ vocab_size (`int`, *optional*, defaults to 129280):
23
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
24
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
25
+ hidden_size (`int`, *optional*, defaults to 4096):
26
+ Dimension of the hidden representations.
27
+ intermediate_size (`int`, *optional*, defaults to 11008):
28
+ Dimension of the MLP representations.
29
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
30
+ Dimension of the MoE representations.
31
+ num_hidden_layers (`int`, *optional*, defaults to 32):
32
+ Number of hidden layers in the Transformer decoder.
33
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
34
+ Number of nextn predict layers in the DeepSeekV3 Model.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer decoder.
37
+ n_shared_experts (`int`, *optional*, defaults to None):
38
+ Number of shared experts, None means dense model.
39
+ n_routed_experts (`int`, *optional*, defaults to None):
40
+ Number of routed experts, None means dense model.
41
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
42
+ Scaling factor or routed experts.
43
+ topk_method (`str`, *optional*, defaults to `gready`):
44
+ Topk method used in routed gate.
45
+ n_group (`int`, *optional*, defaults to None):
46
+ Number of groups for routed experts.
47
+ topk_group (`int`, *optional*, defaults to None):
48
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
49
+ num_experts_per_tok (`int`, *optional*, defaults to None):
50
+ Number of selected experts, None means dense model.
51
+ moe_layer_freq (`int`, *optional*, defaults to 1):
52
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
53
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
54
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
55
+ \--k dense layers--/
56
+ norm_topk_prob (`bool`, *optional*, defaults to False):
57
+ Whether to normalize the weights of the routed experts.
58
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
59
+ Method of computing expert weights.
60
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
61
+ Auxiliary loss weight coefficient.
62
+ seq_aux = (`bool`, *optional*, defaults to True):
63
+ Whether to compute the auxiliary loss for each individual sample.
64
+ num_key_value_heads (`int`, *optional*):
65
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
66
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
67
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
68
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
69
+ by meanpooling all the original heads within that group. For more details checkout [this
70
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
71
+ `num_attention_heads`.
72
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
73
+ The non-linear activation function (function or string) in the decoder.
74
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
75
+ The maximum sequence length that this model might ever be used with.
76
+ initializer_range (`float`, *optional*, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
79
+ The epsilon used by the rms normalization layers.
80
+ use_cache (`bool`, *optional*, defaults to `True`):
81
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
82
+ relevant if `config.is_decoder=True`.
83
+ pad_token_id (`int`, *optional*):
84
+ Padding token id.
85
+ bos_token_id (`int`, *optional*, defaults to 1):
86
+ Beginning of stream token id.
87
+ eos_token_id (`int`, *optional*, defaults to 2):
88
+ End of stream token id.
89
+ pretraining_tp (`int`, *optional*, defaults to 1):
90
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
91
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
92
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
93
+ issue](https://github.com/pytorch/pytorch/issues/76232).
94
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
95
+ Whether to tie weight embeddings
96
+ rope_theta (`float`, *optional*, defaults to 10000.0):
97
+ The base period of the RoPE embeddings.
98
+ rope_scaling (`Dict`, *optional*):
99
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
100
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
101
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
102
+ `max_position_embeddings` to the expected new maximum.
103
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
104
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
105
+ attention_dropout (`float`, *optional*, defaults to 0.0):
106
+ The dropout ratio for the attention probabilities.
107
+
108
+ ```python
109
+ >>> from transformers import DeepseekV3Model, DeepseekV3Config
110
+
111
+ >>> # Initializing a Deepseek-V3 style configuration
112
+ >>> configuration = DeepseekV3Config()
113
+
114
+ >>> # Accessing the model configuration
115
+ >>> configuration = model.config
116
+ ```"""
117
+
118
+ model_type = "deepseek_v3"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=129280,
124
+ hidden_size=7168,
125
+ intermediate_size=18432,
126
+ moe_intermediate_size=2048,
127
+ num_hidden_layers=61,
128
+ num_nextn_predict_layers=1,
129
+ num_attention_heads=128,
130
+ num_key_value_heads=128,
131
+ n_shared_experts=1,
132
+ n_routed_experts=256,
133
+ ep_size=1,
134
+ routed_scaling_factor=2.5,
135
+ kv_lora_rank=512,
136
+ q_lora_rank=1536,
137
+ qk_rope_head_dim=64,
138
+ v_head_dim=128,
139
+ qk_nope_head_dim=128,
140
+ topk_method='noaux_tc',
141
+ n_group=8,
142
+ topk_group=4,
143
+ num_experts_per_tok=8,
144
+ moe_layer_freq=1,
145
+ first_k_dense_replace=3,
146
+ norm_topk_prob=True,
147
+ scoring_func='sigmoid',
148
+ aux_loss_alpha=0.001,
149
+ seq_aux=True,
150
+ hidden_act="silu",
151
+ max_position_embeddings=4096,
152
+ initializer_range=0.02,
153
+ rms_norm_eps=1e-6,
154
+ use_cache=True,
155
+ pad_token_id=None,
156
+ bos_token_id=0,
157
+ eos_token_id=1,
158
+ pretraining_tp=1,
159
+ tie_word_embeddings=False,
160
+ rope_theta=10000.0,
161
+ rope_scaling=None,
162
+ attention_bias=False,
163
+ attention_dropout=0.0,
164
+ **kwargs,
165
+ ):
166
+ self.vocab_size = vocab_size
167
+ self.max_position_embeddings = max_position_embeddings
168
+ self.hidden_size = hidden_size
169
+ self.intermediate_size = intermediate_size
170
+ self.moe_intermediate_size = moe_intermediate_size
171
+ self.num_hidden_layers = num_hidden_layers
172
+ self.num_nextn_predict_layers = num_nextn_predict_layers
173
+ self.num_attention_heads = num_attention_heads
174
+ self.n_shared_experts = n_shared_experts
175
+ self.n_routed_experts = n_routed_experts
176
+ self.ep_size = ep_size
177
+ self.routed_scaling_factor = routed_scaling_factor
178
+ self.kv_lora_rank = kv_lora_rank
179
+ self.q_lora_rank = q_lora_rank
180
+ self.qk_rope_head_dim = qk_rope_head_dim
181
+ self.v_head_dim = v_head_dim
182
+ self.qk_nope_head_dim = qk_nope_head_dim
183
+ self.topk_method = topk_method
184
+ self.n_group = n_group
185
+ self.topk_group = topk_group
186
+ self.num_experts_per_tok = num_experts_per_tok
187
+ self.moe_layer_freq = moe_layer_freq
188
+ self.first_k_dense_replace = first_k_dense_replace
189
+ self.norm_topk_prob = norm_topk_prob
190
+ self.scoring_func = scoring_func
191
+ self.aux_loss_alpha = aux_loss_alpha
192
+ self.seq_aux = seq_aux
193
+ # for backward compatibility
194
+ if num_key_value_heads is None:
195
+ num_key_value_heads = num_attention_heads
196
+
197
+ self.num_key_value_heads = num_key_value_heads
198
+ self.hidden_act = hidden_act
199
+ self.initializer_range = initializer_range
200
+ self.rms_norm_eps = rms_norm_eps
201
+ self.pretraining_tp = pretraining_tp
202
+ self.use_cache = use_cache
203
+ self.rope_theta = rope_theta
204
+ self.rope_scaling = rope_scaling
205
+ self.attention_bias = attention_bias
206
+ self.attention_dropout = attention_dropout
207
+
208
+ super().__init__(
209
+ pad_token_id=pad_token_id,
210
+ bos_token_id=bos_token_id,
211
+ eos_token_id=eos_token_id,
212
+ tie_word_embeddings=tie_word_embeddings,
213
+ **kwargs,
214
+ )
configuration_kimi_k25.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ try:
4
+ from configuration_deepseek import DeepseekV3Config
5
+ except ImportError:
6
+ from .configuration_deepseek import DeepseekV3Config
7
+
8
+
9
+ class KimiK25VisionConfig(PretrainedConfig):
10
+
11
+ def __init__(
12
+ self,
13
+ patch_size: int = 14,
14
+ init_pos_emb_height: int = 64,
15
+ init_pos_emb_width: int = 64,
16
+ init_pos_emb_time: int = 4,
17
+ pos_emb_type: str = 'divided_fixed',
18
+ vt_num_attention_heads: int = 16,
19
+ vt_num_hidden_layers: int = 27,
20
+ vt_hidden_size: int = 1152,
21
+ vt_intermediate_size: int = 4304,
22
+ merge_kernel_size: tuple = (2, 2),
23
+ video_attn_type: str = 'spatial_temporal',
24
+ merge_type: str = 'sd2_tpool',
25
+ _attn_implementation: str = 'flash_attention_2',
26
+ # MM Projector parameters
27
+ mm_projector_type: str = 'patchmerger',
28
+ mm_hidden_size: int | None = None,
29
+ projector_hidden_act: str = "gelu",
30
+ projector_ln_eps: float = 1e-5,
31
+ # Other parameters
32
+ ignore_index: int = -100,
33
+ media_placeholder_token_id: int = 163605,
34
+ pad_token_id: int = 0,
35
+ use_unified_vision_chunk: bool = True,
36
+ video_placeholder="<|kimi_k25_video_placeholder|>",
37
+ text_hidden_size=7168,
38
+ **vision_config_kwargs):
39
+
40
+ self.patch_size = patch_size
41
+ self.init_pos_emb_height = init_pos_emb_height
42
+ self.init_pos_emb_width = init_pos_emb_width
43
+ self.init_pos_emb_time = init_pos_emb_time
44
+ self.pos_emb_type = pos_emb_type
45
+ self.vt_num_attention_heads = vt_num_attention_heads
46
+ self.vt_num_hidden_layers = vt_num_hidden_layers
47
+ self.vt_hidden_size = vt_hidden_size
48
+ self.vt_intermediate_size = vt_intermediate_size
49
+ self.merge_kernel_size = merge_kernel_size
50
+ self.video_attn_type = video_attn_type
51
+ self.merge_type = merge_type
52
+ self._attn_implementation = _attn_implementation
53
+
54
+ # MM Projector config
55
+ self.mm_projector_type = mm_projector_type
56
+ self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else vt_hidden_size
57
+ self.projector_hidden_act = projector_hidden_act
58
+ self.projector_ln_eps = projector_ln_eps
59
+ self.text_hidden_size = text_hidden_size
60
+
61
+
62
+ class KimiK25Config(PretrainedConfig):
63
+ """Kimi-K2.5 model configuration.
64
+
65
+ Args:
66
+ text_config (dict | DeepseekV3Config): Configuration for the text model.
67
+
68
+ Vision Tower Parameters (from MoonViT3dConfig):
69
+ patch_size (int): Patch size for vision tower.
70
+ init_pos_emb_height (int): Initial position embedding height.
71
+ init_pos_emb_width (int): Initial position embedding width.
72
+ init_pos_emb_time (int): Initial position embedding time dimension.
73
+ pos_emb_type (str): Type of position embedding.
74
+ vt_num_attention_heads (int): Number of attention heads in vision tower.
75
+ vt_num_hidden_layers (int): Number of hidden layers in vision tower.
76
+ vt_hidden_size (int): Hidden size of vision tower.
77
+ vt_intermediate_size (int): Intermediate size in vision tower FFN.
78
+ merge_kernel_size (tuple): Kernel size for patch merging.
79
+ video_attn_type (str): Type of video attention.
80
+ merge_type (str): Type of merge operation.
81
+ _attn_implementation (str): Attention implementation type.
82
+
83
+ MM Projector Parameters (from MultiModalProjectorConfig):
84
+ mm_projector_type (str): Type of multimodal projector.
85
+ mm_hidden_size (int): Hidden size from vision tower (should match vt_hidden_size).
86
+ projector_hidden_act (str): Activation function for projector.
87
+ projector_ln_eps (float): Layer norm epsilon for projector.
88
+
89
+ Other Parameters:
90
+ ignore_index (int): The ignore index for the loss function.
91
+ media_placeholder_token_id (int): The token ID to use for media placeholders.
92
+ pad_token_id (int): The token ID to use for padding.
93
+ """
94
+
95
+ model_type = "kimi_k25"
96
+
97
+ def __init__(
98
+ self,
99
+ text_config: dict | DeepseekV3Config = None,
100
+ vision_config: dict | KimiK25VisionConfig = None,
101
+ # Other parameters
102
+ ignore_index: int = -100,
103
+ media_placeholder_token_id: int = 163605,
104
+ pad_token_id: int = 0,
105
+ use_unified_vision_chunk: bool = True,
106
+ video_placeholder="<|kimi_k25_video_placeholder|>",
107
+ **kwargs,
108
+ ):
109
+ if isinstance(text_config, dict):
110
+ text_config = DeepseekV3Config(**text_config)
111
+ if isinstance(vision_config, dict):
112
+ vision_config = KimiK25VisionConfig(**vision_config)
113
+ self.text_config = text_config
114
+ self.vision_config = vision_config
115
+ # Other config
116
+ self.ignore_index = ignore_index
117
+ self.media_placeholder_token_id = media_placeholder_token_id
118
+ self.use_unified_vision_chunk = use_unified_vision_chunk
119
+ self.video_placeholder = video_placeholder
120
+ if getattr(self.text_config, "quantization_config", None) is not None:
121
+ self.quantization_config = self.text_config.quantization_config
122
+
123
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
docs/deploy_guidance.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Kimi-K2.5 Deployment Guide
2
+
3
+ > [!Note]
4
+ > This guide only provides some examples of deployment commands for Kimi-K2.5, which may not be the optimal configuration. Since inference engines are still being updated frequenty, please continue to follow the guidance from their homepage if you want to achieve better inference performance.
5
+
6
+ > kimi_k2 reasoning parser and other related features have been merged into vLLM/sglang and will be available in the next release. For now, please use the nightly build Docker image.
7
+ ## vLLM Deployment
8
+
9
+ This model is available in nightly vLLM wheel:
10
+ ```
11
+ uv pip install -U vllm \
12
+ --torch-backend=auto \
13
+ --extra-index-url https://wheels.vllm.ai/nightly
14
+ ```
15
+
16
+ Here is the example to serve this model on a H200 single node with TP8 via vLLM:
17
+ ```bash
18
+ vllm serve $MODEL_PATH --tp 8 --trust-remote-code --tool-call-parser kimi_k2 --reasoning-parser kimi_k2
19
+ ```
20
+ **Key notes**
21
+ - `--tool-call-parser kimi_k2`: Required for enabling tool calling
22
+ - `--reasoning-parser kimi_k2`: Kimi-K2.5 enables thinking mode by default. Make sure to pass this for correct reasoning processing.
23
+
24
+ ## SGLang Deployment
25
+
26
+ This model is available in SGLang latest main:
27
+
28
+ ```
29
+ pip install "sglang @ git+https://github.com/sgl-project/sglang.git#subdirectory=python"
30
+ pip install nvidia-cudnn-cu12==9.16.0.29
31
+ ```
32
+
33
+ Similarly, here is the example for it to run with TP8 on H200 in a single node via SGLang:
34
+ ``` bash
35
+ sglang serve --model-path $MODEL_PATH --tp 8 --trust-remote-code --tool-call-parser kimi_k2 --reasoning-parser kimi_k2
36
+ ```
37
+ **Key parameter notes:**
38
+ - `--tool-call-parser kimi_k2`: Required when enabling tool usage.
39
+ - `--reasoning-parser kimi_k2`: Required for correctly processing reasoning content.
40
+
41
+ ## KTransformers Deployment
42
+ ### KTransformers+SGLang Inference Deployment
43
+ Launch with KTransformers + SGLang for CPU+GPU heterogeneous inference:
44
+
45
+ ```
46
+ python -m sglang.launch_server \
47
+ --model path/to/Kimi-K2.5/ \
48
+ --kt-amx-weight-path path/to/Kimi-K2.5/ \
49
+ --kt-cpuinfer 64 \
50
+ --kt-threadpool-count 2 \
51
+ --kt-num-gpu-experts 180 \
52
+ --kt-amx-method AMXINT4 \
53
+ --trust-remote-code \
54
+ --mem-fraction-static 0.98 \
55
+ --chunked-prefill-size 16384 \
56
+ --max-running-requests 48 \
57
+ --max-total-tokens 50000 \
58
+ --tensor-parallel-size 8 \
59
+ --enable-p2p-check \
60
+ --disable-shared-experts-fusion
61
+ ```
62
+
63
+ Achieves 640.12 tokens/s Prefill and 24.51 tokens/s Decode (48-way concurrency) on 8× NVIDIA L20 + 2× Intel 6454S.
64
+
65
+ More details: https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/Kimi-K2.5.md .
66
+
67
+ ### KTransformers+LLaMA-Factory Fine-tuning Deployment
68
+
69
+ You can use below command to run LoRA SFT with KT+llamafactory.
70
+
71
+ ```
72
+ # For LoRA SFT
73
+ USE_KT=1 llamafactory-cli train examples/train_lora/kimik2_lora_sft_kt.yaml
74
+ # For Chat with model after LoRA SFT
75
+ llamafactory-cli chat examples/inference/kimik2_lora_sft_kt.yaml
76
+ # For API with model after LoRA SFT
77
+ llamafactory-cli api examples/inference/kimik2_lora_sft_kt.yaml
78
+ ```
79
+
80
+ This achieves end-to-end LoRA SFT Throughput: 44.55 token/s on 2× NVIDIA 4090 + Intel 8488C with 1.97T RAM and 200G swap memory.
81
+
82
+ More details refer to https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/SFT_Installation_Guide_KimiK2.5.md .
figures/demo_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09b4d925aa0a7c712feef50765355f0625d8f6d46ea302fd98db9609e9070047
3
+ size 270100
figures/kimi-logo.png ADDED
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 262144,
3
+ "eos_token_id": 163586,
4
+ "temperature": 1.0,
5
+ "top_k": 50,
6
+ "top_p": 0.95
7
+ }
kimi_k25_processor.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.feature_extraction_utils import BatchFeature
2
+ from transformers.processing_utils import ProcessorMixin
3
+ from transformers.utils import logging
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class KimiK25Processor(ProcessorMixin):
9
+ r"""
10
+ Constructs a KimiK25 processor which wraps a KimiK25 image processor and a tokenizer into a single processor.
11
+
12
+ [`KimiK25Processor`] offers all the functionalities of [`KimiK25ImageProcessor`] and [`TikTokenTokenizer`]. See the
13
+ [`~KimiK25Processor.__call__`] and [`~KimiK25Processor.decode`] for more information.
14
+
15
+ Args:
16
+ image_processor ([`KimiK25ImageProcessor`], *optional*):
17
+ The image processor is a required input.
18
+ tokenizer ([`TikTokenTokenizer`], *optional*):
19
+ The tokenizer is a required input.
20
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
21
+ in a chat into a tokenizable string.
22
+ """
23
+
24
+ attributes = ["image_processor", "tokenizer"]
25
+ valid_kwargs = ["chat_template"]
26
+ image_processor_class = "AutoImageProcessor"
27
+ tokenizer_class = "AutoTokenizer"
28
+
29
+ def __init__(
30
+ self,
31
+ image_processor=None,
32
+ tokenizer=None,
33
+ chat_template=None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(image_processor,
37
+ tokenizer,
38
+ chat_template=chat_template)
39
+ self.media_processor = image_processor
40
+ # A special temporal placeholder to be replaced by actual video placeholders
41
+ self.video_placeholder = "<|kimi_k25_video_placeholder|>"
42
+
43
+ def update_raw_text(self, text: str, video_prompts: list[str]) -> str:
44
+ # replace video prompt in text with video chunk prompts
45
+ video_count = text.count(self.video_placeholder)
46
+ if video_count == 0:
47
+ return text
48
+ assert video_count == len(video_prompts)
49
+ text_parts = text.split(self.video_placeholder)
50
+ assert len(text_parts) == len(video_prompts) + 1
51
+ text = "".join([
52
+ text_parts[i] + video_prompts[i] for i in range(len(video_prompts))
53
+ ])
54
+ text += text_parts[-1]
55
+ return text
56
+
57
+ def preprocess_medias(self, medias: list[dict]) -> list[dict]:
58
+ updated_medias = []
59
+ video_prompts = []
60
+ for media in medias:
61
+ if media['type'] == 'image':
62
+ updated_medias.append(media)
63
+ elif media['type'] == 'video':
64
+ video_chunks = self.media_processor.split_video_chunks(
65
+ media['video'])
66
+ updated_medias.extend(video_chunks)
67
+ video_prompts.append("".join(
68
+ [vc['prompt'] for vc in video_chunks]))
69
+ else:
70
+ raise ValueError(f"unsupported media type: {media['type']}")
71
+ return updated_medias, video_prompts
72
+
73
+ def __call__(self,
74
+ messages: list[dict] = None,
75
+ medias: list[dict] = None,
76
+ text: str = None,
77
+ return_tensors: str = "pt",
78
+ **kwargs) -> BatchFeature:
79
+ """
80
+ Process multimodal inputs for Kimi-K2.5 model.
81
+
82
+ This processor accepts ordered messages and extracts both media and text in a single pass.
83
+ text will be automatically updated if video input detected in messages
84
+
85
+ Args:
86
+ messages: List of message dicts with 'role' and 'content' fields.
87
+ If provided, medias and text will be extracted automatically.
88
+ medias: Pre-extracted list of media dicts. If None, extracted from messages.
89
+ text: Pre-formatted text string. If None, generated via apply_chat_template.
90
+ return_tensors: Format of returned tensors ('pt', 'np', 'tf'). Default: 'pt'.
91
+ **kwargs: Additional arguments passed to tokenizer.apply_chat_template.
92
+
93
+ Returns:
94
+ BatchFeature with fields: input_ids, attention_mask, pixel_values, grid_thws.
95
+ """
96
+ if messages is None and (medias is None or text is None):
97
+ raise ValueError(
98
+ "Provide either 'messages' or both 'medias' and 'text'")
99
+
100
+ if medias is not None and text is not None:
101
+ updated_medias, video_prompts = self.preprocess_medias(medias)
102
+ preprocessed = self.media_processor.preprocess(
103
+ updated_medias, return_tensors=return_tensors)
104
+ text = self.update_raw_text(text, video_prompts)
105
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
106
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
107
+
108
+ if medias is None:
109
+ medias = self._extract_medias_from_messages(messages)
110
+ updated_medias, video_prompts = self.preprocess_medias(medias)
111
+ preprocessed = self.media_processor.preprocess(
112
+ updated_medias, return_tensors=return_tensors)
113
+
114
+ # Generate text if not provided
115
+ if text is None:
116
+ text = self.tokenizer.apply_chat_template(messages, **kwargs)
117
+
118
+ text = self.update_raw_text(text, video_prompts)
119
+
120
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
121
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
122
+
123
+ @staticmethod
124
+ def _extract_medias_from_messages(messages: list[dict]) -> list[dict]:
125
+ """
126
+ Extract media items from messages in a single pass.
127
+
128
+ This is an optimized version that processes messages only once.
129
+ Kept as internal method since external callers should use __call__.
130
+ """
131
+ medias = []
132
+ for msg in messages:
133
+ if msg['role'] != 'user' or not msg.get('content'):
134
+ continue
135
+
136
+ for content_part in msg['content']:
137
+ if not isinstance(content_part, dict):
138
+ continue
139
+
140
+ content_type = content_part.get('type')
141
+ if content_type in ['video_url', 'video']:
142
+ medias.append({
143
+ 'type': 'video',
144
+ 'video': content_part['video_url']['url'],
145
+ 'first_frame_timestamp': 0.0
146
+ })
147
+ elif content_type in ['image_url', 'image']:
148
+ medias.append({
149
+ 'type': 'image',
150
+ 'image': content_part['image_url'],
151
+ })
152
+ return medias
153
+
154
+ def apply_chat_template(self, messages, **kwargs):
155
+ return self.tokenizer.apply_chat_template(messages, **kwargs)
156
+
157
+ def batch_decode(self, *args, **kwargs):
158
+ return self.tokenizer.batch_decode(*args, **kwargs)
159
+
160
+ def decode(self, *args, **kwargs):
161
+ return self.tokenizer.decode(*args, **kwargs)
162
+
163
+ @property
164
+ def model_input_names(self):
165
+ return ['input_ids', 'attention_mask', 'pixel_values', 'grid_thws']
kimi_k25_vision_processing.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for Kimi-K2.5.
2
+ """
3
+
4
+ import json
5
+ from typing import Any, Dict, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from transformers.image_processing_utils import (BaseImageProcessor,
11
+ BatchFeature)
12
+ from transformers.utils import TensorType
13
+
14
+ from .media_utils import (MediaInput, VideoChunkInput, _to_tensor,
15
+ ensure_media_type, get_video_meta, image_to_np,
16
+ navit_patchify, navit_resize_image,
17
+ navit_resize_video, normalize,
18
+ real_sample_fps_and_max_num_frames, timestamp_as_str)
19
+
20
+ try:
21
+ from mecord import VideoReader
22
+ except ImportError:
23
+ VideoReader = None
24
+
25
+
26
+ def resampling(video_bytes: bytes,
27
+ sample_indices: list[int],
28
+ key_indices=None,
29
+ frame_time_info=None,
30
+ num_threads=4) -> str:
31
+ video = VideoReader(video_bytes,
32
+ num_threads=num_threads,
33
+ frame_time_info=frame_time_info,
34
+ key_indices=key_indices)
35
+ # extract target frames
36
+ frames = video[sample_indices]
37
+ frames = [Image.fromarray(frame) for frame in frames]
38
+ return frames
39
+
40
+
41
+ class KimiK25VisionProcessor(BaseImageProcessor):
42
+ model_type = "kimi_k25"
43
+
44
+ def __init__(
45
+ self,
46
+ media_proc_cfg: dict,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.media_proc_cfg = media_proc_cfg
51
+ self.num_frames_per_chunk = media_proc_cfg[
52
+ 'temporal_merge_kernel_size']
53
+
54
+ def media_tokens_calculator(self, media: MediaInput):
55
+ media = ensure_media_type(media)
56
+ ret = self.get_resize_config(media)
57
+ return ret['num_tokens']
58
+
59
+ @classmethod
60
+ def make_chunk_prompt(cls, timestamp_text: str) -> str:
61
+ return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>"
62
+
63
+ def split_video_chunks(self,
64
+ video_url: str | bytes) -> list[list[Image.Image]]:
65
+ # video_url should be base64 str or bytes
66
+ video_spec = get_video_meta(video_url)
67
+ sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps)
68
+ sampled_nframes = max(
69
+ round(video_spec.num_frames * sample_fps / video_spec.fps), 1)
70
+ frame_inds = np.linspace(0, video_spec.num_frames - 1,
71
+ sampled_nframes).round().astype(int)
72
+ frame_inds = frame_inds.tolist()
73
+ sampled_frame_ids = []
74
+ temporal_merge_kernel_size = self.media_proc_cfg[
75
+ "temporal_merge_kernel_size"]
76
+ num_chunks = 0
77
+ chunk_timestamp = []
78
+ for i in range(0, len(frame_inds), temporal_merge_kernel_size):
79
+ sampled_frame_ids.extend(frame_inds[i:i +
80
+ temporal_merge_kernel_size])
81
+ start_time = frame_inds[i] / float(video_spec.fps)
82
+ timestamp_text = timestamp_as_str(
83
+ start_time, self.media_proc_cfg["timestamp_mode"])
84
+ chunk_timestamp.append(timestamp_text)
85
+ num_chunks += 1
86
+
87
+ sampled_frames = resampling(video_url, sampled_frame_ids)
88
+ chunks = []
89
+ for chunk_id in range(num_chunks):
90
+ chunk = sampled_frames[chunk_id *
91
+ temporal_merge_kernel_size:(chunk_id + 1) *
92
+ temporal_merge_kernel_size]
93
+ chunks.append(
94
+ VideoChunkInput(type="video_chunk",
95
+ video_chunk=chunk,
96
+ prompt=self.make_chunk_prompt(
97
+ chunk_timestamp[chunk_id])))
98
+ return chunks
99
+
100
+ def get_resize_config(self, media_input: MediaInput) -> dict:
101
+ if media_input['type'] == 'image':
102
+ w, h = media_input['image'].size
103
+ ret = navit_resize_image(
104
+ w, h, self.media_proc_cfg['patch_size'],
105
+ self.media_proc_cfg['merge_kernel_size'],
106
+ self.media_proc_cfg['in_patch_limit'],
107
+ self.media_proc_cfg['patch_limit_on_one_side'],
108
+ self.media_proc_cfg['fixed_output_tokens'])
109
+ return ret
110
+ elif media_input['type'] == 'video_chunk':
111
+ frame = media_input['video_chunk'][0]
112
+ width, height = frame.size
113
+ num_frames = len(media_input["video_chunk"])
114
+ fps = 1.0
115
+
116
+ sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames(
117
+ media_input["type"],
118
+ self.media_proc_cfg['sample_fps'],
119
+ self.media_proc_cfg['max_num_frames_each_video'],
120
+ )
121
+
122
+ in_patch_limit_each_frame = self.media_proc_cfg[
123
+ 'in_patch_limit_each_frame']
124
+ if in_patch_limit_each_frame is None:
125
+ in_patch_limit_each_frame = self.media_proc_cfg[
126
+ 'in_patch_limit']
127
+
128
+ ret = navit_resize_video(
129
+ width,
130
+ height,
131
+ num_frames,
132
+ fps,
133
+ sample_fps,
134
+ self.media_proc_cfg['patch_size'],
135
+ self.media_proc_cfg['merge_kernel_size'],
136
+ in_patch_limit_each_frame,
137
+ self.media_proc_cfg['patch_limit_on_one_side'],
138
+ self.media_proc_cfg['in_patch_limit_video'],
139
+ max_num_frames_each_video,
140
+ self.media_proc_cfg['fixed_output_tokens'],
141
+ )
142
+ return ret
143
+ else:
144
+ raise ValueError("Unsupported type: {}".format(
145
+ media_input['type']))
146
+
147
+ def resize_image(self, image: Image.Image, new_width: int, new_height: int,
148
+ pad_width: int, pad_height: int) -> np.ndarray:
149
+ image_np = image_to_np(image, (new_width, new_height), "resize")
150
+ image_np = np.pad(
151
+ image_np,
152
+ ((0, pad_height), (0, pad_width), (0, 0)),
153
+ mode="constant",
154
+ constant_values=0,
155
+ )
156
+ return image_np
157
+
158
+ def preprocess(
159
+ self,
160
+ medias: list[MediaInput],
161
+ return_tensors: Optional[Union[str, TensorType]] = None,
162
+ ) -> BatchFeature:
163
+ """
164
+ Preprocess a atom vision input (images/video_chunk) into model-ready tensors.
165
+
166
+ Args:
167
+ medias: List of MediaInput.
168
+ return_tensors: Desired output format ('pt', 'np', 'tf', or None).
169
+
170
+ Returns:
171
+ BatchFeature containing 'pixel_values' and 'grid_thws' tensors.
172
+ """
173
+ if not isinstance(medias, list):
174
+ medias = [medias]
175
+ if medias:
176
+ pixel_values = []
177
+ for item in medias:
178
+ item = ensure_media_type(item)
179
+ resize_config = self.get_resize_config(item)
180
+ new_width, new_height, pad_width, pad_height = resize_config[
181
+ 'new_width'], resize_config['new_height'], resize_config[
182
+ 'pad_width'], resize_config['pad_height']
183
+ if item['type'] == 'image':
184
+ image = item['image']
185
+ image_np = self.resize_image(image, new_width, new_height,
186
+ pad_width, pad_height)
187
+ pixel_values.append(np.expand_dims(image_np, axis=0))
188
+ elif item['type'] == 'video_chunk':
189
+ pixels = []
190
+ for frame in item['video_chunk']:
191
+ frame_np = self.resize_image(frame, new_width,
192
+ new_height, pad_width,
193
+ pad_height)
194
+ pixels.append(frame_np)
195
+ pixel_values.append(np.stack(pixels, axis=0))
196
+ else:
197
+ raise ValueError("Unsupported type: {}".format(
198
+ item['type']))
199
+ normalized_pixel_values = []
200
+ image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std'])
201
+ image_mean = np.array(self.media_proc_cfg['image_mean'])
202
+ for pixels in pixel_values:
203
+ pixels = normalize(pixels, image_mean, image_std_inv)
204
+ pixels_and_thw = navit_patchify(
205
+ pixels,
206
+ self.media_proc_cfg['patch_size'],
207
+ )
208
+ normalized_pixel_values.append(pixels_and_thw)
209
+
210
+ pixel_values = torch.cat([
211
+ _to_tensor(pixel_value['pixel_values'])
212
+ for pixel_value in normalized_pixel_values
213
+ ])
214
+ grid_thws = torch.cat([
215
+ _to_tensor(pixel_value['grid_thw'],
216
+ dtype=torch.int64).unsqueeze(0)
217
+ for pixel_value in normalized_pixel_values
218
+ ])
219
+
220
+ data = {
221
+ 'pixel_values': pixel_values,
222
+ 'grid_thws': grid_thws,
223
+ }
224
+
225
+ else:
226
+ data = {}
227
+
228
+ return BatchFeature(data=data, tensor_type=return_tensors)
229
+
230
+ def __repr__(self):
231
+ return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})"
232
+
233
+ def to_dict(self) -> Dict[str, Any]:
234
+ output = super().to_dict()
235
+ output["media_proc_cfg"] = self.media_proc_cfg
236
+ if "media_processor" in output:
237
+ del output["media_processor"]
238
+ return output
239
+
240
+ @classmethod
241
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
242
+ config = config_dict.copy()
243
+ media_proc_cfg = config.pop("media_proc_cfg", {})
244
+ return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs)
245
+
246
+ def to_json_string(self):
247
+ dictionary = self.to_dict()
248
+ for key, value in dictionary.items():
249
+ if hasattr(value, 'tolist'):
250
+ dictionary[key] = value.tolist()
251
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
media_utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import math
4
+ import os
5
+ from datetime import datetime, timezone
6
+ from typing import List, Literal, Optional, TypedDict
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pydantic import BaseModel, Field
11
+
12
+ try:
13
+ from mecord import VideoReader
14
+ except ImportError:
15
+ VideoReader = None
16
+
17
+
18
+ class VideoSpec(BaseModel):
19
+ media_type: str = Literal['video']
20
+ height: int = Field(..., gt=0, description="video frame height")
21
+ width: int = Field(..., gt=0, description="video frame width")
22
+ num_frames: int = Field(..., gt=0, description="num frames")
23
+ fps: float = Field(..., gt=0, description="average fps")
24
+
25
+ # optional, help to accelerate video reading
26
+ key_indices: list[int] = Field(None, description="key indices")
27
+ frame_time_info: dict = Field(None, description="frame time info")
28
+
29
+
30
+ class ImageInput(TypedDict):
31
+ type: Literal['image']
32
+ image: Image.Image
33
+
34
+
35
+ class VideoChunkInput(TypedDict):
36
+ type: Literal['video_chunk']
37
+ video_chunk: List[Image.Image]
38
+ prompt: Optional[str] = None
39
+
40
+
41
+ MediaInput = ImageInput | VideoChunkInput
42
+
43
+
44
+ def get_video_meta(video_src: bytes | str | os.PathLike,
45
+ accurate: bool = True) -> dict:
46
+ """Get the dimensions of a video."""
47
+ if isinstance(video_src, os.PathLike):
48
+ video_src = str(video_src)
49
+ # if b64 string, decode to bytes
50
+ if isinstance(video_src,
51
+ str) and video_src.startswith('data:video/mp4;base64,'):
52
+ video_src = base64.b64decode(video_src.split(',')[1])
53
+ video = VideoReader(video_src, auto_init=accurate, num_threads=1)
54
+ assert video.num_frames > 0, "Invalid video format."
55
+ assert video.original_width > 0 and video.original_height > 0, (
56
+ "Invalid video format.")
57
+ assert video.avg_fps > 0, "Invalid video format."
58
+ return VideoSpec(media_type='video',
59
+ height=video.original_height,
60
+ width=video.original_width,
61
+ num_frames=video.num_frames,
62
+ fps=video.avg_fps,
63
+ key_indices=video.key_indices,
64
+ frame_time_info=video.frame_time_info)
65
+
66
+
67
+ def timestamp_as_str(timestamp: float,
68
+ timestamp_mode: str = "hh:mm:ss.fff") -> str:
69
+ """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
70
+ if timestamp_mode == "hh:mm:ss.fff":
71
+ return (datetime.fromtimestamp(timestamp,
72
+ tz=timezone.utc).strftime("%H:%M:%S") +
73
+ f".{int((timestamp % 1) * 1000):03d}")
74
+ elif timestamp_mode == "mm:ss.fff":
75
+ return (datetime.fromtimestamp(timestamp,
76
+ tz=timezone.utc).strftime("%M:%S") +
77
+ f".{int((timestamp % 1) * 1000):03d}")
78
+ elif timestamp_mode == "mm:ss":
79
+ return datetime.fromtimestamp(timestamp,
80
+ tz=timezone.utc).strftime("%M:%S")
81
+ else:
82
+ raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
83
+
84
+
85
+ def navit_resize_image(
86
+ width: int,
87
+ height: int,
88
+ patch_size: int,
89
+ merge_kernel_size: int,
90
+ in_patch_limit: int,
91
+ patch_limit_on_one_side: int,
92
+ fixed_output_tokens: int | None,
93
+ ):
94
+ # Apply the patch limits.
95
+ s1 = math.sqrt(
96
+ in_patch_limit /
97
+ (max(1.0, width // patch_size) * max(1.0, height // patch_size)))
98
+ s2 = patch_limit_on_one_side * patch_size / width
99
+ s3 = patch_limit_on_one_side * patch_size / height
100
+ scale = min(1.0, s1, s2, s3)
101
+ new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
102
+ new_w = min(new_w, patch_limit_on_one_side * patch_size)
103
+ new_h = min(new_h, patch_limit_on_one_side * patch_size)
104
+
105
+ # Calculate the padding to make the height and width divisible by the merge kernel size and patch size.
106
+ factor = merge_kernel_size * patch_size
107
+
108
+ pad_height = (factor - new_h % factor) % factor
109
+ pad_width = (factor - new_w % factor) % factor
110
+
111
+ if fixed_output_tokens is not None:
112
+ num_tokens = fixed_output_tokens
113
+ else:
114
+ # Calculate new dimensions after padding and patching
115
+ token_height = (new_h + pad_height) // factor
116
+ token_width = (new_w + pad_width) // factor
117
+
118
+ assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
119
+ f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
120
+ )
121
+ assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
122
+ f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
123
+ )
124
+
125
+ num_tokens = token_height * token_width
126
+ return {
127
+ "num_tokens": num_tokens,
128
+ "new_width": new_w,
129
+ "new_height": new_h,
130
+ "pad_width": pad_width,
131
+ "pad_height": pad_height,
132
+ "sampled_nframes": 1,
133
+ }
134
+
135
+
136
+ def navit_resize_video(
137
+ width: int,
138
+ height: int,
139
+ nframes: int,
140
+ avg_fps: float,
141
+ sample_fps: float,
142
+ patch_size: int,
143
+ merge_kernel_size: int,
144
+ in_patch_limit_each_frame: int,
145
+ patch_limit_on_one_side: int,
146
+ in_patch_limit_total: int | None,
147
+ max_num_frames_each_video: int | None,
148
+ fixed_output_tokens_each_frame: int | None,
149
+ ):
150
+ sample_fps = min(sample_fps, avg_fps)
151
+ # Calculate the number of frames to sample based on target FPS
152
+ sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
153
+ if max_num_frames_each_video is not None:
154
+ sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
155
+
156
+ if in_patch_limit_total is not None:
157
+ in_patch_limit_each_frame = min(
158
+ round(in_patch_limit_total / sampled_nframes),
159
+ in_patch_limit_each_frame)
160
+
161
+ ret = navit_resize_image(
162
+ width,
163
+ height,
164
+ patch_size,
165
+ merge_kernel_size,
166
+ in_patch_limit_each_frame,
167
+ patch_limit_on_one_side,
168
+ fixed_output_tokens_each_frame,
169
+ )
170
+ ret["sampled_nframes"] = sampled_nframes
171
+ return ret
172
+
173
+
174
+ def real_sample_fps_and_max_num_frames(
175
+ type_name: Literal["video", "video_chunk"],
176
+ sample_fps: float,
177
+ max_num_frames_each_video: int | None,
178
+ ) -> tuple[int, int | None]:
179
+ if type_name == "video":
180
+ return sample_fps, max_num_frames_each_video
181
+ elif type_name == "video_chunk":
182
+ max_num_frames_each_video = None
183
+ sample_fps = math.inf
184
+ return sample_fps, max_num_frames_each_video
185
+ else:
186
+ return math.inf, None
187
+
188
+
189
+ def _to_pil(data: str | bytes):
190
+ if isinstance(data, Image.Image):
191
+
192
+ return data.convert("RGB")
193
+ elif isinstance(data, str):
194
+ if data.startswith("data:"):
195
+ raw_base64 = data.split(",")[1]
196
+ return Image.open(io.BytesIO(
197
+ base64.b64decode(raw_base64))).convert("RGB")
198
+ else:
199
+ return Image.open(data).convert("RGB")
200
+ elif isinstance(data, bytes):
201
+ return Image.open(io.BytesIO(data)).convert("RGB")
202
+ else:
203
+ raise ValueError(f"Unsupported data type: {type(data)}")
204
+
205
+
206
+ def ensure_media_type(media: MediaInput) -> MediaInput:
207
+ if media['type'] == 'image':
208
+ media['image'] = _to_pil(media['image'])
209
+ return media
210
+ elif media['type'] == 'video_chunk':
211
+ media['video_chunk'] = [
212
+ _to_pil(frame) for frame in media['video_chunk']
213
+ ]
214
+ return media
215
+ else:
216
+ raise ValueError(f"Unsupported media type: {media['type']}")
217
+
218
+
219
+ def image_to_np(
220
+ image: Image.Image,
221
+ resize_to: tuple[int, int] | None = None,
222
+ mode: str = "resize",
223
+ raise_error_for_ill_resize: bool = True,
224
+ ) -> np.ndarray:
225
+ """Convert an image to a numpy array.
226
+
227
+ Args:
228
+ content: The image to convert.
229
+ resize_to: The size to resize the image to.
230
+ mode: The mode to resize the image to.
231
+ raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
232
+
233
+ Returns:
234
+ A numpy array.
235
+ """
236
+ assert isinstance(image, Image.Image), "image must be a PIL Image"
237
+ if resize_to is not None:
238
+ if mode == "resize":
239
+ image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
240
+
241
+ elif mode == "rescale_and_pad_to_center":
242
+ scale = min(resize_to[0] / image.width,
243
+ resize_to[1] / image.height, 1.0)
244
+ new_width = round(image.width * scale)
245
+ new_height = round(image.height * scale)
246
+ if new_width == 0 or new_height == 0:
247
+ if raise_error_for_ill_resize:
248
+ raise ValueError(
249
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
250
+ )
251
+ else:
252
+ return np.zeros((resize_to[1], resize_to[0], 3),
253
+ dtype=np.uint8)
254
+
255
+ image = image.resize((new_width, new_height),
256
+ resample=Image.Resampling.BICUBIC)
257
+ padding_left = (resize_to[0] - new_width) // 2
258
+ padding_right = resize_to[0] - new_width - padding_left
259
+ padding_top = (resize_to[1] - new_height) // 2
260
+ padding_bottom = resize_to[1] - new_height - padding_top
261
+ image = np.asarray(image)
262
+ image = np.pad(
263
+ image,
264
+ ((padding_top, padding_bottom), (padding_left, padding_right),
265
+ (0, 0)),
266
+ mode="constant",
267
+ constant_values=0,
268
+ )
269
+ assert image.shape == (resize_to[1], resize_to[0], 3)
270
+
271
+ elif mode == "rescale_and_pad_to_rightbottom":
272
+ scale = min(resize_to[0] / image.width,
273
+ resize_to[1] / image.height, 1.0)
274
+ new_width = round(image.width * scale)
275
+ new_height = round(image.height * scale)
276
+ if new_width == 0 or new_height == 0:
277
+ if raise_error_for_ill_resize:
278
+ raise ValueError(
279
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
280
+ )
281
+ else:
282
+ return np.zeros((resize_to[1], resize_to[0], 3),
283
+ dtype=np.uint8)
284
+
285
+ image = image.resize((new_width, new_height),
286
+ resample=Image.Resampling.BICUBIC)
287
+ padding_right = resize_to[0] - new_width
288
+ padding_bottom = resize_to[1] - new_height
289
+ image = np.asarray(image)
290
+ image = np.pad(
291
+ image,
292
+ ((0, padding_bottom), (0, padding_right), (0, 0)),
293
+ mode="constant",
294
+ constant_values=0,
295
+ )
296
+ assert image.shape == (resize_to[1], resize_to[0], 3)
297
+
298
+ else:
299
+ raise ValueError(f"Invalid mode: {mode}")
300
+
301
+ if isinstance(image, Image.Image):
302
+ return np.asarray(image)
303
+ else:
304
+ return image
305
+
306
+
307
+ def navit_patchify(pixel_values: np.ndarray,
308
+ patch_size: int) -> dict[str, np.ndarray]:
309
+ """Reshape the pixel values to a navit shape.
310
+
311
+ Args:
312
+ pixel_values: np.ndarray, shape (t, h, w, c)
313
+ patch_size: int
314
+
315
+ Returns:
316
+ dict[str, np.ndarray]
317
+ - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
318
+ - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
319
+ """
320
+ T, H, W, C = pixel_values.shape
321
+ assert C == 3, "pixel_values must have 3 channels"
322
+
323
+ patches = pixel_values.reshape(T, H // patch_size, patch_size,
324
+ W // patch_size, patch_size, C)
325
+ # (T, H//patch_size, W//patch_size, C, patch_size, patch_size)
326
+ patches = patches.transpose(0, 1, 3, 5, 2, 4)
327
+ patches = patches.reshape(-1, C, patch_size, patch_size)
328
+ grid_thw = np.array([T, H // patch_size, W // patch_size])
329
+ return {"pixel_values": patches, "grid_thw": grid_thw}
330
+
331
+
332
+ def normalize(x: np.ndarray,
333
+ mean,
334
+ std_inv,
335
+ pixels_dtype: np.dtype = np.float32) -> np.ndarray:
336
+ """Normalize the image.
337
+
338
+ Args:
339
+ x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
340
+ mean: The mean of the image.
341
+ std_inv: The inverse of the std of the image.
342
+ pixels_dtype: The dtype of the image.
343
+ Returns:
344
+ The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
345
+ """
346
+ x = (x / 255.0).astype(pixels_dtype)
347
+ x -= mean
348
+ x *= std_inv
349
+ return x
350
+
351
+
352
+ def _to_tensor(data, **kwargs):
353
+ import torch
354
+
355
+ if isinstance(data, np.ndarray):
356
+ return torch.from_numpy(data).to(**kwargs)
357
+ elif isinstance(data, torch.Tensor):
358
+ return data.to(**kwargs)
359
+ elif isinstance(data, list):
360
+ return [_to_tensor(item, **kwargs) for item in data]
361
+ elif isinstance(data, tuple):
362
+ return tuple(_to_tensor(item, **kwargs) for item in data)
363
+ elif isinstance(data, dict):
364
+ return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
365
+ elif data is None:
366
+ return None
367
+ else:
368
+ raise ValueError(f"Unsupported data type: {type(data)}")
model-00067-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:015093acc292c291f8121bf31b1df34ab43858bc1f7df61ac5339b9553a39a97
3
+ size 2997644120
model-00068-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8573fbb3b9b6a3cbb1b16f6969bb98fcfca03c7762e077d5650c4689a02a5372
3
+ size 2997643656
model-00069-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bee387741974f5303469f9342a925e9f7dc4731f43abb007ae6ff5603acd5af5
3
+ size 2995075328
model-00071-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed264e7c759f0c17471938bdb5ad2eca524278ef2a10dcf8ea56ff8260f13463
3
+ size 2995074848
model-00074-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1c8779c070caac11bce56f6f2c38cc746c00e6013d9055f1c03b540388a149a
3
+ size 2995075088
model-00108-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ac4420cb056a6e5ab7312c8b888f20c9fc83d302b10633fa5542c84d4f3b01f
3
+ size 2995075136
model-00112-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe111ec3a43e96aa103a8c29b4b1cfe608293aad46594a304d998202f5602052
3
+ size 2997644072
model-00113-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58e22bd7c4cebbd46c9c6bcf457b17f4e75b76156c763805adf698c798fbfa48
3
+ size 2995074888
model-00114-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a13c2919b2eda729a428090d74e13ae9b3e24c9d36c08fbbc24f20fefb2927a
3
+ size 2997644120
model-00115-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:984b18f13d59c8f816b62ac22e893214e1fdd0a2c63cba2bc5dc2ae8d9cc7cf0
3
+ size 2997643672
model-00116-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cbf71faf91ee314b0912104c90a91a56c0a0db7de3446ddd8d99e17f0f234f4
3
+ size 2995075312
model-00117-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ec12ec2420671ab0b97f0e583b5fc91662f332fa03818b0525453c9defb9cb5
3
+ size 2997644120
model-00118-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3689dc18ebd537129e5ddce6f628920c0dd42bfef024afc746bfbbb0e0f7d613
3
+ size 2995074848
model-00119-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0943e0f70dfea18de632298a0c7e0cf01fe1aaabfd5c4332a9663b9aa16cefd8
3
+ size 2997644120
model-00120-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c12031e6e48f120ffda14db97419e536a292557d28564320794e6778b164310
3
+ size 2997643912
model-00121-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f0390d0a804ef27543401794e659e49b9691b0f4b746c5aa254bf68200dde92
3
+ size 2995075072
model-00122-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efdb09a2170b76999408a8390e12b2f3fed01e32ac98ecf966594fb293a9d372
3
+ size 2997644128
model-00123-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9715dd553effe3b0471aa2797cc7edb74fb6bad1059541522944013365c3ac1
3
+ size 2997643488
model-00124-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c554e040c157eb9c064100aa707e9c209e70eb431364e7947f528d0b66a756f2
3
+ size 2995075488
model-00125-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e50962f521a9dea55953bbaaef6c94b7ad357334d088bb8435344c870590bd0b
3
+ size 2997644104
model-00126-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a85ed284c50b83dcdce790267710fa700869b09376a5951071750ab6af8a88da
3
+ size 2995074864
model-00127-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1262da6377bcdedb3748c4bb2a0c5bc72dd1554bde887767734f61be3d407780
3
+ size 2997644120
model-00128-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bcd5134e6a894fd240b9733464ea2a5b50629cd362b3b76ba944e4f36171c39
3
+ size 2997643728
model-00160-of-00160.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05cdcb42b66c30b4f236578e8eb17c3f501e6bf0b9ea74216c39af9a1cb99710
3
+ size 293085592
model.safetensors.index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b34bef0556ddc02f77caa9f39dbc9834fa3ccd6811a7feea4d7b8ad3d979f77
3
+ size 18528168
modeling_deepseek.py ADDED
@@ -0,0 +1,1808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch DeepSeek model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.distributed as dist
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import \
35
+ _prepare_4d_causal_attention_mask
36
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast,
38
+ SequenceClassifierOutputWithPast)
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS,
41
+ is_torch_greater_or_equal_than_1_13)
42
+ from transformers.utils import (add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10, logging,
46
+ replace_return_docstrings)
47
+ from transformers.utils.import_utils import is_torch_fx_available
48
+
49
+ from .configuration_deepseek import DeepseekV3Config
50
+
51
+ if is_flash_attn_2_available():
52
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
53
+ from flash_attn.bert_padding import pad_input # noqa
54
+ from flash_attn.bert_padding import index_first_axis, unpad_input
55
+
56
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57
+ # It means that the function will not be traced through and simply appear as a node in the graph.
58
+ if is_torch_fx_available():
59
+ if not is_torch_greater_or_equal_than_1_13:
60
+ import torch.fx
61
+
62
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(
63
+ _prepare_4d_causal_attention_mask)
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
68
+
69
+
70
+ def _get_unpad_data(attention_mask):
71
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
72
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
73
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
74
+ cu_seqlens = F.pad(
75
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
76
+ return (
77
+ indices,
78
+ cu_seqlens,
79
+ max_seqlen_in_batch,
80
+ )
81
+
82
+
83
+ # code modified from transformers 4.48.3 to amend breaks in newer transformers versions
84
+ def get_usable_length(past_key_value,
85
+ new_seq_length: int,
86
+ layer_idx: Optional[int] = 0) -> int:
87
+ max_length = past_key_value.get_max_cache_shape()
88
+ previous_seq_length = past_key_value.get_seq_length(layer_idx)
89
+ if max_length is not None and max_length > 0 and previous_seq_length + new_seq_length > max_length:
90
+ return max_length - new_seq_length
91
+ return previous_seq_length
92
+
93
+
94
+ class DeepseekV3RMSNorm(nn.Module):
95
+
96
+ def __init__(self, hidden_size, eps=1e-6):
97
+ """
98
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
99
+ """
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size))
102
+ self.variance_epsilon = eps
103
+
104
+ def forward(self, hidden_states):
105
+ input_dtype = hidden_states.dtype
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
108
+ hidden_states = hidden_states * torch.rsqrt(variance +
109
+ self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+
113
+ ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
114
+
115
+
116
+ class DeepseekV3RotaryEmbedding(nn.Module):
117
+
118
+ def __init__(self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None):
123
+ super().__init__()
124
+
125
+ self.dim = dim
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.base = base
128
+ inv_freq = 1.0 / (self.base**(
129
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
130
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
131
+
132
+ # Build here to make `torch.jit.trace` work.
133
+ self._set_cos_sin_cache(
134
+ seq_len=max_position_embeddings,
135
+ device=self.inv_freq.device,
136
+ dtype=torch.get_default_dtype(),
137
+ )
138
+ self.max_seq_len_cached = None
139
+
140
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
141
+ self.max_seq_len_cached = seq_len
142
+ t = torch.arange(self.max_seq_len_cached,
143
+ device=device,
144
+ dtype=self.inv_freq.dtype)
145
+
146
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
147
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
+ emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer("cos_cached",
150
+ emb.cos().to(dtype),
151
+ persistent=False)
152
+ self.register_buffer("sin_cached",
153
+ emb.sin().to(dtype),
154
+ persistent=False)
155
+
156
+ def forward(self, x, seq_len=None):
157
+ # x: [bs, num_attention_heads, seq_len, head_size]
158
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
159
+ self._set_cos_sin_cache(seq_len=seq_len,
160
+ device=x.device,
161
+ dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3
170
+ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
171
+ """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ max_position_embeddings=2048,
177
+ base=10000,
178
+ device=None,
179
+ scaling_factor=1.0,
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(self.max_seq_len_cached,
187
+ device=device,
188
+ dtype=self.inv_freq.dtype)
189
+ t = t / self.scaling_factor
190
+
191
+ freqs = torch.outer(t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached",
195
+ emb.cos().to(dtype),
196
+ persistent=False)
197
+ self.register_buffer("sin_cached",
198
+ emb.sin().to(dtype),
199
+ persistent=False)
200
+
201
+
202
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3
203
+ class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
204
+ """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
205
+
206
+ def __init__(
207
+ self,
208
+ dim,
209
+ max_position_embeddings=2048,
210
+ base=10000,
211
+ device=None,
212
+ scaling_factor=1.0,
213
+ ):
214
+ self.scaling_factor = scaling_factor
215
+ super().__init__(dim, max_position_embeddings, base, device)
216
+
217
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
218
+ self.max_seq_len_cached = seq_len
219
+
220
+ if seq_len > self.max_position_embeddings:
221
+ base = self.base * ((self.scaling_factor * seq_len /
222
+ self.max_position_embeddings) -
223
+ (self.scaling_factor - 1))**(self.dim /
224
+ (self.dim - 2))
225
+ inv_freq = 1.0 / (base**(
226
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
227
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
228
+
229
+ t = torch.arange(self.max_seq_len_cached,
230
+ device=device,
231
+ dtype=self.inv_freq.dtype)
232
+
233
+ freqs = torch.outer(t, self.inv_freq)
234
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
235
+ emb = torch.cat((freqs, freqs), dim=-1)
236
+ self.register_buffer("cos_cached",
237
+ emb.cos().to(dtype),
238
+ persistent=False)
239
+ self.register_buffer("sin_cached",
240
+ emb.sin().to(dtype),
241
+ persistent=False)
242
+
243
+
244
+ # Inverse dim formula to find dim based on number of rotations
245
+ def yarn_find_correction_dim(num_rotations,
246
+ dim,
247
+ base=10000,
248
+ max_position_embeddings=2048):
249
+ return (dim * math.log(max_position_embeddings /
250
+ (num_rotations * 2 * math.pi))) / (2 *
251
+ math.log(base))
252
+
253
+
254
+ # Find dim range bounds based on rotations
255
+ def yarn_find_correction_range(low_rot,
256
+ high_rot,
257
+ dim,
258
+ base=10000,
259
+ max_position_embeddings=2048):
260
+ low = math.floor(
261
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
262
+ high = math.ceil(
263
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
264
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
265
+
266
+
267
+ def yarn_get_mscale(scale=1, mscale=1):
268
+ if scale <= 1:
269
+ return 1.0
270
+ return 0.1 * mscale * math.log(scale) + 1.0
271
+
272
+
273
+ def yarn_linear_ramp_mask(min, max, dim):
274
+ if min == max:
275
+ max += 0.001 # Prevent singularity
276
+
277
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
278
+ ramp_func = torch.clamp(linear_func, 0, 1)
279
+ return ramp_func
280
+
281
+
282
+ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
283
+
284
+ def __init__(
285
+ self,
286
+ dim,
287
+ max_position_embeddings=2048,
288
+ base=10000,
289
+ device=None,
290
+ scaling_factor=1.0,
291
+ original_max_position_embeddings=4096,
292
+ beta_fast=32,
293
+ beta_slow=1,
294
+ mscale=1,
295
+ mscale_all_dim=0,
296
+ ):
297
+ self.scaling_factor = scaling_factor
298
+ self.original_max_position_embeddings = original_max_position_embeddings
299
+ self.beta_fast = beta_fast
300
+ self.beta_slow = beta_slow
301
+ self.mscale = mscale
302
+ self.mscale_all_dim = mscale_all_dim
303
+ super().__init__(dim, max_position_embeddings, base, device)
304
+
305
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
306
+ self.max_seq_len_cached = seq_len
307
+ dim = self.dim
308
+
309
+ freq_extra = 1.0 / (self.base**(
310
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
311
+ freq_inter = 1.0 / (self.scaling_factor * self.base**(
312
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
313
+
314
+ low, high = yarn_find_correction_range(
315
+ self.beta_fast,
316
+ self.beta_slow,
317
+ dim,
318
+ self.base,
319
+ self.original_max_position_embeddings,
320
+ )
321
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
322
+ device=device, dtype=torch.float32)
323
+ inv_freq = freq_inter * (1 -
324
+ inv_freq_mask) + freq_extra * inv_freq_mask
325
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
326
+
327
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
328
+
329
+ freqs = torch.outer(t, inv_freq)
330
+
331
+ _mscale = float(
332
+ yarn_get_mscale(self.scaling_factor, self.mscale) /
333
+ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
334
+
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype),
337
+ persistent=False)
338
+ self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype),
339
+ persistent=False)
340
+
341
+
342
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
343
+ def rotate_half(x):
344
+ """Rotates half the hidden dims of the input."""
345
+ x1 = x[..., :x.shape[-1] // 2]
346
+ x2 = x[..., x.shape[-1] // 2:]
347
+ return torch.cat((-x2, x1), dim=-1)
348
+
349
+
350
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
351
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
352
+ """Applies Rotary Position Embedding to the query and key tensors.
353
+
354
+ Args:
355
+ q (`torch.Tensor`): The query tensor.
356
+ k (`torch.Tensor`): The key tensor.
357
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
358
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
359
+ position_ids (`torch.Tensor`):
360
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
361
+ used to pass offsetted position ids when working with a KV-cache.
362
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
363
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
364
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
365
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
366
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
367
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
368
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
369
+ Returns:
370
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
371
+ """
372
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
373
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
374
+
375
+ b, h, s, d = q.shape
376
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
377
+
378
+ b, h, s, d = k.shape
379
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
380
+
381
+ q_embed = (q * cos) + (rotate_half(q) * sin)
382
+ k_embed = (k * cos) + (rotate_half(k) * sin)
383
+ return q_embed, k_embed
384
+
385
+
386
+ class DeepseekV3MLP(nn.Module):
387
+
388
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
389
+ super().__init__()
390
+ self.config = config
391
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
392
+ self.intermediate_size = (config.intermediate_size if intermediate_size
393
+ is None else intermediate_size)
394
+
395
+ self.gate_proj = nn.Linear(self.hidden_size,
396
+ self.intermediate_size,
397
+ bias=False)
398
+ self.up_proj = nn.Linear(self.hidden_size,
399
+ self.intermediate_size,
400
+ bias=False)
401
+ self.down_proj = nn.Linear(self.intermediate_size,
402
+ self.hidden_size,
403
+ bias=False)
404
+ self.act_fn = ACT2FN[config.hidden_act]
405
+
406
+ def forward(self, x):
407
+ down_proj = self.down_proj(
408
+ self.act_fn(self.gate_proj(x)) * self.up_proj(x))
409
+ return down_proj
410
+
411
+
412
+ class MoEGate(nn.Module):
413
+
414
+ def __init__(self, config):
415
+ super().__init__()
416
+ self.config = config
417
+ self.top_k = config.num_experts_per_tok
418
+ self.n_routed_experts = config.n_routed_experts
419
+ self.routed_scaling_factor = config.routed_scaling_factor
420
+ self.scoring_func = config.scoring_func
421
+ self.seq_aux = config.seq_aux
422
+ self.topk_method = config.topk_method
423
+ self.n_group = config.n_group
424
+ self.topk_group = config.topk_group
425
+
426
+ # topk selection algorithm
427
+ self.norm_topk_prob = config.norm_topk_prob
428
+ self.gating_dim = config.hidden_size
429
+ self.weight = nn.Parameter(
430
+ torch.empty((self.n_routed_experts, self.gating_dim)))
431
+ if self.topk_method == "noaux_tc":
432
+ self.e_score_correction_bias = nn.Parameter(
433
+ torch.empty((self.n_routed_experts)))
434
+ self.reset_parameters()
435
+
436
+ def reset_parameters(self) -> None:
437
+ import torch.nn.init as init
438
+
439
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
440
+
441
+ def forward(self, hidden_states):
442
+ bsz, seq_len, h = hidden_states.shape
443
+ ### compute gating score
444
+ hidden_states = hidden_states.view(-1, h)
445
+ logits = F.linear(hidden_states.type(torch.float32),
446
+ self.weight.type(torch.float32), None)
447
+ if self.scoring_func == "sigmoid":
448
+ scores = logits.sigmoid()
449
+ else:
450
+ raise NotImplementedError(
451
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
452
+ )
453
+
454
+ ### select top-k experts
455
+ if self.topk_method == "noaux_tc":
456
+ assert not self.training
457
+ scores_for_choice = scores.view(
458
+ bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
459
+ group_scores = (scores_for_choice.view(
460
+ bsz * seq_len, self.n_group,
461
+ -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
462
+ group_idx = torch.topk(group_scores,
463
+ k=self.topk_group,
464
+ dim=-1,
465
+ sorted=False)[1] # [n, top_k_group]
466
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
467
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
468
+ score_mask = (group_mask.unsqueeze(-1).expand(
469
+ bsz * seq_len, self.n_group,
470
+ self.n_routed_experts // self.n_group).reshape(
471
+ bsz * seq_len, -1)) # [n, e]
472
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
473
+ 0.0) # [n, e]
474
+ _, topk_idx = torch.topk(tmp_scores,
475
+ k=self.top_k,
476
+ dim=-1,
477
+ sorted=False)
478
+ topk_weight = scores.gather(1, topk_idx)
479
+ else:
480
+ raise NotImplementedError(
481
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
482
+ )
483
+
484
+ ### norm gate to sum 1
485
+ if self.top_k > 1 and self.norm_topk_prob:
486
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
487
+ topk_weight = topk_weight / denominator
488
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
489
+
490
+ return topk_idx, topk_weight
491
+
492
+
493
+ class DeepseekV3MoE(nn.Module):
494
+ """
495
+ A mixed expert module containing shared experts.
496
+ """
497
+
498
+ def __init__(self, config):
499
+ super().__init__()
500
+ self.config = config
501
+ self.num_experts_per_tok = config.num_experts_per_tok
502
+
503
+ if hasattr(config, "ep_size") and config.ep_size > 1:
504
+ assert config.ep_size == dist.get_world_size()
505
+ self.ep_size = config.ep_size
506
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
507
+ self.ep_rank = dist.get_rank()
508
+ self.experts = nn.ModuleList([
509
+ (DeepseekV3MLP(config,
510
+ intermediate_size=config.moe_intermediate_size)
511
+ if i >= self.ep_rank * self.experts_per_rank
512
+ and i < (self.ep_rank + 1) * self.experts_per_rank else None)
513
+ for i in range(config.n_routed_experts)
514
+ ])
515
+ else:
516
+ self.ep_size = 1
517
+ self.experts_per_rank = config.n_routed_experts
518
+ self.ep_rank = 0
519
+ self.experts = nn.ModuleList([
520
+ DeepseekV3MLP(config,
521
+ intermediate_size=config.moe_intermediate_size)
522
+ for i in range(config.n_routed_experts)
523
+ ])
524
+ self.gate = MoEGate(config)
525
+ if config.n_shared_experts is not None:
526
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
527
+ self.shared_experts = DeepseekV3MLP(
528
+ config=config, intermediate_size=intermediate_size)
529
+
530
+ def forward(self, hidden_states):
531
+ identity = hidden_states
532
+ orig_shape = hidden_states.shape
533
+ topk_idx, topk_weight = self.gate(hidden_states)
534
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
535
+ flat_topk_idx = topk_idx.view(-1)
536
+ if not self.training:
537
+ y = self.moe_infer(hidden_states, topk_idx,
538
+ topk_weight).view(*orig_shape)
539
+ if self.config.n_shared_experts is not None:
540
+ y = y + self.shared_experts(identity)
541
+ return y
542
+
543
+ @torch.no_grad()
544
+ def moe_infer(self, x, topk_ids, topk_weight):
545
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
546
+ cnts.scatter_(1, topk_ids, 1)
547
+ tokens_per_expert = cnts.sum(dim=0)
548
+ idxs = topk_ids.view(-1).argsort()
549
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
550
+ sorted_tokens_shape = sorted_tokens.shape
551
+ if self.ep_size > 1:
552
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
553
+ -1).sum(dim=1)
554
+ tokens_per_expert_group = tokens_per_expert.new_empty(
555
+ tokens_per_expert.shape[0])
556
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
557
+ output_splits = (tokens_per_expert_group.view(
558
+ self.ep_size, -1).sum(1).cpu().numpy().tolist())
559
+ gathered_tokens = sorted_tokens.new_empty(
560
+ tokens_per_expert_group.sum(dim=0).cpu().item(),
561
+ sorted_tokens.shape[1])
562
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
563
+ dist.all_to_all(
564
+ list(gathered_tokens.split(output_splits)),
565
+ list(sorted_tokens.split(input_split_sizes)),
566
+ )
567
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
568
+ self.ep_size, self.experts_per_rank).sum(dim=0)
569
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0], ),
570
+ dtype=np.int32)
571
+ s = 0
572
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
573
+ gatherd_idxs[s:s + k] = i % self.experts_per_rank
574
+ s += k
575
+ gatherd_idxs = gatherd_idxs.argsort()
576
+ sorted_tokens = gathered_tokens[gatherd_idxs]
577
+ tokens_per_expert = tokens_per_expert_post_gather
578
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
579
+
580
+ outputs = []
581
+ start_idx = 0
582
+ for i, num_tokens in enumerate(tokens_per_expert):
583
+ end_idx = start_idx + num_tokens
584
+ if num_tokens == 0:
585
+ continue
586
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
587
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
588
+ expert_out = expert(tokens_for_this_expert)
589
+ outputs.append(expert_out)
590
+ start_idx = end_idx
591
+
592
+ outs = torch.cat(outputs,
593
+ dim=0) if len(outputs) else sorted_tokens.new_empty(0)
594
+ if self.ep_size > 1:
595
+ new_x = torch.empty_like(outs)
596
+ new_x[gatherd_idxs] = outs
597
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
598
+ dist.all_to_all(
599
+ list(gathered_tokens.split(input_split_sizes)),
600
+ list(new_x.split(output_splits)),
601
+ )
602
+ outs = gathered_tokens
603
+
604
+ new_x = torch.empty_like(outs)
605
+ new_x[idxs] = outs
606
+ final_out = (new_x.view(
607
+ *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
608
+ topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
609
+ return final_out
610
+
611
+
612
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
613
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
614
+ """
615
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
616
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
617
+ """
618
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
619
+ if n_rep == 1:
620
+ return hidden_states
621
+ hidden_states = hidden_states[:, :,
622
+ None, :, :].expand(batch,
623
+ num_key_value_heads,
624
+ n_rep, slen, head_dim)
625
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
626
+ head_dim)
627
+
628
+
629
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3
630
+ class DeepseekV3Attention(nn.Module):
631
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
632
+
633
+ def __init__(self,
634
+ config: DeepseekV3Config,
635
+ layer_idx: Optional[int] = None):
636
+ super().__init__()
637
+ self.config = config
638
+ self.layer_idx = layer_idx
639
+ if layer_idx is None:
640
+ logger.warning_once(
641
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
642
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
643
+ "when creating this class.")
644
+
645
+ self.attention_dropout = config.attention_dropout
646
+ self.hidden_size = config.hidden_size
647
+ self.num_heads = config.num_attention_heads
648
+
649
+ self.max_position_embeddings = config.max_position_embeddings
650
+ self.rope_theta = config.rope_theta
651
+ self.q_lora_rank = config.q_lora_rank
652
+ self.qk_rope_head_dim = config.qk_rope_head_dim
653
+ self.kv_lora_rank = config.kv_lora_rank
654
+ self.v_head_dim = config.v_head_dim
655
+ self.qk_nope_head_dim = config.qk_nope_head_dim
656
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
657
+
658
+ self.is_causal = True
659
+
660
+ if self.q_lora_rank is None:
661
+ self.q_proj = nn.Linear(self.hidden_size,
662
+ self.num_heads * self.q_head_dim,
663
+ bias=False)
664
+ else:
665
+ self.q_a_proj = nn.Linear(self.hidden_size,
666
+ config.q_lora_rank,
667
+ bias=config.attention_bias)
668
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
669
+ self.q_b_proj = nn.Linear(config.q_lora_rank,
670
+ self.num_heads * self.q_head_dim,
671
+ bias=False)
672
+
673
+ self.kv_a_proj_with_mqa = nn.Linear(
674
+ self.hidden_size,
675
+ config.kv_lora_rank + config.qk_rope_head_dim,
676
+ bias=config.attention_bias,
677
+ )
678
+ self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
679
+ self.kv_b_proj = nn.Linear(
680
+ config.kv_lora_rank,
681
+ self.num_heads *
682
+ (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
683
+ bias=False,
684
+ )
685
+
686
+ self.o_proj = nn.Linear(
687
+ self.num_heads * self.v_head_dim,
688
+ self.hidden_size,
689
+ bias=config.attention_bias,
690
+ )
691
+ self._init_rope()
692
+
693
+ self.softmax_scale = self.q_head_dim**(-0.5)
694
+ if self.config.rope_scaling is not None:
695
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
696
+ scaling_factor = self.config.rope_scaling["factor"]
697
+ if mscale_all_dim:
698
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
699
+ self.softmax_scale = self.softmax_scale * mscale * mscale
700
+
701
+ def _init_rope(self):
702
+ if self.config.rope_scaling is None:
703
+ self.rotary_emb = DeepseekV3RotaryEmbedding(
704
+ self.qk_rope_head_dim,
705
+ max_position_embeddings=self.max_position_embeddings,
706
+ base=self.rope_theta,
707
+ )
708
+ else:
709
+ scaling_type = self.config.rope_scaling["type"]
710
+ scaling_factor = self.config.rope_scaling["factor"]
711
+ if scaling_type == "linear":
712
+ self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
713
+ self.qk_rope_head_dim,
714
+ max_position_embeddings=self.max_position_embeddings,
715
+ scaling_factor=scaling_factor,
716
+ base=self.rope_theta,
717
+ )
718
+ elif scaling_type == "dynamic":
719
+ self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
720
+ self.qk_rope_head_dim,
721
+ max_position_embeddings=self.max_position_embeddings,
722
+ scaling_factor=scaling_factor,
723
+ base=self.rope_theta,
724
+ )
725
+ elif scaling_type == "yarn":
726
+ kwargs = {
727
+ key: self.config.rope_scaling[key]
728
+ for key in [
729
+ "original_max_position_embeddings",
730
+ "beta_fast",
731
+ "beta_slow",
732
+ "mscale",
733
+ "mscale_all_dim",
734
+ ] if key in self.config.rope_scaling
735
+ }
736
+ self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
737
+ self.qk_rope_head_dim,
738
+ max_position_embeddings=self.max_position_embeddings,
739
+ scaling_factor=scaling_factor,
740
+ base=self.rope_theta,
741
+ **kwargs,
742
+ )
743
+ else:
744
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
745
+
746
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
747
+ return (tensor.view(bsz, seq_len, self.num_heads,
748
+ self.v_head_dim).transpose(1, 2).contiguous())
749
+
750
+ def forward(
751
+ self,
752
+ hidden_states: torch.Tensor,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ position_ids: Optional[torch.LongTensor] = None,
755
+ past_key_value: Optional[Cache] = None,
756
+ output_attentions: bool = False,
757
+ use_cache: bool = False,
758
+ **kwargs,
759
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
760
+ Optional[Tuple[torch.Tensor]]]:
761
+ if "padding_mask" in kwargs:
762
+ warnings.warn(
763
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
764
+ )
765
+ bsz, q_len, _ = hidden_states.size()
766
+
767
+ if self.q_lora_rank is None:
768
+ q = self.q_proj(hidden_states)
769
+ else:
770
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
771
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
772
+ q_nope, q_pe = torch.split(
773
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
774
+
775
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
776
+ compressed_kv, k_pe = torch.split(
777
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
778
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
779
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
780
+ bsz, q_len, self.num_heads,
781
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
782
+
783
+ k_nope, value_states = torch.split(
784
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
785
+ kv_seq_len = value_states.shape[-2]
786
+ if past_key_value is not None:
787
+ if self.layer_idx is None:
788
+ raise ValueError(
789
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
790
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
791
+ "with a layer index.")
792
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
793
+ self.layer_idx)
794
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
795
+
796
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
797
+
798
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
799
+ self.q_head_dim)
800
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
801
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
802
+
803
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
804
+ self.q_head_dim)
805
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
806
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
807
+ if past_key_value is not None:
808
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
809
+ key_states, value_states = past_key_value.update(
810
+ key_states, value_states, self.layer_idx, cache_kwargs)
811
+
812
+ attn_weights = (
813
+ torch.matmul(query_states, key_states.transpose(2, 3)) *
814
+ self.softmax_scale)
815
+
816
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
817
+ raise ValueError(
818
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
819
+ f" {attn_weights.size()}")
820
+ assert attention_mask is not None
821
+ if attention_mask is not None:
822
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
823
+ raise ValueError(
824
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
825
+ )
826
+ attn_weights = attn_weights + attention_mask
827
+
828
+ # upcast attention to fp32
829
+ attn_weights = nn.functional.softmax(attn_weights,
830
+ dim=-1,
831
+ dtype=torch.float32).to(
832
+ query_states.dtype)
833
+ attn_weights = nn.functional.dropout(attn_weights,
834
+ p=self.attention_dropout,
835
+ training=self.training)
836
+ attn_output = torch.matmul(attn_weights, value_states)
837
+
838
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
839
+ raise ValueError(
840
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
841
+ f" {attn_output.size()}")
842
+
843
+ attn_output = attn_output.transpose(1, 2).contiguous()
844
+
845
+ attn_output = attn_output.reshape(bsz, q_len,
846
+ self.num_heads * self.v_head_dim)
847
+
848
+ attn_output = self.o_proj(attn_output)
849
+
850
+ if not output_attentions:
851
+ attn_weights = None
852
+
853
+ return attn_output, attn_weights, past_key_value
854
+
855
+
856
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
857
+ class DeepseekV3FlashAttention2(DeepseekV3Attention):
858
+ """
859
+ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
860
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
861
+ flash attention and deal with padding tokens in case the input contains any of them.
862
+ """
863
+
864
+ def __init__(self, *args, **kwargs):
865
+ super().__init__(*args, **kwargs)
866
+
867
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
868
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
869
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
870
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
871
+ )
872
+
873
+ def forward(
874
+ self,
875
+ hidden_states: torch.Tensor,
876
+ attention_mask: Optional[torch.LongTensor] = None,
877
+ position_ids: Optional[torch.LongTensor] = None,
878
+ past_key_value: Optional[Cache] = None,
879
+ output_attentions: bool = False,
880
+ use_cache: bool = False,
881
+ **kwargs,
882
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
883
+ Optional[Tuple[torch.Tensor]]]:
884
+ # DeepseekV3FlashAttention2 attention does not support output_attentions
885
+ if "padding_mask" in kwargs:
886
+ warnings.warn(
887
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
888
+ )
889
+
890
+ # overwrite attention_mask with padding_mask
891
+ attention_mask = kwargs.pop("padding_mask")
892
+
893
+ output_attentions = False
894
+
895
+ bsz, q_len, _ = hidden_states.size()
896
+
897
+ if self.q_lora_rank is None:
898
+ q = self.q_proj(hidden_states)
899
+ else:
900
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
901
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
902
+ q_nope, q_pe = torch.split(
903
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
904
+
905
+ # Flash attention requires the input to have the shape
906
+ # batch_size x seq_length x head_dim x hidden_dim
907
+ # therefore we just need to keep the original shape
908
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
909
+ compressed_kv, k_pe = torch.split(
910
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
911
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
912
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
913
+ bsz, q_len, self.num_heads,
914
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
915
+
916
+ k_nope, value_states = torch.split(
917
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
918
+ kv_seq_len = value_states.shape[-2]
919
+
920
+ kv_seq_len = value_states.shape[-2]
921
+ if past_key_value is not None:
922
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
923
+ self.layer_idx)
924
+
925
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
926
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
927
+
928
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
929
+ self.q_head_dim)
930
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
931
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
932
+
933
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
934
+ self.q_head_dim)
935
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
936
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
937
+
938
+ if self.q_head_dim != self.v_head_dim:
939
+ value_states = F.pad(value_states,
940
+ [0, self.q_head_dim - self.v_head_dim])
941
+
942
+ if past_key_value is not None:
943
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
944
+ key_states, value_states = past_key_value.update(
945
+ key_states, value_states, self.layer_idx, cache_kwargs)
946
+
947
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
948
+ # to be able to avoid many of these transpose/reshape/view.
949
+ query_states = query_states.transpose(1, 2)
950
+ key_states = key_states.transpose(1, 2)
951
+ value_states = value_states.transpose(1, 2)
952
+
953
+ dropout_rate = self.attention_dropout if self.training else 0.0
954
+
955
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
956
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
957
+ # cast them back in the correct dtype just to be sure everything works as expected.
958
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
959
+ # in fp32. (DeepseekV3RMSNorm handles it correctly)
960
+
961
+ input_dtype = query_states.dtype
962
+ if input_dtype == torch.float32:
963
+ # Handle the case where the model is quantized
964
+ if hasattr(self.config, "_pre_quantization_dtype"):
965
+ target_dtype = self.config._pre_quantization_dtype
966
+ elif torch.is_autocast_enabled():
967
+ target_dtype = torch.get_autocast_gpu_dtype()
968
+ else:
969
+ target_dtype = (self.q_proj.weight.dtype if self.q_lora_rank
970
+ is None else self.q_a_proj.weight.dtype)
971
+
972
+ logger.warning_once(
973
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
974
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
975
+ f" {target_dtype}.")
976
+
977
+ query_states = query_states.to(target_dtype)
978
+ key_states = key_states.to(target_dtype)
979
+ value_states = value_states.to(target_dtype)
980
+
981
+ attn_output = self._flash_attention_forward(
982
+ query_states,
983
+ key_states,
984
+ value_states,
985
+ attention_mask,
986
+ q_len,
987
+ dropout=dropout_rate,
988
+ softmax_scale=self.softmax_scale,
989
+ )
990
+ if self.q_head_dim != self.v_head_dim:
991
+ attn_output = attn_output[:, :, :, :self.v_head_dim]
992
+
993
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads *
994
+ self.v_head_dim).contiguous()
995
+ attn_output = self.o_proj(attn_output)
996
+
997
+ if not output_attentions:
998
+ attn_weights = None
999
+
1000
+ return attn_output, attn_weights, past_key_value
1001
+
1002
+ def _flash_attention_forward(
1003
+ self,
1004
+ query_states,
1005
+ key_states,
1006
+ value_states,
1007
+ attention_mask,
1008
+ query_length,
1009
+ dropout=0.0,
1010
+ softmax_scale=None,
1011
+ ):
1012
+ """
1013
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1014
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1015
+
1016
+ Args:
1017
+ query_states (`torch.Tensor`):
1018
+ Input query states to be passed to Flash Attention API
1019
+ key_states (`torch.Tensor`):
1020
+ Input key states to be passed to Flash Attention API
1021
+ value_states (`torch.Tensor`):
1022
+ Input value states to be passed to Flash Attention API
1023
+ attention_mask (`torch.Tensor`):
1024
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1025
+ position of padding tokens and 1 for the position of non-padding tokens.
1026
+ dropout (`int`, *optional*):
1027
+ Attention dropout
1028
+ softmax_scale (`float`, *optional*):
1029
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1030
+ """
1031
+ if not self._flash_attn_uses_top_left_mask:
1032
+ causal = self.is_causal
1033
+ else:
1034
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
1035
+ causal = self.is_causal and query_length != 1
1036
+
1037
+ # Contains at least one padding token in the sequence
1038
+ if attention_mask is not None:
1039
+ batch_size = query_states.shape[0]
1040
+ (
1041
+ query_states,
1042
+ key_states,
1043
+ value_states,
1044
+ indices_q,
1045
+ cu_seq_lens,
1046
+ max_seq_lens,
1047
+ ) = self._upad_input(query_states, key_states, value_states,
1048
+ attention_mask, query_length)
1049
+
1050
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1051
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1052
+
1053
+ attn_output_unpad = flash_attn_varlen_func(
1054
+ query_states,
1055
+ key_states,
1056
+ value_states,
1057
+ cu_seqlens_q=cu_seqlens_q,
1058
+ cu_seqlens_k=cu_seqlens_k,
1059
+ max_seqlen_q=max_seqlen_in_batch_q,
1060
+ max_seqlen_k=max_seqlen_in_batch_k,
1061
+ dropout_p=dropout,
1062
+ softmax_scale=softmax_scale,
1063
+ causal=causal,
1064
+ )
1065
+
1066
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
1067
+ query_length)
1068
+ else:
1069
+ attn_output = flash_attn_func(
1070
+ query_states,
1071
+ key_states,
1072
+ value_states,
1073
+ dropout,
1074
+ softmax_scale=softmax_scale,
1075
+ causal=causal,
1076
+ )
1077
+
1078
+ return attn_output
1079
+
1080
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
1081
+ query_length):
1082
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
1083
+ attention_mask)
1084
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1085
+
1086
+ key_layer = index_first_axis(
1087
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1088
+ head_dim),
1089
+ indices_k,
1090
+ )
1091
+ value_layer = index_first_axis(
1092
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1093
+ head_dim),
1094
+ indices_k,
1095
+ )
1096
+ if query_length == kv_seq_len:
1097
+ query_layer = index_first_axis(
1098
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
1099
+ head_dim),
1100
+ indices_k,
1101
+ )
1102
+ cu_seqlens_q = cu_seqlens_k
1103
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1104
+ indices_q = indices_k
1105
+ elif query_length == 1:
1106
+ max_seqlen_in_batch_q = 1
1107
+ cu_seqlens_q = torch.arange(
1108
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1109
+ ) # There is a memcpy here, that is very bad.
1110
+ indices_q = cu_seqlens_q[:-1]
1111
+ query_layer = query_layer.squeeze(1)
1112
+ else:
1113
+ # The -q_len: slice assumes left padding.
1114
+ attention_mask = attention_mask[:, -query_length:]
1115
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1116
+ query_layer, attention_mask)
1117
+
1118
+ return (
1119
+ query_layer,
1120
+ key_layer,
1121
+ value_layer,
1122
+ indices_q,
1123
+ (cu_seqlens_q, cu_seqlens_k),
1124
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1125
+ )
1126
+
1127
+
1128
+ ATTENTION_CLASSES = {
1129
+ "eager": DeepseekV3Attention,
1130
+ "flash_attention_2": DeepseekV3FlashAttention2,
1131
+ }
1132
+
1133
+
1134
+ class DeepseekV3DecoderLayer(nn.Module):
1135
+
1136
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
1137
+ super().__init__()
1138
+ self.hidden_size = config.hidden_size
1139
+
1140
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1141
+ config=config, layer_idx=layer_idx)
1142
+
1143
+ self.mlp = (DeepseekV3MoE(config) if
1144
+ (config.n_routed_experts is not None
1145
+ and layer_idx >= config.first_k_dense_replace
1146
+ and layer_idx % config.moe_layer_freq == 0) else
1147
+ DeepseekV3MLP(config))
1148
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size,
1149
+ eps=config.rms_norm_eps)
1150
+ self.post_attention_layernorm = DeepseekV3RMSNorm(
1151
+ config.hidden_size, eps=config.rms_norm_eps)
1152
+
1153
+ def forward(
1154
+ self,
1155
+ hidden_states: torch.Tensor,
1156
+ attention_mask: Optional[torch.Tensor] = None,
1157
+ position_ids: Optional[torch.LongTensor] = None,
1158
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1159
+ output_attentions: Optional[bool] = False,
1160
+ use_cache: Optional[bool] = False,
1161
+ **kwargs,
1162
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
1163
+ torch.FloatTensor]]]:
1164
+ """
1165
+ Args:
1166
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1167
+ attention_mask (`torch.FloatTensor`, *optional*):
1168
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1169
+ query_sequence_length, key_sequence_length)` if default attention is used.
1170
+ output_attentions (`bool`, *optional*):
1171
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1172
+ returned tensors for more detail.
1173
+ use_cache (`bool`, *optional*):
1174
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1175
+ (see `past_key_values`).
1176
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1177
+ """
1178
+ if "padding_mask" in kwargs:
1179
+ warnings.warn(
1180
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1181
+ )
1182
+ residual = hidden_states
1183
+
1184
+ hidden_states = self.input_layernorm(hidden_states)
1185
+
1186
+ # Self Attention
1187
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1188
+ hidden_states=hidden_states,
1189
+ attention_mask=attention_mask,
1190
+ position_ids=position_ids,
1191
+ past_key_value=past_key_value,
1192
+ output_attentions=output_attentions,
1193
+ use_cache=use_cache,
1194
+ **kwargs,
1195
+ )
1196
+ hidden_states = residual + hidden_states
1197
+
1198
+ # Fully Connected
1199
+ residual = hidden_states
1200
+ hidden_states = self.post_attention_layernorm(hidden_states)
1201
+ hidden_states = self.mlp(hidden_states)
1202
+ hidden_states = residual + hidden_states
1203
+
1204
+ outputs = (hidden_states, )
1205
+
1206
+ if output_attentions:
1207
+ outputs += (self_attn_weights, )
1208
+
1209
+ if use_cache:
1210
+ outputs += (present_key_value, )
1211
+
1212
+ return outputs
1213
+
1214
+
1215
+ DeepseekV3_START_DOCSTRING = r"""
1216
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1217
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1218
+ etc.)
1219
+
1220
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1221
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1222
+ and behavior.
1223
+
1224
+ Parameters:
1225
+ config ([`DeepseekV3Config`]):
1226
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1227
+ load the weights associated with the model, only the configuration. Check out the
1228
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1229
+ """
1230
+
1231
+
1232
+ @add_start_docstrings(
1233
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1234
+ DeepseekV3_START_DOCSTRING,
1235
+ )
1236
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
1237
+ config_class = DeepseekV3Config
1238
+ base_model_prefix = "model"
1239
+ supports_gradient_checkpointing = True
1240
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
1241
+ _skip_keys_device_placement = "past_key_values"
1242
+ _supports_flash_attn_2 = True
1243
+ _supports_cache_class = True
1244
+
1245
+ def _init_weights(self, module):
1246
+ std = self.config.initializer_range
1247
+ if isinstance(module, nn.Linear):
1248
+ module.weight.data.normal_(mean=0.0, std=std)
1249
+ if module.bias is not None:
1250
+ module.bias.data.zero_()
1251
+ elif isinstance(module, nn.Embedding):
1252
+ module.weight.data.normal_(mean=0.0, std=std)
1253
+ if module.padding_idx is not None:
1254
+ module.weight.data[module.padding_idx].zero_()
1255
+
1256
+
1257
+ DeepseekV3_INPUTS_DOCSTRING = r"""
1258
+ Args:
1259
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1260
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1261
+ it.
1262
+
1263
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1264
+ [`PreTrainedTokenizer.__call__`] for details.
1265
+
1266
+ [What are input IDs?](../glossary#input-ids)
1267
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1268
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1269
+
1270
+ - 1 for tokens that are **not masked**,
1271
+ - 0 for tokens that are **masked**.
1272
+
1273
+ [What are attention masks?](../glossary#attention-mask)
1274
+
1275
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1276
+ [`PreTrainedTokenizer.__call__`] for details.
1277
+
1278
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1279
+ `past_key_values`).
1280
+
1281
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1282
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1283
+ information on the default strategy.
1284
+
1285
+ - 1 indicates the head is **not masked**,
1286
+ - 0 indicates the head is **masked**.
1287
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1288
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1289
+ config.n_positions - 1]`.
1290
+
1291
+ [What are position IDs?](../glossary#position-ids)
1292
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1293
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1294
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1295
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1296
+
1297
+ Two formats are allowed:
1298
+ - a [`~cache_utils.Cache`] instance;
1299
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1300
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1301
+ cache format.
1302
+
1303
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1304
+ legacy cache format will be returned.
1305
+
1306
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1307
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1308
+ of shape `(batch_size, sequence_length)`.
1309
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1310
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1311
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1312
+ model's internal embedding lookup matrix.
1313
+ use_cache (`bool`, *optional*):
1314
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1315
+ `past_key_values`).
1316
+ output_attentions (`bool`, *optional*):
1317
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1318
+ tensors for more detail.
1319
+ output_hidden_states (`bool`, *optional*):
1320
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1321
+ more detail.
1322
+ return_dict (`bool`, *optional*):
1323
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1324
+ """
1325
+
1326
+
1327
+ @add_start_docstrings(
1328
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1329
+ DeepseekV3_START_DOCSTRING,
1330
+ )
1331
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1332
+ """
1333
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
1334
+
1335
+ Args:
1336
+ config: DeepseekV3Config
1337
+ """
1338
+
1339
+ def __init__(self, config: DeepseekV3Config):
1340
+ super().__init__(config)
1341
+ self.padding_idx = config.pad_token_id
1342
+ self.vocab_size = config.vocab_size
1343
+
1344
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
1345
+ self.padding_idx)
1346
+ self.layers = nn.ModuleList([
1347
+ DeepseekV3DecoderLayer(config, layer_idx)
1348
+ for layer_idx in range(config.num_hidden_layers)
1349
+ ])
1350
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1351
+ self.norm = DeepseekV3RMSNorm(config.hidden_size,
1352
+ eps=config.rms_norm_eps)
1353
+
1354
+ self.gradient_checkpointing = False
1355
+ # Initialize weights and apply final processing
1356
+ self.post_init()
1357
+
1358
+ def get_input_embeddings(self):
1359
+ return self.embed_tokens
1360
+
1361
+ def set_input_embeddings(self, value):
1362
+ self.embed_tokens = value
1363
+
1364
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1365
+ def forward(
1366
+ self,
1367
+ input_ids: torch.LongTensor = None,
1368
+ attention_mask: Optional[torch.Tensor] = None,
1369
+ position_ids: Optional[torch.LongTensor] = None,
1370
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1371
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1372
+ use_cache: Optional[bool] = None,
1373
+ output_attentions: Optional[bool] = None,
1374
+ output_hidden_states: Optional[bool] = None,
1375
+ return_dict: Optional[bool] = None,
1376
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1377
+ output_attentions = (output_attentions if output_attentions is not None
1378
+ else self.config.output_attentions)
1379
+ output_hidden_states = (output_hidden_states
1380
+ if output_hidden_states is not None else
1381
+ self.config.output_hidden_states)
1382
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1383
+
1384
+ return_dict = (return_dict if return_dict is not None else
1385
+ self.config.use_return_dict)
1386
+
1387
+ # retrieve input_ids and inputs_embeds
1388
+ if input_ids is not None and inputs_embeds is not None:
1389
+ raise ValueError(
1390
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1391
+ )
1392
+ elif input_ids is not None:
1393
+ batch_size, seq_length = input_ids.shape[:2]
1394
+ elif inputs_embeds is not None:
1395
+ batch_size, seq_length = inputs_embeds.shape[:2]
1396
+ else:
1397
+ raise ValueError(
1398
+ "You have to specify either input_ids or inputs_embeds")
1399
+
1400
+ past_key_values_length = 0
1401
+ if use_cache:
1402
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1403
+ if use_legacy_cache:
1404
+ past_key_values = DynamicCache.from_legacy_cache(
1405
+ past_key_values)
1406
+ past_key_values_length = get_usable_length(past_key_values,
1407
+ seq_length)
1408
+
1409
+ if position_ids is None:
1410
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1411
+ position_ids = torch.arange(
1412
+ past_key_values_length,
1413
+ seq_length + past_key_values_length,
1414
+ dtype=torch.long,
1415
+ device=device,
1416
+ )
1417
+ position_ids = position_ids.unsqueeze(0)
1418
+
1419
+ if inputs_embeds is None:
1420
+ inputs_embeds = self.embed_tokens(input_ids)
1421
+
1422
+ if self._use_flash_attention_2:
1423
+ # 2d mask is passed through the layers
1424
+ attention_mask = (attention_mask if
1425
+ (attention_mask is not None
1426
+ and 0 in attention_mask) else None)
1427
+ else:
1428
+ # 4d mask is passed through the layers
1429
+ attention_mask = _prepare_4d_causal_attention_mask(
1430
+ attention_mask,
1431
+ (batch_size, seq_length),
1432
+ inputs_embeds,
1433
+ past_key_values_length,
1434
+ )
1435
+
1436
+ # embed positions
1437
+ hidden_states = inputs_embeds
1438
+
1439
+ # decoder layers
1440
+ all_hidden_states = () if output_hidden_states else None
1441
+ all_self_attns = () if output_attentions else None
1442
+ next_decoder_cache = None
1443
+
1444
+ for decoder_layer in self.layers:
1445
+ if output_hidden_states:
1446
+ all_hidden_states += (hidden_states, )
1447
+
1448
+ layer_outputs = decoder_layer(
1449
+ hidden_states,
1450
+ attention_mask=attention_mask,
1451
+ position_ids=position_ids,
1452
+ past_key_value=past_key_values,
1453
+ output_attentions=output_attentions,
1454
+ use_cache=use_cache,
1455
+ )
1456
+
1457
+ hidden_states = layer_outputs[0]
1458
+
1459
+ if use_cache:
1460
+ next_decoder_cache = layer_outputs[
1461
+ 2 if output_attentions else 1]
1462
+
1463
+ if output_attentions:
1464
+ all_self_attns += (layer_outputs[1], )
1465
+
1466
+ hidden_states = self.norm(hidden_states)
1467
+
1468
+ # add hidden states from the last decoder layer
1469
+ if output_hidden_states:
1470
+ all_hidden_states += (hidden_states, )
1471
+
1472
+ next_cache = None
1473
+ if use_cache:
1474
+ next_cache = (next_decoder_cache.to_legacy_cache()
1475
+ if use_legacy_cache else next_decoder_cache)
1476
+ if not return_dict:
1477
+ return tuple(
1478
+ v for v in
1479
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
1480
+ if v is not None)
1481
+ return BaseModelOutputWithPast(
1482
+ last_hidden_state=hidden_states,
1483
+ past_key_values=next_cache,
1484
+ hidden_states=all_hidden_states,
1485
+ attentions=all_self_attns,
1486
+ )
1487
+
1488
+
1489
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1490
+ _tied_weights_keys = ["lm_head.weight"]
1491
+
1492
+ def __init__(self, config):
1493
+ super().__init__(config)
1494
+ self.model = DeepseekV3Model(config)
1495
+ self.vocab_size = config.vocab_size
1496
+ self.lm_head = nn.Linear(config.hidden_size,
1497
+ config.vocab_size,
1498
+ bias=False)
1499
+
1500
+ # Initialize weights and apply final processing
1501
+ self.post_init()
1502
+
1503
+ def get_input_embeddings(self):
1504
+ return self.model.embed_tokens
1505
+
1506
+ def set_input_embeddings(self, value):
1507
+ self.model.embed_tokens = value
1508
+
1509
+ def get_output_embeddings(self):
1510
+ return self.lm_head
1511
+
1512
+ def set_output_embeddings(self, new_embeddings):
1513
+ self.lm_head = new_embeddings
1514
+
1515
+ def set_decoder(self, decoder):
1516
+ self.model = decoder
1517
+
1518
+ def get_decoder(self):
1519
+ return self.model
1520
+
1521
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1522
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast,
1523
+ config_class=_CONFIG_FOR_DOC)
1524
+ def forward(
1525
+ self,
1526
+ input_ids: torch.LongTensor = None,
1527
+ attention_mask: Optional[torch.Tensor] = None,
1528
+ position_ids: Optional[torch.LongTensor] = None,
1529
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1530
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1531
+ labels: Optional[torch.LongTensor] = None,
1532
+ use_cache: Optional[bool] = None,
1533
+ output_attentions: Optional[bool] = None,
1534
+ output_hidden_states: Optional[bool] = None,
1535
+ return_dict: Optional[bool] = None,
1536
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1537
+ r"""
1538
+ Args:
1539
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1540
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1541
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1542
+ (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1543
+
1544
+ Returns:
1545
+
1546
+ Example:
1547
+
1548
+ ```python
1549
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
1550
+
1551
+ >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1552
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1553
+
1554
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1555
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1556
+
1557
+ >>> # Generate
1558
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1559
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1560
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1561
+ ```"""
1562
+ output_attentions = (output_attentions if output_attentions is not None
1563
+ else self.config.output_attentions)
1564
+ output_hidden_states = (output_hidden_states
1565
+ if output_hidden_states is not None else
1566
+ self.config.output_hidden_states)
1567
+ return_dict = (return_dict if return_dict is not None else
1568
+ self.config.use_return_dict)
1569
+
1570
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1571
+ outputs = self.model(
1572
+ input_ids=input_ids,
1573
+ attention_mask=attention_mask,
1574
+ position_ids=position_ids,
1575
+ past_key_values=past_key_values,
1576
+ inputs_embeds=inputs_embeds,
1577
+ use_cache=use_cache,
1578
+ output_attentions=output_attentions,
1579
+ output_hidden_states=output_hidden_states,
1580
+ return_dict=return_dict,
1581
+ )
1582
+
1583
+ hidden_states = outputs[0]
1584
+ logits = self.lm_head(hidden_states)
1585
+ logits = logits.float()
1586
+
1587
+ loss = None
1588
+ if labels is not None:
1589
+ # Shift so that tokens < n predict n
1590
+ shift_logits = logits[..., :-1, :].contiguous()
1591
+ shift_labels = labels[..., 1:].contiguous()
1592
+ # Flatten the tokens
1593
+ loss_fct = CrossEntropyLoss()
1594
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1595
+ shift_labels = shift_labels.view(-1)
1596
+ # Enable model parallelism
1597
+ shift_labels = shift_labels.to(shift_logits.device)
1598
+ loss = loss_fct(shift_logits, shift_labels)
1599
+
1600
+ if not return_dict:
1601
+ output = (logits, ) + outputs[1:]
1602
+ return (loss, ) + output if loss is not None else output
1603
+
1604
+ return CausalLMOutputWithPast(
1605
+ loss=loss,
1606
+ logits=logits,
1607
+ past_key_values=outputs.past_key_values,
1608
+ hidden_states=outputs.hidden_states,
1609
+ attentions=outputs.attentions,
1610
+ )
1611
+
1612
+ def prepare_inputs_for_generation(
1613
+ self,
1614
+ input_ids,
1615
+ past_key_values=None,
1616
+ attention_mask=None,
1617
+ inputs_embeds=None,
1618
+ **kwargs,
1619
+ ):
1620
+ if past_key_values is not None:
1621
+ if isinstance(past_key_values, Cache):
1622
+ cache_length = past_key_values.get_seq_length()
1623
+ # seen_tokens 可能在某些 transformers 版本中不存在,使用 getattr 安全访问
1624
+ past_length = getattr(past_key_values, 'seen_tokens',
1625
+ cache_length)
1626
+ max_cache_length = past_key_values.get_max_length()
1627
+ else:
1628
+ cache_length = past_length = past_key_values[0][0].shape[2]
1629
+ max_cache_length = None
1630
+
1631
+ # Keep only the unprocessed tokens:
1632
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1633
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1634
+ # input)
1635
+ if (attention_mask is not None
1636
+ and attention_mask.shape[1] > input_ids.shape[1]):
1637
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
1638
+ past_length):]
1639
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1640
+ # input_ids based on the past_length.
1641
+ elif past_length < input_ids.shape[1]:
1642
+ input_ids = input_ids[:, past_length:]
1643
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1644
+
1645
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1646
+ if (max_cache_length is not None and attention_mask is not None
1647
+ and cache_length + input_ids.shape[1] > max_cache_length):
1648
+ attention_mask = attention_mask[:, -max_cache_length:]
1649
+
1650
+ position_ids = kwargs.get("position_ids", None)
1651
+ if attention_mask is not None and position_ids is None:
1652
+ # create position_ids on the fly for batch generation
1653
+ position_ids = attention_mask.long().cumsum(-1) - 1
1654
+ position_ids.masked_fill_(attention_mask == 0, 1)
1655
+ if past_key_values:
1656
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1657
+
1658
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1659
+ if inputs_embeds is not None and past_key_values is None:
1660
+ model_inputs = {"inputs_embeds": inputs_embeds}
1661
+ else:
1662
+ model_inputs = {"input_ids": input_ids}
1663
+
1664
+ model_inputs.update({
1665
+ "position_ids": position_ids,
1666
+ "past_key_values": past_key_values,
1667
+ "use_cache": kwargs.get("use_cache"),
1668
+ "attention_mask": attention_mask,
1669
+ })
1670
+ return model_inputs
1671
+
1672
+ @staticmethod
1673
+ def _reorder_cache(past_key_values, beam_idx):
1674
+ reordered_past = ()
1675
+ for layer_past in past_key_values:
1676
+ reordered_past += (tuple(
1677
+ past_state.index_select(0, beam_idx.to(past_state.device))
1678
+ for past_state in layer_past), )
1679
+ return reordered_past
1680
+
1681
+
1682
+ @add_start_docstrings(
1683
+ """
1684
+ The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).
1685
+
1686
+ [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1687
+ (e.g. GPT-2) do.
1688
+
1689
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1690
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1691
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1692
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1693
+ each row of the batch).
1694
+ """,
1695
+ DeepseekV3_START_DOCSTRING,
1696
+ )
1697
+ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1698
+
1699
+ def __init__(self, config):
1700
+ super().__init__(config)
1701
+ self.num_labels = config.num_labels
1702
+ self.model = DeepseekV3Model(config)
1703
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1704
+
1705
+ # Initialize weights and apply final processing
1706
+ self.post_init()
1707
+
1708
+ def get_input_embeddings(self):
1709
+ return self.model.embed_tokens
1710
+
1711
+ def set_input_embeddings(self, value):
1712
+ self.model.embed_tokens = value
1713
+
1714
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1715
+ def forward(
1716
+ self,
1717
+ input_ids: torch.LongTensor = None,
1718
+ attention_mask: Optional[torch.Tensor] = None,
1719
+ position_ids: Optional[torch.LongTensor] = None,
1720
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1721
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1722
+ labels: Optional[torch.LongTensor] = None,
1723
+ use_cache: Optional[bool] = None,
1724
+ output_attentions: Optional[bool] = None,
1725
+ output_hidden_states: Optional[bool] = None,
1726
+ return_dict: Optional[bool] = None,
1727
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1728
+ r"""
1729
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1730
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1731
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1732
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1733
+ """
1734
+ return_dict = (return_dict if return_dict is not None else
1735
+ self.config.use_return_dict)
1736
+
1737
+ transformer_outputs = self.model(
1738
+ input_ids,
1739
+ attention_mask=attention_mask,
1740
+ position_ids=position_ids,
1741
+ past_key_values=past_key_values,
1742
+ inputs_embeds=inputs_embeds,
1743
+ use_cache=use_cache,
1744
+ output_attentions=output_attentions,
1745
+ output_hidden_states=output_hidden_states,
1746
+ return_dict=return_dict,
1747
+ )
1748
+ hidden_states = transformer_outputs[0]
1749
+ logits = self.score(hidden_states)
1750
+
1751
+ if input_ids is not None:
1752
+ batch_size = input_ids.shape[0]
1753
+ else:
1754
+ batch_size = inputs_embeds.shape[0]
1755
+
1756
+ if self.config.pad_token_id is None and batch_size != 1:
1757
+ raise ValueError(
1758
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1759
+ )
1760
+ if self.config.pad_token_id is None:
1761
+ sequence_lengths = -1
1762
+ else:
1763
+ if input_ids is not None:
1764
+ sequence_lengths = (torch.eq(
1765
+ input_ids, self.config.pad_token_id).int().argmax(-1) -
1766
+ 1).to(logits.device)
1767
+ else:
1768
+ sequence_lengths = -1
1769
+
1770
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device),
1771
+ sequence_lengths]
1772
+
1773
+ loss = None
1774
+ if labels is not None:
1775
+ labels = labels.to(logits.device)
1776
+ if self.config.problem_type is None:
1777
+ if self.num_labels == 1:
1778
+ self.config.problem_type = "regression"
1779
+ elif self.num_labels > 1 and (labels.dtype == torch.long
1780
+ or labels.dtype == torch.int):
1781
+ self.config.problem_type = "single_label_classification"
1782
+ else:
1783
+ self.config.problem_type = "multi_label_classification"
1784
+
1785
+ if self.config.problem_type == "regression":
1786
+ loss_fct = MSELoss()
1787
+ if self.num_labels == 1:
1788
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1789
+ else:
1790
+ loss = loss_fct(pooled_logits, labels)
1791
+ elif self.config.problem_type == "single_label_classification":
1792
+ loss_fct = CrossEntropyLoss()
1793
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels),
1794
+ labels.view(-1))
1795
+ elif self.config.problem_type == "multi_label_classification":
1796
+ loss_fct = BCEWithLogitsLoss()
1797
+ loss = loss_fct(pooled_logits, labels)
1798
+ if not return_dict:
1799
+ output = (pooled_logits, ) + transformer_outputs[1:]
1800
+ return ((loss, ) + output) if loss is not None else output
1801
+
1802
+ return SequenceClassifierOutputWithPast(
1803
+ loss=loss,
1804
+ logits=pooled_logits,
1805
+ past_key_values=transformer_outputs.past_key_values,
1806
+ hidden_states=transformer_outputs.hidden_states,
1807
+ attentions=transformer_outputs.attentions,
1808
+ )
modeling_kimi_k25.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-2026 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for Kimi-K2.5.
5
+ #
6
+ # Licensing Information:
7
+ # - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
8
+ # - Other parts of the code are licensed under the MIT License.
9
+ #
10
+ # Apache License, Version 2.0:
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ #
23
+ # MIT License:
24
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
25
+ # of this software and associated documentation files (the "Software"), to deal
26
+ # in the Software without restriction, including without limitation the rights
27
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28
+ # copies of the Software, and to permit persons to whom the Software is
29
+ # furnished to do so, subject to the following conditions:
30
+ #
31
+ # The above copyright notice and this permission notice shall be included in all
32
+ # copies or substantial portions of the Software.
33
+ #
34
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40
+ # SOFTWARE.
41
+ import math
42
+ from collections.abc import Sequence
43
+ from copy import deepcopy
44
+ from typing import Optional
45
+
46
+ import numpy as np
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ from transformers import activations
51
+
52
+ try:
53
+ from transformers.activations import PytorchGELUTanh
54
+ except ImportError:
55
+ from transformers.activations import GELUTanh
56
+ activations.PytorchGELUTanh = GELUTanh
57
+ PytorchGELUTanh = GELUTanh
58
+ from transformers.activations import PytorchGELUTanh
59
+ from transformers.cache_utils import Cache
60
+ from transformers.configuration_utils import PretrainedConfig
61
+ from transformers.modeling_utils import PreTrainedModel
62
+ from transformers.models.llava.modeling_llava import \
63
+ LlavaCausalLMOutputWithPast
64
+ from transformers.utils import is_flash_attn_2_available
65
+
66
+ from .configuration_kimi_k25 import KimiK25Config
67
+ from .modeling_deepseek import DeepseekV3ForCausalLM
68
+
69
+ # Flash attention imports
70
+ if is_flash_attn_2_available():
71
+ from flash_attn import flash_attn_varlen_func
72
+ else:
73
+ flash_attn_varlen_func = None
74
+
75
+
76
+ def multihead_attention(
77
+ q: torch.Tensor,
78
+ k: torch.Tensor,
79
+ v: torch.Tensor,
80
+ q_cu_seqlens: torch.Tensor | None = None,
81
+ k_cu_seqlens: torch.Tensor | None = None,
82
+ max_seqlen_q: int | None = None,
83
+ max_seqlen_k: int | None = None,
84
+ deterministic: bool = False,
85
+ ):
86
+ """Multi-head attention using flash attention 2.
87
+
88
+ Args:
89
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
90
+ or (tot_seqlens, num_heads, head_dim) if packing.
91
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
92
+ The first element should be 0 and the last element should be q.shape[0].
93
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
94
+ The first element should be 0 and the last element should be k.shape[0].
95
+
96
+ Returns:
97
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
98
+ where dim = num_heads * head_dim
99
+ """
100
+ attn_out = flash_attn_varlen_func(
101
+ q,
102
+ k,
103
+ v,
104
+ q_cu_seqlens,
105
+ k_cu_seqlens,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ causal=False,
109
+ deterministic=deterministic,
110
+ )
111
+ if isinstance(attn_out, tuple):
112
+ attn_out = attn_out[0]
113
+
114
+ attn_out = attn_out.flatten(start_dim=-2)
115
+
116
+ return attn_out
117
+
118
+
119
+ def eager_attention(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ q_cu_seqlens: Optional[torch.Tensor] = None,
124
+ k_cu_seqlens: Optional[torch.Tensor] = None,
125
+ **kwargs,
126
+ ) -> torch.Tensor:
127
+ seq_length = q.shape[0]
128
+ attention_mask = torch.zeros([1, seq_length, seq_length],
129
+ device=q.device,
130
+ dtype=torch.bool)
131
+ for i in range(1, len(q_cu_seqlens)):
132
+ attention_mask[
133
+ ...,
134
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
135
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
136
+ ] = True
137
+ q = q.transpose(0, 1)
138
+ k = k.transpose(0, 1)
139
+ v = v.transpose(0, 1)
140
+
141
+ attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
142
+ attn_weight += attention_mask
143
+ attn_weight = torch.softmax(attn_weight, dim=-1,
144
+ dtype=torch.float32).to(q.dtype)
145
+
146
+ attn_output = attn_weight @ v
147
+ attn_output = attn_output.transpose(0, 1)
148
+ attn_output = attn_output.reshape(seq_length, -1)
149
+ return attn_output
150
+
151
+
152
+ VL_VISION_ATTENTION_FUNCTIONS = {
153
+ "flash_attention_2": multihead_attention,
154
+ "eager": eager_attention,
155
+ }
156
+
157
+
158
+ def _apply_rope_input_validation(x, freqs_cis):
159
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
160
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
161
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
162
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
163
+
164
+
165
+ def get_rope_shape_decorate(func):
166
+ _get_rope_shape_first_call_flag = set()
167
+
168
+ def wrapper(org, interpolation_mode, shape):
169
+ key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode)
170
+ if key not in _get_rope_shape_first_call_flag:
171
+ _get_rope_shape_first_call_flag.add(key)
172
+ _ = func(org, interpolation_mode, shape=(64, 64))
173
+ return func(org, interpolation_mode, shape)
174
+
175
+ return wrapper
176
+
177
+
178
+ @get_rope_shape_decorate
179
+ @torch.compile(dynamic=True)
180
+ def get_rope_shape(org, interpolation_mode, shape):
181
+ return (F.interpolate(
182
+ org.permute((2, 0, 1)).unsqueeze(0),
183
+ size=shape,
184
+ mode=interpolation_mode,
185
+ ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
186
+
187
+
188
+ def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
189
+ freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Args: (The leading dimensions of all inputs should be the same)
192
+ xq: query, tensor of shape (..., num_heads, head_dim)
193
+ xk: key, tensor of shape (..., num_heads, head_dim)
194
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
195
+ Returns:
196
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
197
+ """
198
+ _apply_rope_input_validation(xq, freqs_cis)
199
+ _apply_rope_input_validation(xk, freqs_cis)
200
+
201
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
202
+ # ..., num_heads, head_dim/2
203
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
204
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
205
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
206
+ -2) # ..., num_heads, head_dim
207
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
208
+ -2) # ..., num_heads, head_dim
209
+ return xq_out.type_as(xq), xk_out.type_as(xk)
210
+
211
+
212
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
213
+ """
214
+ From:
215
+ https://github.com/OpenGVLab/InternVideo/blob/421f6d2361fc8f61a3394244571f2601a4e99e29/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py#L86
216
+ embed_dim: output dimension for each position
217
+ pos: a list of positions to be encoded: size (M,)
218
+ out: (M, D)
219
+ """
220
+ assert embed_dim % 2 == 0
221
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
222
+ omega /= embed_dim / 2.0
223
+ omega = 1.0 / 10000**omega # (D/2,)
224
+
225
+ pos = pos.reshape(-1) # (M,)
226
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
227
+
228
+ emb_sin = np.sin(out) # (M, D/2)
229
+ emb_cos = np.cos(out) # (M, D/2)
230
+
231
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
232
+ return emb
233
+
234
+
235
+ def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
236
+ """
237
+ t_size: int of the temporal size
238
+ return:
239
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
240
+ """
241
+ grid_t = np.arange(t_size, dtype=np.float32)
242
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
243
+ if cls_token:
244
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
245
+ axis=0)
246
+ return pos_embed
247
+
248
+
249
+ class Learnable2DInterpPosEmbDivided_fixed(nn.Module):
250
+
251
+ def __init__(self,
252
+ height: int,
253
+ width: int,
254
+ num_frames: int,
255
+ dim: int,
256
+ interpolation_mode: str = 'bicubic') -> None:
257
+ super().__init__()
258
+ self.height = height
259
+ self.width = width
260
+ self.num_frames = num_frames
261
+ self.dim = dim
262
+ self.interpolation_mode = interpolation_mode
263
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
264
+ self.register_buffer('time_weight',
265
+ torch.from_numpy(
266
+ get_1d_sincos_pos_embed(
267
+ self.dim,
268
+ self.num_frames)).float().unsqueeze(1),
269
+ persistent=False)
270
+
271
+ self.reset_parameters()
272
+
273
+ def reset_parameters(self):
274
+ nn.init.normal_(self.weight)
275
+
276
+ def forward(self, x: torch.Tensor,
277
+ grid_thws: torch.Tensor) -> torch.Tensor:
278
+ pos_embs = []
279
+ for t, h, w in grid_thws.tolist():
280
+ assert t <= self.num_frames, f't:{t} > self.num_frames:{self.num_frames}'
281
+ if (h, w) == self.weight.shape[:-1]:
282
+ pos_emb_2d = self.weight.flatten(end_dim=1)
283
+ else:
284
+ pos_emb_2d = get_rope_shape(
285
+ self.weight,
286
+ interpolation_mode=self.interpolation_mode,
287
+ shape=(h, w),
288
+ )
289
+
290
+ if t == 1:
291
+ pos_emb_3d = pos_emb_2d
292
+ else:
293
+ pos_emb_3d = pos_emb_2d.unsqueeze(0).repeat(
294
+ t, 1, 1) + self.time_weight[0:t]
295
+
296
+ pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1]))
297
+
298
+ out = x + torch.cat(pos_embs)
299
+ return out
300
+
301
+
302
+ class MoonVision3dPatchEmbed(nn.Module):
303
+
304
+ def __init__(self,
305
+ out_dim: int,
306
+ in_dim: int = 3,
307
+ patch_size: int | tuple[int, int] = (14, 14),
308
+ pos_emb_height: int = 14,
309
+ pos_emb_width: int = 14,
310
+ pos_emb_time: int = 4,
311
+ pos_emb_type: str = 'divided_fixed'):
312
+ super().__init__()
313
+ assert isinstance(
314
+ patch_size,
315
+ int | Sequence), f'Invalid patch_size type: {type(patch_size)}'
316
+ if isinstance(patch_size, int):
317
+ patch_size = (patch_size, patch_size)
318
+ assert (len(patch_size) == 2
319
+ ), f'Expected patch_size to be a tuple of 2, got {patch_size}'
320
+ self.patch_size = patch_size
321
+
322
+ self.proj = nn.Conv2d(in_dim,
323
+ out_dim,
324
+ kernel_size=patch_size,
325
+ stride=patch_size)
326
+
327
+ if pos_emb_type == 'divided_fixed':
328
+ self.pos_emb = Learnable2DInterpPosEmbDivided_fixed(
329
+ height=pos_emb_height,
330
+ width=pos_emb_width,
331
+ num_frames=pos_emb_time,
332
+ dim=out_dim)
333
+ else:
334
+ raise NotImplementedError(
335
+ f'Not support pos_emb_type: {pos_emb_type}')
336
+
337
+ def forward(self, x: torch.Tensor,
338
+ grid_thws: torch.Tensor) -> torch.Tensor:
339
+ """
340
+ Args:
341
+ x (L, Channels): input tensor
342
+ grid_hws (N, 3): temporal, height and width
343
+
344
+ Returns:
345
+ (L, Cout) tensor
346
+ """
347
+ x = self.proj(x).view(x.size(0), -1)
348
+ # apply positional embedding
349
+ x = self.pos_emb(x, grid_thws)
350
+ return x
351
+
352
+
353
+ class Rope2DPosEmbRepeated(nn.Module):
354
+ """2D rotary position embedding with multi-resolution support.
355
+
356
+ This class is intended to be used in the following way:
357
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
358
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
359
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
360
+ The rope is shared across all attention layers and all heads.
361
+
362
+ Refs:
363
+ - RoFormer: https://arxiv.org/abs/2104.09864
364
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
365
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
366
+
367
+ Args:
368
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
369
+ max_height (int): the maximum height of the 2D grid
370
+ max_width (int): the maximum width of the 2D grid
371
+ theta_base (float): the base of the theta
372
+ device (str): the device to store the precomputed cis
373
+ """
374
+
375
+ def __init__(self,
376
+ dim: int,
377
+ max_height: int,
378
+ max_width: int,
379
+ theta_base=10000):
380
+ super().__init__()
381
+ self.dim = dim
382
+ assert self.dim % 4 == 0, 'dim must be divisible by 4'
383
+ self.max_height = max_height
384
+ self.max_width = max_width
385
+ self.theta_base = theta_base
386
+
387
+ def extra_repr(self):
388
+ return f'dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}'
389
+
390
+ def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
391
+ """Calculate the cis(freqs) for each position in the 2D grid.
392
+
393
+ Return: complex tensor of shape (max_height, max_width, dim//2) and value:
394
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
395
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
396
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
397
+ """
398
+ N = self.max_height * self.max_width
399
+ flat_pos = torch.arange(0, N).float().to(device)
400
+ x_pos = flat_pos % self.max_width
401
+ y_pos = flat_pos // self.max_width
402
+ dim_range = (torch.arange(0, self.dim,
403
+ 4)[:(self.dim // 4)].float().to(device)
404
+ ) # C/4
405
+ freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
406
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
407
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
408
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
409
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
410
+ # N, C/4, 2
411
+ freqs_cis = torch.cat(
412
+ [x_cis.unsqueeze(dim=-1),
413
+ y_cis.unsqueeze(dim=-1)], dim=-1)
414
+ # max_height, max_width, C/2
415
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
416
+ return freqs_cis
417
+
418
+ def get_freqs_cis(self, grid_thws: torch.Tensor,
419
+ device: torch.device) -> torch.Tensor:
420
+ """
421
+ Args:
422
+ grid_thws (torch.Tensor): grid time, height and width
423
+
424
+ Returns:
425
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
426
+ """
427
+ if not hasattr(self, 'freqs_cis'):
428
+ self.register_buffer('freqs_cis',
429
+ self._precompute_freqs_cis(device),
430
+ persistent=False)
431
+
432
+ shapes = grid_thws.tolist()
433
+ assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
434
+ for t, h, w in shapes), (
435
+ shapes,
436
+ self.max_height,
437
+ self.max_width,
438
+ )
439
+ freqs_cis = torch.cat(
440
+ [
441
+ self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1)
442
+ for t, h, w in shapes
443
+ ],
444
+ dim=0,
445
+ )
446
+ return freqs_cis
447
+
448
+
449
+ class MLP2(nn.Module):
450
+ """
451
+ Args:
452
+ dims: [in_dim, hidden_dim, out_dim]
453
+ bias: whether to use bias in linear layer.
454
+ """
455
+
456
+ def __init__(self, dims: list[int], activation, bias=True):
457
+ super().__init__()
458
+ assert len(dims) == 3
459
+ self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
460
+ self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
461
+ self.activation = activation
462
+ for m in [self.fc0, self.fc1]:
463
+ nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
464
+ if m.bias is not None:
465
+ nn.init.zeros_(m.bias)
466
+
467
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
468
+ x = self.fc0(x)
469
+ x = self.activation(x)
470
+ return self.fc1(x)
471
+
472
+
473
+ class MoonViTEncoderLayer(nn.Module):
474
+
475
+ def __init__(
476
+ self,
477
+ num_heads: int,
478
+ hidden_dim: int,
479
+ mlp_dim: int,
480
+ *,
481
+ attn_implementation: str = 'flash_attention_2',
482
+ activation=F.gelu,
483
+ attn_bias: bool = False,
484
+ use_deterministic_attn: bool = False,
485
+ ):
486
+ super().__init__()
487
+ self.num_heads = num_heads
488
+ self.hidden_dim = hidden_dim
489
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
490
+ self.attn_implementation = attn_implementation
491
+ self.use_deterministic_attn = use_deterministic_attn
492
+
493
+ self.norm0 = nn.LayerNorm(hidden_dim)
494
+ self.norm1 = nn.LayerNorm(hidden_dim)
495
+ self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
496
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
497
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
498
+
499
+ def attention_qkvpacked(
500
+ self,
501
+ x: torch.Tensor,
502
+ cu_seqlens: torch.Tensor,
503
+ max_seqlen: torch.Tensor,
504
+ rope_freqs_cis: torch.Tensor | None = None,
505
+ ):
506
+ """
507
+ Args:
508
+ x (torch.Tensor): (batch_size, seqlen, hidden_dim)
509
+ cu_seqlens (torch.Tensor):
510
+ """
511
+ xqkv = self.wqkv(x)
512
+
513
+ qkv_shape = xqkv.size()[:-1] + (
514
+ 3,
515
+ self.num_heads,
516
+ self.hidden_size_per_attention_head,
517
+ )
518
+ # xqkv: (batch_size, seqlen, 3, nheads, headdim)
519
+ xqkv = xqkv.view(*qkv_shape)
520
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
521
+
522
+ xq, xk = apply_rope(xq, xk, rope_freqs_cis)
523
+
524
+ attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
525
+ attn_out = attn_func(xq,
526
+ xk,
527
+ xv,
528
+ q_cu_seqlens=cu_seqlens,
529
+ k_cu_seqlens=cu_seqlens,
530
+ max_seqlen_k=max_seqlen,
531
+ max_seqlen_q=max_seqlen,
532
+ deterministic=self.use_deterministic_attn)
533
+
534
+ attn_out = self.wo(attn_out)
535
+ return attn_out
536
+
537
+ def forward(
538
+ self,
539
+ hidden_states: torch.Tensor,
540
+ cu_seqlens: torch.Tensor,
541
+ max_seqlen: int,
542
+ rope_freqs_cis: torch.Tensor | None = None,
543
+ ):
544
+ residual = hidden_states
545
+ hidden_states = self.norm0(hidden_states)
546
+
547
+ hidden_states = self.attention_qkvpacked(hidden_states, cu_seqlens,
548
+ max_seqlen, rope_freqs_cis)
549
+ hidden_states = residual + hidden_states
550
+
551
+ residual = hidden_states
552
+ hidden_states = self.norm1(hidden_states)
553
+ hidden_states = self.mlp(hidden_states)
554
+ hidden_states = residual + hidden_states
555
+
556
+ return hidden_states
557
+
558
+
559
+ class MoonViT3dEncoder(nn.Module):
560
+
561
+ def __init__(self,
562
+ hidden_dim: int,
563
+ num_layers: int,
564
+ block_cfg: dict,
565
+ video_attn_type: str = 'spatial_temporal') -> None:
566
+ super().__init__()
567
+
568
+ assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
569
+ self.video_attn_type = video_attn_type
570
+ self.rope_2d = Rope2DPosEmbRepeated(
571
+ block_cfg['hidden_dim'] // block_cfg['num_heads'], 512, 512)
572
+ self.blocks = nn.ModuleList([
573
+ MoonViTEncoderLayer(
574
+ **block_cfg,
575
+ use_deterministic_attn=self.use_deterministic_attn)
576
+ for _ in range(num_layers)
577
+ ])
578
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
579
+
580
+ def forward(
581
+ self,
582
+ hidden_states: torch.Tensor,
583
+ grid_thws: torch.Tensor,
584
+ ) -> torch.Tensor:
585
+ rope_freqs_cis = self.rope_2d.get_freqs_cis(
586
+ grid_thws=grid_thws, device=hidden_states.device)
587
+
588
+ lengths = torch.cat((
589
+ torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device),
590
+ grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2],
591
+ ))
592
+
593
+ max_seqlen = lengths.max()
594
+ cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0,
595
+ dtype=torch.int32)
596
+ for block in self.blocks:
597
+ hidden_states = block(hidden_states,
598
+ cu_seqlens,
599
+ max_seqlen,
600
+ rope_freqs_cis=rope_freqs_cis)
601
+
602
+ hidden_states = self.final_layernorm(hidden_states)
603
+ return hidden_states
604
+
605
+
606
+ def tpool_patch_merger(
607
+ x: torch.Tensor,
608
+ grid_thws: torch.Tensor,
609
+ merge_kernel_size: tuple[int, int] = (2, 2),
610
+ ) -> list[torch.Tensor]:
611
+ d_model = x.size(-1)
612
+
613
+ outputs = []
614
+ pre_sum = 0
615
+ for t, h, w in grid_thws.tolist():
616
+ # Get the current sequence
617
+ seq = x[pre_sum:pre_sum + t * h * w]
618
+ # Reshape along self.merge_kernel_size and concat to the last dimension
619
+ kernel_height, kernel_width = merge_kernel_size
620
+ new_height, new_width = h // kernel_height, w // kernel_width
621
+ reshaped_seq = seq.view(t, new_height, kernel_height, new_width,
622
+ kernel_width, d_model)
623
+ reshaped_seq = reshaped_seq.permute(0, 1,
624
+ 3, 2, 4, 5).contiguous().mean(
625
+ dim=0) # temporal pooling
626
+ padded_seq = reshaped_seq.view(new_height * new_width,
627
+ kernel_height * kernel_width, -1)
628
+ outputs.append(padded_seq)
629
+ pre_sum += t * h * w
630
+
631
+ return outputs
632
+
633
+
634
+ class MoonViT3dPretrainedModel(PreTrainedModel):
635
+ config_class = None
636
+ model_type = 'moonvit3d'
637
+ _no_split_modules = ['PackingTransformer']
638
+ _supports_flash_attn_2 = True
639
+ _supports_sdpa = True
640
+
641
+ def __init__(self, config, *inputs, **kwargs):
642
+ super().__init__(config, *inputs, **kwargs)
643
+ config = deepcopy(config)
644
+ self.merge_kernel_size = config.merge_kernel_size
645
+ self.patch_size = config.patch_size
646
+ self.merge_type = config.merge_type
647
+
648
+ self.patch_embed = MoonVision3dPatchEmbed(
649
+ out_dim=config.hidden_size,
650
+ patch_size=config.patch_size,
651
+ pos_emb_height=config.init_pos_emb_height,
652
+ pos_emb_width=config.init_pos_emb_width,
653
+ pos_emb_time=config.init_pos_emb_time,
654
+ pos_emb_type=config.pos_emb_type,
655
+ )
656
+
657
+ self.encoder = MoonViT3dEncoder(hidden_dim=config.hidden_size,
658
+ num_layers=config.num_hidden_layers,
659
+ block_cfg={
660
+ 'num_heads':
661
+ config.num_attention_heads,
662
+ 'hidden_dim':
663
+ config.hidden_size,
664
+ 'mlp_dim':
665
+ config.intermediate_size,
666
+ 'activation':
667
+ PytorchGELUTanh(),
668
+ 'attn_bias':
669
+ True,
670
+ 'attn_implementation':
671
+ config._attn_implementation,
672
+ },
673
+ video_attn_type=config.video_attn_type)
674
+
675
+ def forward(self, pixel_values: torch.Tensor,
676
+ grid_thws: torch.Tensor) -> torch.Tensor:
677
+ """
678
+ Args:
679
+ pixel_values (torch.Tensor): The input pixel values.
680
+ grid_thws (torch.Tensor): Temporal, height and width.
681
+
682
+ Returns:
683
+ torch.Tensor: The output tokens.
684
+ """
685
+ # grid_thws = grid_thws.to('cpu')
686
+ assert grid_thws.ndim == 2, f'grid_thws should be 2D, got {grid_thws.ndim}'
687
+ assert grid_thws.size(1) == 3, f'No support for thw: {grid_thws}'
688
+ hidden_states = self.patch_embed(pixel_values, grid_thws)
689
+ hidden_states = self.encoder(hidden_states, grid_thws)
690
+ if self.merge_type == 'sd2_tpool': # spatial downsampling 2x with temporal pooling all
691
+ hidden_states = tpool_patch_merger(
692
+ hidden_states,
693
+ grid_thws,
694
+ merge_kernel_size=self.merge_kernel_size)
695
+ else:
696
+ raise NotImplementedError(f'Not support {self.merge_type}')
697
+
698
+ return hidden_states
699
+
700
+
701
+ # ============================================================================
702
+ # MM Projector Helper Classes (from mm_projector/modeling_mm_projectors.py)
703
+ # ============================================================================
704
+
705
+
706
+ class IdentityMap(nn.Module):
707
+
708
+ def __init__(self):
709
+ super().__init__()
710
+
711
+ def forward(self, x, *args, **kwargs):
712
+ return x
713
+
714
+
715
+ class MLP(nn.Module):
716
+
717
+ def __init__(self, config):
718
+ super().__init__()
719
+ # TODO, use faster LayerNorm
720
+ self.pre_norm = nn.LayerNorm(config.mm_hidden_size)
721
+ self.proj = nn.Sequential(
722
+ nn.Linear(config.mm_hidden_size, config.hidden_size), nn.GELU(),
723
+ nn.Linear(config.hidden_size, config.hidden_size))
724
+
725
+ def forward(self, x, *args, **kwargs):
726
+ assert isinstance(x,
727
+ list | tuple), f'x is not a list or tuple: {type(x)}'
728
+ lengths = [item.shape[0] for item in x]
729
+ x = torch.cat(x, dim=0)
730
+ x = self.pre_norm(x)
731
+ x = self.proj(x)
732
+ x = torch.split(x, lengths, dim=0)
733
+
734
+ return x
735
+
736
+
737
+ class PatchMergerMLP(nn.Module):
738
+
739
+ def __init__(self, config):
740
+ super().__init__()
741
+ eps = config.projector_ln_eps
742
+ self.hidden_size = config.mm_hidden_size * (
743
+ config.merge_kernel_size[0] * config.merge_kernel_size[1])
744
+ self.pre_norm = nn.LayerNorm(config.mm_hidden_size, eps=eps)
745
+ self.proj = nn.Sequential(
746
+ nn.Linear(self.hidden_size, self.hidden_size),
747
+ nn.GELU(),
748
+ nn.Linear(self.hidden_size, config.hidden_size),
749
+ )
750
+
751
+ def forward(self, x, *args, **kwargs):
752
+ if isinstance(x, list) or isinstance(x, tuple):
753
+ x = [
754
+ self.proj(self.pre_norm(item).view(item.shape[0], -1))
755
+ for item in x
756
+ ]
757
+ else:
758
+ # B, N, N_k, C = x.shape
759
+ B = x.shape[0]
760
+ x = self.proj(self.pre_norm(x).view(B, -1, self.hidden_size))
761
+ return x
762
+
763
+
764
+ class KimiK25PreTrainedModel(PreTrainedModel):
765
+ config_class = KimiK25Config
766
+ base_model_prefix = "model"
767
+ _no_split_modules = [
768
+ "MoonViT3dPretrainedModel",
769
+ "MoonViTEncoderLayer",
770
+ "DeepseekDecoderLayer",
771
+ "PatchMergerMLP",
772
+ ]
773
+ _skip_keys_device_placement = "past_key_values"
774
+ _supports_flash_attn_2 = True
775
+ _supports_sdpa = False
776
+
777
+ def _init_weights(self, module):
778
+ # important: this ported version of Llava isn't meant for training from scratch - only
779
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
780
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
781
+ std = (self.config.initializer_range if hasattr(
782
+ self.config, "initializer_range") else
783
+ self.config.text_config.initializer_range)
784
+
785
+ if hasattr(module, "class_embedding"):
786
+ module.class_embedding.data.normal_(mean=0.0, std=std)
787
+
788
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
789
+ module.weight.data.normal_(mean=0.0, std=std)
790
+ if module.bias is not None:
791
+ module.bias.data.zero_()
792
+ elif isinstance(module, nn.Embedding):
793
+ module.weight.data.normal_(mean=0.0, std=std)
794
+ if module.padding_idx is not None:
795
+ module.weight.data[module.padding_idx].zero_()
796
+
797
+
798
+ class VisionTowerConfig(PretrainedConfig):
799
+ model_type = 'moonvit3d'
800
+
801
+ def __init__(self, config: KimiK25Config, **kwargs):
802
+ super().__init__(**kwargs)
803
+ self.patch_size = config.patch_size
804
+ self.init_pos_emb_height = config.init_pos_emb_height
805
+ self.init_pos_emb_width = config.init_pos_emb_width
806
+ self.init_pos_emb_time = config.init_pos_emb_time
807
+ self.pos_emb_type = config.pos_emb_type
808
+ self.num_attention_heads = config.vt_num_attention_heads
809
+ self.num_hidden_layers = config.vt_num_hidden_layers
810
+ self.hidden_size = config.vt_hidden_size
811
+ self.intermediate_size = config.vt_intermediate_size
812
+ self.merge_kernel_size = config.merge_kernel_size
813
+ self.video_attn_type = config.video_attn_type
814
+ self.merge_type = config.merge_type
815
+ self._attn_implementation = config._attn_implementation
816
+
817
+
818
+ class ProjectorConfig:
819
+
820
+ def __init__(self, config: KimiK25Config):
821
+ self.mm_projector_type = config.mm_projector_type
822
+ self.mm_hidden_size = config.mm_hidden_size
823
+ self.hidden_size = config.text_hidden_size
824
+ self.merge_kernel_size = config.merge_kernel_size
825
+ self.projector_hidden_act = config.projector_hidden_act
826
+ self.projector_ln_eps = config.projector_ln_eps
827
+
828
+
829
+ # ref https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/llava/modeling_llava.py#L240
830
+ class KimiK25ForConditionalGeneration(KimiK25PreTrainedModel):
831
+
832
+ def __init__(self, config: KimiK25Config):
833
+ super().__init__(config)
834
+
835
+ vt_config = VisionTowerConfig(config.vision_config)
836
+ self.vision_tower = MoonViT3dPretrainedModel(vt_config)
837
+
838
+ proj_config = ProjectorConfig(config.vision_config)
839
+ if proj_config.mm_projector_type == 'identity':
840
+ self.mm_projector = IdentityMap()
841
+ elif proj_config.mm_projector_type == 'mlp':
842
+ self.mm_projector = MLP(proj_config)
843
+ elif proj_config.mm_projector_type == 'patchmerger':
844
+ self.mm_projector = PatchMergerMLP(proj_config)
845
+ else:
846
+ raise ValueError(
847
+ f"Unsupported mm_projector_type: {proj_config.mm_projector_type}"
848
+ )
849
+
850
+ self.language_model = DeepseekV3ForCausalLM(config.text_config)
851
+ self.post_init()
852
+
853
+ if hasattr(self.language_model, 'dtype'):
854
+ target_dtype = self.language_model.dtype
855
+ self.vision_tower = self.vision_tower.to(dtype=target_dtype)
856
+ self.mm_projector = self.mm_projector.to(dtype=target_dtype)
857
+
858
+ def get_input_embeddings(self):
859
+ return self.language_model.get_input_embeddings()
860
+
861
+ def set_input_embeddings(self, value):
862
+ self.language_model.set_input_embeddings(value)
863
+
864
+ def get_output_embeddings(self):
865
+ return self.language_model.get_output_embeddings()
866
+
867
+ def set_output_embeddings(self, new_embeddings):
868
+ self.language_model.set_output_embeddings(new_embeddings)
869
+
870
+ def set_decoder(self, decoder):
871
+ self.language_model.set_decoder(decoder)
872
+
873
+ def get_decoder(self):
874
+ return self.language_model.get_decoder()
875
+
876
+ def tie_weights(self):
877
+ return self.language_model.tie_weights()
878
+
879
+ def resize_token_embeddings(self,
880
+ new_num_tokens: int | None = None,
881
+ pad_to_multiple_of=None) -> nn.Embedding:
882
+ model_embeds = self.language_model.resize_token_embeddings(
883
+ new_num_tokens, pad_to_multiple_of)
884
+ # update vocab size
885
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
886
+ self.vocab_size = model_embeds.num_embeddings
887
+ return model_embeds
888
+
889
+ def _merge_input_ids_with_image_features(
890
+ self,
891
+ image_features: list[torch.Tensor],
892
+ inputs_embeds: torch.Tensor,
893
+ input_ids: torch.Tensor,
894
+ attention_mask: torch.Tensor,
895
+ labels: torch.Tensor | None = None,
896
+ ):
897
+ """
898
+ Args:
899
+ image_features (:obj:`torch.Tensor` of shape :obj:`(num_image_tokens, embed_dim)`):
900
+ The image features to merge with the input embeddings.
901
+ inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, embed_dim)`):
902
+ The input embeddings.
903
+ input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
904
+ The input ids.
905
+ attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
906
+ The attention mask.
907
+ labels (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, *optional*):
908
+ The labels.
909
+ """
910
+ _, embed_dim = image_features[0].shape
911
+ feature_lengths = [x.shape[0] for x in image_features]
912
+ image_features = torch.cat(image_features, dim=0)
913
+
914
+ image_token_index: int = self.config.media_placeholder_token_id
915
+ pad_token_id: int = self.config.pad_token_id
916
+ ignore_index: int = self.config.ignore_index
917
+
918
+ batch_size, sequence_length = input_ids.shape
919
+ left_padding = not torch.sum(
920
+ input_ids[:, -1] == torch.tensor(pad_token_id))
921
+
922
+ # 1. Create a mask to know where special image tokens are
923
+ _token_occupation_table = torch.ones_like(input_ids.flatten())
924
+ _token_occupation_table[input_ids.flatten() ==
925
+ image_token_index] = torch.tensor(
926
+ feature_lengths,
927
+ dtype=torch.long,
928
+ device=input_ids.device)
929
+ _token_occupation_table = _token_occupation_table.reshape(
930
+ input_ids.shape)
931
+
932
+ max_embed_dim = _token_occupation_table.sum(-1).max().item()
933
+ assert (
934
+ max_embed_dim >= sequence_length
935
+ ), f"The maximum embedding dimension ({max_embed_dim}) is less than the sequence length ({sequence_length})"
936
+ batch_indices, non_image_indices = torch.where(
937
+ input_ids != image_token_index)
938
+
939
+ # 2. Compute the positions where text should be written
940
+ # Calculate new positions for text tokens in merged image-text sequence.
941
+ new_token_positions = torch.cumsum(_token_occupation_table, -1) - 1
942
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
943
+ if left_padding:
944
+ new_token_positions += nb_image_pad[:,
945
+ None] # offset for left padding
946
+ text_to_overwrite = new_token_positions[batch_indices,
947
+ non_image_indices]
948
+
949
+ # 3. Create the full embedding, already padded to the maximum position
950
+ final_embedding = torch.zeros(
951
+ batch_size,
952
+ max_embed_dim,
953
+ embed_dim,
954
+ dtype=inputs_embeds.dtype,
955
+ device=inputs_embeds.device,
956
+ )
957
+ final_attention_mask = torch.zeros(batch_size,
958
+ max_embed_dim,
959
+ dtype=attention_mask.dtype,
960
+ device=inputs_embeds.device)
961
+ if labels is not None:
962
+ final_labels = torch.full(
963
+ (batch_size, max_embed_dim),
964
+ ignore_index,
965
+ dtype=input_ids.dtype,
966
+ device=input_ids.device,
967
+ )
968
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
969
+ # set the corresponding tensors into their correct target device.
970
+ target_device = inputs_embeds.device
971
+ batch_indices, non_image_indices, text_to_overwrite = (
972
+ batch_indices.to(target_device),
973
+ non_image_indices.to(target_device),
974
+ text_to_overwrite.to(target_device),
975
+ )
976
+ attention_mask = attention_mask.to(target_device)
977
+
978
+ # 4. Fill the embeddings based on the mask.
979
+ final_embedding[batch_indices,
980
+ text_to_overwrite] = inputs_embeds[batch_indices,
981
+ non_image_indices]
982
+ final_attention_mask[batch_indices,
983
+ text_to_overwrite] = attention_mask[
984
+ batch_indices, non_image_indices]
985
+ if labels is not None:
986
+ final_labels[batch_indices,
987
+ text_to_overwrite] = labels[batch_indices,
988
+ non_image_indices]
989
+
990
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
991
+ image_to_overwrite = torch.full((batch_size, max_embed_dim),
992
+ True,
993
+ dtype=torch.bool,
994
+ device=inputs_embeds.device)
995
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
996
+ image_to_overwrite &= image_to_overwrite.cumsum(
997
+ -1) - 1 >= nb_image_pad[:, None].to(target_device)
998
+
999
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
1000
+ raise ValueError(
1001
+ f"The input provided to the model are wrong. The number of image tokens is {image_to_overwrite.sum()} while"
1002
+ f" the number of image features given to the model is {image_features.shape[:-1].numel()}. "
1003
+ "This prevents correct indexing and breaks batch generation.")
1004
+
1005
+ final_embedding[image_to_overwrite] = (
1006
+ image_features.contiguous().reshape(-1,
1007
+ embed_dim).to(target_device))
1008
+ final_attention_mask |= image_to_overwrite
1009
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
1010
+ (final_attention_mask == 0), 1)
1011
+
1012
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
1013
+ batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
1014
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
1015
+
1016
+ final_embedding[batch_indices, indices_to_mask] = 0
1017
+
1018
+ if labels is None:
1019
+ final_labels = None
1020
+
1021
+ return final_embedding, final_attention_mask, final_labels, position_ids
1022
+
1023
+ def _extract_image_features(self, pixel_values: torch.Tensor,
1024
+ grid_thws: torch.Tensor) -> list[torch.Tensor]:
1025
+ """
1026
+ Args:
1027
+ pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
1028
+ The pixel values of the images processed by image processor.
1029
+ grid_thws (:obj:`torch.Tensor` of shape :obj:`(batch_size, 3)`):
1030
+ The grid, height, width of the images.
1031
+
1032
+ Returns:
1033
+ selected_image_feature (:obj:`torch.FloatTensor` of shape :obj:`(num_image_tokens, embed_dim)`):
1034
+ The selected image features to use as input to the projector head.
1035
+
1036
+ """
1037
+
1038
+ target_dtype = self.vision_tower.patch_embed.proj.weight.dtype
1039
+ pixel_values = pixel_values.to(target_dtype)
1040
+
1041
+ image_features = self.vision_tower(pixel_values, grid_thws)
1042
+ return image_features
1043
+
1044
+ def forward(
1045
+ self,
1046
+ input_ids: torch.LongTensor | None = None,
1047
+ pixel_values: torch.FloatTensor | list[torch.FloatTensor]
1048
+ | None = None,
1049
+ grid_thws: torch.Tensor | None = None,
1050
+ attention_mask: torch.Tensor | None = None,
1051
+ position_ids: torch.LongTensor | None = None,
1052
+ past_key_values: list[torch.FloatTensor] | None = None,
1053
+ inputs_embeds: torch.FloatTensor | None = None,
1054
+ labels: torch.LongTensor | None = None,
1055
+ use_cache: bool | None = None,
1056
+ output_attentions: bool | None = None,
1057
+ output_hidden_states: bool | None = None,
1058
+ return_dict: bool | None = None,
1059
+ ) -> tuple | LlavaCausalLMOutputWithPast:
1060
+ r"""
1061
+ Args:
1062
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1063
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1064
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1065
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1066
+
1067
+ ```"""
1068
+ assert self.vision_tower is not None, "vision_tower is not loaded"
1069
+ output_attentions = (output_attentions if output_attentions is not None
1070
+ else self.config.output_attentions)
1071
+ output_hidden_states = (output_hidden_states
1072
+ if output_hidden_states is not None else
1073
+ self.config.output_hidden_states)
1074
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1075
+
1076
+ if inputs_embeds is None:
1077
+ # 1. Extra the input embeddings
1078
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1079
+
1080
+ # 2. Merge text and images
1081
+ if pixel_values is not None and len(
1082
+ pixel_values) > 0 and input_ids.shape[1] != 1:
1083
+ image_features = self._extract_image_features(
1084
+ pixel_values, grid_thws)
1085
+ if self.mm_projector:
1086
+ image_features = self.mm_projector(image_features)
1087
+
1088
+ inputs_embeds = inputs_embeds.to(
1089
+ image_features[0].dtype) # num_tokens, embed_dim
1090
+ inputs_embeds, attention_mask, labels, position_ids = (
1091
+ self._merge_input_ids_with_image_features(
1092
+ image_features,
1093
+ inputs_embeds,
1094
+ input_ids,
1095
+ attention_mask,
1096
+ labels,
1097
+ ))
1098
+
1099
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1100
+ # generation with cache
1101
+ elif (past_key_values is not None and pixel_values is not None
1102
+ and input_ids.shape[1] == 1):
1103
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
1104
+ # that are set to 0
1105
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
1106
+
1107
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
1108
+ batch_index, non_attended_tokens = torch.where(
1109
+ first_layer_past_key_value.float().sum(-2) == 0)
1110
+
1111
+ # Get the target length
1112
+ target_length = input_ids.shape[1]
1113
+ past_length = first_layer_past_key_value.shape[-1]
1114
+
1115
+ extended_attention_mask = torch.ones(
1116
+ (attention_mask.shape[0], past_length),
1117
+ dtype=attention_mask.dtype,
1118
+ device=attention_mask.device,
1119
+ )
1120
+
1121
+ # Filter out only the tokens that can be un-attended, this can happen
1122
+ # if one uses Llava + Fused modules where the cache on the
1123
+ # first iteration is already big enough, or if one passes custom cache
1124
+ valid_indices = non_attended_tokens < extended_attention_mask.size(
1125
+ -1)
1126
+ new_batch_index = batch_index[valid_indices]
1127
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
1128
+
1129
+ # Zero-out the places where we don't need to attend
1130
+ extended_attention_mask[new_batch_index,
1131
+ new_non_attended_tokens] = 0
1132
+
1133
+ attention_mask = torch.cat(
1134
+ (extended_attention_mask, attention_mask[:,
1135
+ -target_length:]),
1136
+ dim=1)
1137
+ position_ids = torch.sum(attention_mask,
1138
+ dim=1).unsqueeze(-1) - 1
1139
+
1140
+ outputs = self.language_model(
1141
+ attention_mask=attention_mask,
1142
+ position_ids=position_ids,
1143
+ past_key_values=past_key_values,
1144
+ inputs_embeds=inputs_embeds,
1145
+ use_cache=use_cache,
1146
+ output_attentions=output_attentions,
1147
+ output_hidden_states=output_hidden_states,
1148
+ return_dict=return_dict,
1149
+ )
1150
+
1151
+ logits = outputs[0]
1152
+
1153
+ loss = None
1154
+ if labels is not None:
1155
+ # Shift so that tokens < n predict n
1156
+ if attention_mask is not None:
1157
+ shift_attention_mask = attention_mask[..., 1:]
1158
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(
1159
+ logits.device) != 0].contiguous()
1160
+ shift_labels = labels[..., 1:][shift_attention_mask.to(
1161
+ labels.device) != 0].contiguous()
1162
+ else:
1163
+ shift_logits = logits[..., :-1, :].contiguous()
1164
+ shift_labels = labels[..., 1:].contiguous()
1165
+ # Flatten the tokens
1166
+ loss_fct = nn.CrossEntropyLoss()
1167
+ loss = loss_fct(
1168
+ shift_logits.view(-1, shift_logits.size(-1)),
1169
+ shift_labels.view(-1).to(shift_logits.device),
1170
+ )
1171
+
1172
+ if not return_dict:
1173
+ output = (logits, ) + outputs[1:]
1174
+ return (loss, ) + output if loss is not None else output
1175
+
1176
+ return LlavaCausalLMOutputWithPast(
1177
+ loss=loss,
1178
+ logits=logits,
1179
+ past_key_values=outputs.past_key_values,
1180
+ hidden_states=outputs.hidden_states,
1181
+ attentions=outputs.attentions,
1182
+ )
1183
+
1184
+ def prepare_inputs_for_generation(
1185
+ self,
1186
+ input_ids,
1187
+ past_key_values=None,
1188
+ inputs_embeds=None,
1189
+ pixel_values=None,
1190
+ grid_thws=None,
1191
+ attention_mask=None,
1192
+ **kwargs,
1193
+ ):
1194
+ if past_key_values is not None:
1195
+ if isinstance(past_key_values, Cache):
1196
+ cache_length = past_key_values.get_seq_length()
1197
+ past_length = getattr(past_key_values, 'seen_tokens',
1198
+ cache_length)
1199
+ else:
1200
+ cache_length = past_length = past_key_values[0][0].shape[2]
1201
+
1202
+ # Keep only the unprocessed tokens:
1203
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1204
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1205
+ # input)
1206
+ if attention_mask is not None and attention_mask.shape[
1207
+ 1] > input_ids.shape[1]:
1208
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
1209
+ past_length):]
1210
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1211
+ # input_ids based on the past_length.
1212
+ elif past_length < input_ids.shape[1]:
1213
+ input_ids = input_ids[:, past_length:]
1214
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1215
+ elif self.config.media_placeholder_token_id in input_ids:
1216
+ input_ids = input_ids[:, input_ids.shape[1] - 1:]
1217
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1218
+ # older attention values, as their corresponding values are not part of the input.
1219
+ if cache_length < past_length and attention_mask is not None:
1220
+ attention_mask = attention_mask[:, -(cache_length +
1221
+ input_ids.shape[1]):]
1222
+
1223
+ position_ids = kwargs.get("position_ids", None)
1224
+ if attention_mask is not None and position_ids is None:
1225
+ # create position_ids on the fly for batch generation
1226
+ position_ids = attention_mask.long().cumsum(-1) - 1
1227
+ position_ids.masked_fill_(attention_mask == 0, 1)
1228
+ if past_key_values:
1229
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1230
+
1231
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1232
+ if inputs_embeds is not None and past_key_values is None:
1233
+ model_inputs = {"inputs_embeds": inputs_embeds}
1234
+ else:
1235
+ model_inputs = {"input_ids": input_ids}
1236
+
1237
+ model_inputs.update({
1238
+ "position_ids": position_ids,
1239
+ "past_key_values": past_key_values,
1240
+ "use_cache": kwargs.get("use_cache"),
1241
+ "attention_mask": attention_mask,
1242
+ "pixel_values": pixel_values,
1243
+ "grid_thws": grid_thws,
1244
+ })
1245
+ return model_inputs
1246
+
1247
+ def _reorder_cache(self, *args, **kwargs):
1248
+ return self.language_model._reorder_cache(*args, **kwargs)
preprocessor_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "kimi_k25_processor.KimiK25Processor",
4
+ "AutoImageProcessor": "kimi_k25_vision_processing.KimiK25VisionProcessor"
5
+ },
6
+ "media_proc_cfg": {
7
+ "in_patch_limit": 16384,
8
+ "patch_size": 14,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "merge_kernel_size": 2,
20
+ "fixed_output_tokens": null,
21
+ "patch_limit_on_one_side": 512,
22
+ "in_patch_limit_each_frame": 4096,
23
+ "in_patch_limit_video": null,
24
+ "sample_fps": 2.0,
25
+ "max_num_frames_each_video": null,
26
+ "temporal_merge_kernel_size": 4,
27
+ "timestamp_mode": "hh:mm:ss.fff",
28
+ "config_type": "media_proc.processors.moonvit.MoonViTMediaProcessorConfig"
29
+ }
30
+ }
tiktoken.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103
3
+ size 2795286
tokenization_kimi.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from shutil import copyfile
6
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
+
8
+ import tiktoken
9
+ from tiktoken.load import load_tiktoken_bpe
10
+ from tokenizers import AddedToken
11
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
12
+ from transformers.tokenization_utils import PreTrainedTokenizer
13
+
14
+ from .tool_declaration_ts import encode_tools_to_typescript_style
15
+
16
+ logger = getLogger(__name__)
17
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
18
+
19
+
20
+ class TikTokenTokenizer(PreTrainedTokenizer):
21
+ """
22
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
23
+
24
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
25
+ this superclass for more information regarding those methods.
26
+
27
+ Args:
28
+ vocab_file (`str`):
29
+ The path to the Tiktoken model file.
30
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
31
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
32
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
33
+ The end of sequence token.
34
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
35
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
36
+ token instead. The second to last item in special_tokens.
37
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
38
+ The token used for padding, for example when batching sequences of different lengths.
39
+ additional_special_tokens (list of `str`, *optional*):
40
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
41
+ skipped when decoding if `skip_special_tokens` is set to `True`.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+
48
+ special_tokens: Dict[str, int]
49
+
50
+ num_reserved_special_tokens = 256
51
+
52
+ pat_str = "|".join([
53
+ r"""[\p{Han}]+""",
54
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
55
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
56
+ r"""\p{N}{1,3}""",
57
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
58
+ r"""\s*[\r\n]+""",
59
+ r"""\s+(?!\S)""",
60
+ r"""\s+""",
61
+ ])
62
+
63
+ def __init__(
64
+ self,
65
+ vocab_file,
66
+ bos_token: Union[str, AddedToken] = "[BOS]",
67
+ eos_token: Union[str, AddedToken] = "[EOS]",
68
+ unk_token: Union[str, AddedToken, None] = None,
69
+ pad_token: Union[str, AddedToken, None] = None,
70
+ additional_special_tokens: List[str] = None,
71
+ added_tokens_decoder: Optional[dict] = None,
72
+ **kwargs,
73
+ ):
74
+ assert os.path.isfile(vocab_file), vocab_file
75
+
76
+ if additional_special_tokens is None:
77
+ additional_special_tokens = [
78
+ "<|im_end|>",
79
+ "<|im_user|>",
80
+ "<|im_assistant|>",
81
+ "<|start_header_id|>",
82
+ "<|end_header_id|>",
83
+ "[EOT]",
84
+ "<|im_system|>",
85
+ "<|im_middle|>",
86
+ ]
87
+
88
+ if added_tokens_decoder:
89
+ special_tokens_mapping = {
90
+ i: added_tokens_decoder[i].content
91
+ for i in added_tokens_decoder
92
+ }
93
+ else:
94
+ special_tokens_mapping = {}
95
+
96
+ self.vocab_file = vocab_file
97
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
98
+ num_base_tokens = len(mergeable_ranks)
99
+ self.special_tokens = {
100
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
101
+ for i in range(num_base_tokens, num_base_tokens +
102
+ self.num_reserved_special_tokens)
103
+ }
104
+
105
+ self.model = tiktoken.Encoding(
106
+ name=Path(vocab_file).name,
107
+ pat_str=self.pat_str,
108
+ mergeable_ranks=mergeable_ranks,
109
+ special_tokens=self.special_tokens,
110
+ )
111
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
112
+
113
+ self.n_words: int = self.model.n_vocab
114
+ # BOS / EOS token IDs
115
+ self.bos_id: int = self.special_tokens[str(bos_token)]
116
+ self.eos_id: int = self.special_tokens[str(eos_token)]
117
+ logger.info(
118
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
119
+ )
120
+
121
+ self.pad_id: int = self.special_tokens[str(pad_token)]
122
+ self.unk_id: int = self.special_tokens[str(unk_token)]
123
+
124
+ self.byte_encoder = bytes_to_unicode()
125
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
126
+
127
+ self.decoder = {}
128
+ for i in range(self.n_words):
129
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
130
+ decoding = ''.join([
131
+ self.byte_encoder[ord(char)] for char in
132
+ self.model.decode_single_token_bytes(i).decode('latin-1')
133
+ ])
134
+ self.decoder[i] = decoding
135
+
136
+ self.encoder = {}
137
+ for i in range(self.n_words):
138
+ if i in self.decoder:
139
+ self.encoder[self.decoder[i]] = i
140
+
141
+ self._token_config_cache = OrderedDict()
142
+ self._cache_max_size = 128
143
+
144
+ super().__init__(
145
+ bos_token=bos_token,
146
+ eos_token=eos_token,
147
+ unk_token=unk_token,
148
+ pad_token=pad_token,
149
+ additional_special_tokens=additional_special_tokens,
150
+ added_tokens_decoder=added_tokens_decoder,
151
+ **kwargs,
152
+ )
153
+ self.all_special_ids_set = set(self.all_special_ids)
154
+
155
+ def encode(self,
156
+ text: str,
157
+ allow_special_tokens: bool = True,
158
+ **kwargs) -> List[int]:
159
+ """
160
+ Encodes a string into a list of token IDs.
161
+
162
+ Args:
163
+ text (str): The input string to be encoded.
164
+
165
+ Returns:
166
+ list[int]: A list of token IDs.
167
+ """
168
+ # If there are other args, we should call super().encode because there are a lot of code
169
+ # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
170
+ # NOTE: our encode method is not compatible with the super().encode method,
171
+ # e.g. split_special_tokens' default is True in our encode method.
172
+ if len(kwargs) > 0:
173
+ logger.warning(f"Calling super().encode with {kwargs}")
174
+ return super().encode(text, **kwargs)
175
+
176
+ assert type(text) is str
177
+
178
+ # The tiktoken tokenizer can handle <=400k chars without
179
+ # pyo3_runtime.PanicException.
180
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
181
+
182
+ # https://github.com/openai/tiktoken/issues/195
183
+ # Here we iterate over subsequences and split if we exceed the limit
184
+ # of max consecutive non-whitespace or whitespace characters.
185
+ MAX_NO_WHITESPACES_CHARS = 25_000
186
+
187
+ texts = self.pre_tokenizer_process(text)
188
+
189
+ all_substrs = []
190
+ for text in texts:
191
+ substrs = (
192
+ substr for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
193
+ for substr in self._split_whitespaces_or_nonwhitespaces(
194
+ text[i:i +
195
+ TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS))
196
+ all_substrs.extend(substrs)
197
+
198
+ t: List[int] = []
199
+ for substr in all_substrs:
200
+ if allow_special_tokens:
201
+ t.extend(
202
+ # we should consider special token as a common token
203
+ self.model.encode(
204
+ substr,
205
+ allowed_special="all",
206
+ ))
207
+ else:
208
+ t.extend(
209
+ # we should consider special token as a common token
210
+ self.model.encode(
211
+ substr,
212
+ disallowed_special=(),
213
+ ))
214
+
215
+ return t
216
+
217
+ def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
218
+ """
219
+ Decodes a list of token IDs into a string.
220
+
221
+ Args:
222
+ token_ids (List[int]): The list of token IDs to be decoded.
223
+
224
+ Returns:
225
+ str: The decoded string.
226
+ """
227
+ # If there are other args, we should call super().decode because there are a lot of code
228
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
229
+ if len(kwargs) > 0:
230
+ return super().decode(token_ids, **kwargs)
231
+
232
+ if type(token_ids) is int:
233
+ token_ids = [token_ids]
234
+
235
+ return self.model.decode(cast(List[int], token_ids))
236
+
237
+ @staticmethod
238
+ def _split_whitespaces_or_nonwhitespaces(
239
+ s: str, max_consecutive_slice_len: int) -> Iterator[str]:
240
+ """
241
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
242
+ consecutive whitespaces or consecutive non-whitespaces.
243
+ """
244
+ current_slice_len = 0
245
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
246
+ slice_start = 0
247
+
248
+ for i in range(len(s)):
249
+ is_now_space = s[i].isspace()
250
+
251
+ if current_slice_is_space ^ is_now_space:
252
+ current_slice_len = 1
253
+ current_slice_is_space = is_now_space
254
+ else:
255
+ current_slice_len += 1
256
+ if current_slice_len > max_consecutive_slice_len:
257
+ yield s[slice_start:i]
258
+ slice_start = i
259
+ current_slice_len = 1
260
+ yield s[slice_start:]
261
+
262
+ def pre_tokenizer_process(self, text: str) -> List[str]:
263
+ """
264
+ pre-tokenizes the input text into a list of tokens.
265
+ This method is used to split the input text into smaller chunks for internal processing.
266
+ """
267
+ return [text]
268
+
269
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
270
+
271
+ @property
272
+ def vocab_size(self) -> int:
273
+ return self.n_words
274
+
275
+ def get_vocab(self) -> Dict[str, int]:
276
+ return self.encoder
277
+
278
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
279
+ return [self.decoder[t] for t in self.encode(text)]
280
+
281
+ def _convert_token_to_id(self, token: str) -> int:
282
+ return self.encoder.get(token, self.unk_id)
283
+
284
+ def _convert_id_to_token(self, index: int) -> str:
285
+ return self.decoder.get(index)
286
+
287
+ @staticmethod
288
+ def clean_up_tokenization(out_string: str) -> str:
289
+ return out_string
290
+
291
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
292
+ text = ''.join(tokens)
293
+ text = bytearray([self.byte_decoder[c]
294
+ for c in text]).decode('utf-8', 'replace')
295
+ return text
296
+
297
+ def save_vocabulary(self,
298
+ save_directory: str,
299
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
300
+ if not os.path.isdir(save_directory):
301
+ raise ValueError(
302
+ f"vocabulary path ({save_directory}) should be a directory")
303
+ out_vocab_file = os.path.join(
304
+ save_directory,
305
+ (filename_prefix + "-" if filename_prefix else "") +
306
+ VOCAB_FILES_NAMES["vocab_file"])
307
+
308
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
309
+ out_vocab_file) and os.path.isfile(self.vocab_file):
310
+ copyfile(self.vocab_file, out_vocab_file)
311
+
312
+ return (out_vocab_file, )
313
+
314
+ def apply_chat_template(self,
315
+ conversation,
316
+ tools: Optional[list[dict]] = None,
317
+ tokenize: bool = False,
318
+ add_generation_prompt: bool = True,
319
+ thinking: bool = True,
320
+ **kwargs):
321
+
322
+ tools = deep_sort_dict(tools)
323
+
324
+ # Convert tools to TypeScript style string if tools are provided
325
+ tools_ts_str = None
326
+ if tools:
327
+ try:
328
+ tools_ts_str = encode_tools_to_typescript_style(tools)
329
+
330
+ except Exception as e:
331
+ print(f"Failed to convert tools to TypeScript style: {e}")
332
+ tools_ts_str = None
333
+
334
+ # Store the TypeScript string in kwargs so it can be accessed by the template
335
+ if tools_ts_str is not None:
336
+ kwargs['tools_ts_str'] = tools_ts_str
337
+ return super().apply_chat_template(
338
+ conversation,
339
+ tools=tools,
340
+ tokenize=tokenize,
341
+ add_generation_prompt=add_generation_prompt,
342
+ thinking=thinking,
343
+ **kwargs)
344
+
345
+
346
+ def deep_sort_dict(obj: Any) -> Any:
347
+ if isinstance(obj, dict):
348
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
349
+ if isinstance(obj, list):
350
+ return [deep_sort_dict(item) for item in obj]
351
+ return obj
tokenizer_config.json ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "163584": {
4
+ "content": "[BOS]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "163585": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "163586": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "163587": {
28
+ "content": "<|im_user|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "163588": {
36
+ "content": "<|im_assistant|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "163590": {
44
+ "content": "<|start_header_id|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "163591": {
52
+ "content": "<|end_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "163593": {
60
+ "content": "[EOT]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "163594": {
68
+ "content": "<|im_system|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "163595": {
76
+ "content": "<|tool_calls_section_begin|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "163596": {
84
+ "content": "<|tool_calls_section_end|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "163597": {
92
+ "content": "<|tool_call_begin|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "163598": {
100
+ "content": "<|tool_call_argument_begin|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "163599": {
108
+ "content": "<|tool_call_end|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "163601": {
116
+ "content": "<|im_middle|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "163602": {
124
+ "content": "<|media_begin|>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "163603": {
132
+ "content": "<|media_content|>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "163604": {
140
+ "content": "<|media_end|>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "163605": {
148
+ "content": "<|media_pad|>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "163606": {
156
+ "content": "<think>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "163607": {
164
+ "content": "</think>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "163838": {
172
+ "content": "[UNK]",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ },
179
+ "163839": {
180
+ "content": "[PAD]",
181
+ "lstrip": false,
182
+ "normalized": false,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": true
186
+ }
187
+ },
188
+ "additional_special_tokens": [
189
+ "<|im_end|>",
190
+ "<|im_user|>",
191
+ "<|im_assistant|>",
192
+ "<|start_header_id|>",
193
+ "<|end_header_id|>",
194
+ "[EOT]",
195
+ "<|im_system|>",
196
+ "<|im_middle|>",
197
+ "<|media_begin|>",
198
+ "<|media_content|>",
199
+ "<|media_end|>",
200
+ "<|media_pad|>"
201
+ ],
202
+ "bos_token": "[BOS]",
203
+ "clean_up_tokenization_spaces": false,
204
+ "eos_token": "[EOS]",
205
+ "extra_special_tokens": {},
206
+ "model_max_length": 1000000000000000019884624838656,
207
+ "pad_token": "[PAD]",
208
+ "tokenizer_class": "TikTokenTokenizer",
209
+ "unk_token": "[UNK]",
210
+ "auto_map": {
211
+ "AutoTokenizer": [
212
+ "tokenization_kimi.TikTokenTokenizer",
213
+ null
214
+ ]
215
+ }
216
+ }