liuzijing2014 commited on
Commit
7128145
·
verified ·
1 Parent(s): 2b26517

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +175 -3
  3. chat_template.jinja +112 -0
  4. config.json +257 -0
  5. configuration_deepseek.py +214 -0
  6. configuration_kimi_k25.py +123 -0
  7. generation_config.json +4 -0
  8. hf_quant_config.json +77 -0
  9. kimi_k25_processor.py +165 -0
  10. kimi_k25_vision_processing.py +251 -0
  11. media_utils.py +368 -0
  12. model-00002-of-00119.safetensors +3 -0
  13. model-00004-of-00119.safetensors +3 -0
  14. model-00006-of-00119.safetensors +3 -0
  15. model-00007-of-00119.safetensors +3 -0
  16. model-00008-of-00119.safetensors +3 -0
  17. model-00013-of-00119.safetensors +3 -0
  18. model-00014-of-00119.safetensors +3 -0
  19. model-00015-of-00119.safetensors +3 -0
  20. model-00018-of-00119.safetensors +3 -0
  21. model-00043-of-00119.safetensors +3 -0
  22. model-00044-of-00119.safetensors +3 -0
  23. model-00047-of-00119.safetensors +3 -0
  24. model-00048-of-00119.safetensors +3 -0
  25. model-00054-of-00119.safetensors +3 -0
  26. model-00056-of-00119.safetensors +3 -0
  27. model-00058-of-00119.safetensors +3 -0
  28. model-00060-of-00119.safetensors +3 -0
  29. model-00067-of-00119.safetensors +3 -0
  30. model-00072-of-00119.safetensors +3 -0
  31. model-00085-of-00119.safetensors +3 -0
  32. model-00088-of-00119.safetensors +3 -0
  33. model-00092-of-00119.safetensors +3 -0
  34. model-00093-of-00119.safetensors +3 -0
  35. model-00099-of-00119.safetensors +3 -0
  36. model-00102-of-00119.safetensors +3 -0
  37. model-00107-of-00119.safetensors +3 -0
  38. model-00108-of-00119.safetensors +3 -0
  39. model-00110-of-00119.safetensors +3 -0
  40. model-00118-of-00119.safetensors +3 -0
  41. model.safetensors.index.json +3 -0
  42. modeling_deepseek.py +1808 -0
  43. modeling_kimi_k25.py +1251 -0
  44. preprocessor_config.json +30 -0
  45. processor_config.json +6 -0
  46. special_tokens_map.json +38 -0
  47. tiktoken.model +3 -0
  48. tokenization_kimi.py +393 -0
  49. tokenizer_config.json +216 -0
  50. tool_declaration_ts.py +479 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,175 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: text-generation
3
+ base_model:
4
+ - moonshotai/Kimi-K2.5
5
+ license: other
6
+ license_name: modified-mit
7
+ library_name: Model Optimizer
8
+ tags:
9
+ - nvidia
10
+ - ModelOpt
11
+ - Kimi-K2
12
+ - quantized
13
+ - FP4
14
+ ---
15
+ # Model Overview
16
+
17
+ ## Description:
18
+ The NVIDIA Kimi-K2.5-NVFP4 model is the quantized version of the Moonshot AI's Kimi-K2.5 model, which is an auto-regressive language model that uses an optimized transformer architecture. For more information, please check [here](https://huggingface.co/moonshotai/Kimi-K2.5). The NVIDIA Kimi-K2.5 NVFP4 model is quantized with [Model Optimizer](https://github.com/NVIDIA/Model-Optimizer).
19
+
20
+ This model is ready for commercial/non-commercial use. <br>
21
+
22
+ ## Third-Party Community Consideration
23
+ This model is not owned or developed by NVIDIA. This model has been developed and built to a third-party’s requirements for this application and use case; see link to Non-NVIDIA [(Kimi-K2.5) Model Card](https://huggingface.co/moonshotai/Kimi-K2.5).
24
+
25
+ ### License/Terms of Use:
26
+ Governing Terms: Use of this model is governed by the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
27
+
28
+ ADDITIONAL INFORMATION: [Modified MIT License](https://huggingface.co/moonshotai/Kimi-K2.5/blob/main/LICENSE).
29
+
30
+ ### Deployment Geography
31
+ Global
32
+
33
+ ### Use Case
34
+ This model is intended for developers and researchers building LLMs
35
+
36
+ ### Release Date
37
+ Hugging Face 02/02/2026 via https://huggingface.co/nvidia/Kimi-K2.5-NVFP4
38
+
39
+
40
+ ## Model Architecture:
41
+ **Architecture Type:** Transformers <br>
42
+ **Network Architecture:** DeepSeek V3 <br>
43
+ **Number of Model Parameters:** 1T
44
+
45
+ ## Input:
46
+ **Input Type(s):** Text, Image, Video <br>
47
+ **Input Format(s):** String, Undisclosed, Undisclosed <br>
48
+ **Input Parameters:** One-Dimensional (1D), Two-Dimensional (2D), Three-Dimensional (3D) <br>
49
+ **Other Properties Related to Input:** Context length: 256k
50
+
51
+ ## Output:
52
+ **Output Type(s):** Text <br>
53
+ **Output Format:** String <br>
54
+ **Output Parameters:** 1D (One Dimensional): Sequences <br>
55
+
56
+ Our AI models are designed and/or optimized to run on NVIDIA GPU-accelerated systems. By leveraging NVIDIA’s hardware (e.g. GPU cores) and software frameworks (e.g., CUDA libraries), the model achieves faster training and inference times compared to CPU-only solutions.
57
+
58
+ ## Software Integration:
59
+ **Supported Runtime Engine(s):** <br>
60
+ * vLLM <br>
61
+
62
+ **Supported Hardware Microarchitecture Compatibility:** <br>
63
+ * NVIDIA Blackwell <br>
64
+
65
+ **Preferred Operating System(s):** <br>
66
+ * Linux <br>
67
+ The integration of foundation and fine-tuned models into AI systems requires additional testing using use-case-specific data to ensure safe and effective deployment. Following the V-model methodology, iterative testing and validation at both unit and system levels are essential to mitigate risks, meet technical and functional requirements, and ensure compliance with safety and ethical standards before deployment.
68
+
69
+ ## Model Version(s):
70
+ ** The model is quantized with nvidia-modelopt **v0.41.0** <br>
71
+
72
+ ## Calibration Datasets:
73
+ * Calibration Dataset: [cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) <br>
74
+ ** Data collection method: Automated. <br>
75
+ ** Labeling method: Automated. <br>
76
+ ** Properties: CNN / DailyMail Dataset is an English-language dataset containing over 300k unique news articles as written by journalists at CNN and the Daily Mail. <br>
77
+
78
+ ## Training Dataset: <br>
79
+ ** Data Collection Method by dataset: Hybrid: Human, Automated <br>
80
+ ** Labeling Method by dataset: Hybrid: Human, Automated <br>
81
+ ** Data Modality: Text, Image, Video <br>
82
+ ** Training Data Size: undisclosed. <br>
83
+
84
+ ## Testing Dataset: <br>
85
+ ** Data Collection Method by dataset: Hybrid: Human, Automated <br>
86
+ ** Labeling Method by dataset: Hybrid: Human, Automated <br>
87
+ ** Properties: Undisclosed <br>
88
+
89
+ ## Evaluation Dataset: <br>
90
+ ** Data Collection Method by dataset: Hybrid: Human, Automated <br>
91
+ ** Labeling Method by dataset: Hybrid: Human, Automated <br>
92
+ ** Properties: Undisclosed <br>
93
+
94
+
95
+ ## Inference:
96
+ **Engine:** vLLM <br>
97
+ **Test Hardware:** B200 <br>
98
+
99
+ ## Post Training Quantization
100
+ This model was obtained by converting and quantizing the weights and activations of Kimi-K2.5 from INT4 to BF16 to NVFP4 data type, ready for inference with vLLM. Only the weights and activations of the linear operators within transformer blocks in MoE are quantized.
101
+
102
+ ## Usage
103
+
104
+
105
+ To serve this checkpoint with [vLLM](https://github.com/vllm-project/vllm), you can start the docker `vllm/vllm-openai:latest` and run the sample command below:
106
+
107
+ ```sh
108
+ python3 -m vllm.entrypoints.openai.api_server --model nvidia/Kimi-K2.5-NVFP4 --tensor-parallel-size 4 --tool-call-parser kimi_k2 --reasoning-parser kimi_k2 --trust-remote-code
109
+ ```
110
+
111
+ ## Evaluation
112
+ The accuracy benchmark results are presented in the table below:
113
+ <table>
114
+ <tr>
115
+ <td><strong>Precision</strong>
116
+ </td>
117
+ <td><strong>MMLU Pro</strong>
118
+ </td>
119
+ <td><strong>LiveCodeBench V6</strong>
120
+ </td>
121
+ <td><strong>SciCode</strong>
122
+ </td>
123
+ <td><strong>AIME 2025</strong>
124
+ </td>
125
+ </tr>
126
+ <tr>
127
+ <td>Baseline (official)
128
+ </td>
129
+ <td><strong>87.1</strong>
130
+ </td>
131
+ <td><strong>85.0</strong>
132
+ </td>
133
+ <td><strong>48.7</strong>
134
+ </td>
135
+ <td><strong>96.1</strong>
136
+ </td>
137
+ </tr>
138
+ <tr>
139
+ <td>Baseline (ours)
140
+ </td>
141
+ <td><strong>86.9</strong>
142
+ </td>
143
+ <td><strong>84.7</strong>
144
+ </td>
145
+ <td><strong>47.7</strong>
146
+ </td>
147
+ <td><strong>96.5</strong>
148
+ </td>
149
+ </tr>
150
+ <tr>
151
+ <td>NVFP4
152
+ </td>
153
+ <td><strong>87.3</strong>
154
+ </td>
155
+ <td><strong>84.0</strong>
156
+ </td>
157
+ <td><strong>48.7</strong>
158
+ </td>
159
+ <td><strong>96.3</strong>
160
+ </td>
161
+ </tr>
162
+ </table>
163
+
164
+ > Baseline (official) numbers are from the [Kimi-K2.5 model card](https://huggingface.co/moonshotai/Kimi-K2.5).
165
+ > Evaluation settings follow the same configuration as described in the [Kimi-K2.5 model card](https://huggingface.co/moonshotai/Kimi-K2.5)
166
+
167
+ ## Model Limitations:
168
+ The base model was trained on data that contains toxic language and societal biases originally crawled from the internet. Therefore, the model may amplify those biases and return toxic responses especially when prompted with toxic prompts. The model may generate answers that may be inaccurate, omit key information, or include irrelevant or redundant text producing socially unacceptable or undesirable text, even if the prompt itself does not include anything explicitly offensive.
169
+
170
+ ## Ethical Considerations
171
+
172
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
173
+
174
+ Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://app.intigriti.com/programs/nvidia/nvidiavdp/detail).
175
+
chat_template.jinja ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- macro render_content(msg) -%}
2
+ {%- set c = msg.get('content') -%}
3
+ {%- if c is string -%}
4
+ {{ c }}
5
+ {%- elif c is not none -%}
6
+ {% for content in c -%}
7
+ {% if content['type'] == 'image' or content['type'] == 'image_url' -%}
8
+ <|media_start|>image<|media_content|><|media_pad|><|media_end|>
9
+ {% elif content['type'] == 'video' or content['type']== 'video_url'-%}
10
+ <|kimi_k25_video_placeholder|>
11
+ {% else -%}
12
+ {{ content['text'] }}
13
+ {%- endif -%}
14
+ {%- endfor -%}
15
+ {%- endif -%}
16
+ {%- endmacro -%}
17
+
18
+ {% macro set_roles(message) -%}
19
+ {%- set role_name = message.get('name') or message['role'] -%}
20
+ {%- if message['role'] == 'user' -%}
21
+ <|im_user|>{{role_name}}<|im_middle|>
22
+ {%- elif message['role'] == 'assistant' -%}
23
+ <|im_assistant|>{{role_name}}<|im_middle|>
24
+ {%- else -%}
25
+ <|im_system|>{{role_name}}<|im_middle|>
26
+ {%- endif -%}
27
+ {%- endmacro -%}
28
+
29
+
30
+ {%- macro render_toolcalls(message) -%}
31
+ <|tool_calls_section_begin|>
32
+ {%- for tool_call in message['tool_calls'] -%}
33
+ {%- set formatted_id = tool_call['id'] -%}
34
+ <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
35
+ {%- endfor -%}
36
+ <|tool_calls_section_end|>
37
+ {%- endmacro -%}
38
+
39
+
40
+ {# Find last non-tool-call assisitant message #}
41
+ {%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%}
42
+ {%- for idx in range(messages|length-1, -1, -1) -%}
43
+ {%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%}
44
+ {%- set ns.last_non_tool_call_assistant_msg = idx -%}
45
+ {%- break -%}
46
+ {%- endif -%}
47
+ {%- endfor -%}
48
+
49
+ {# split all messages into history & suffix, reasoning_content in suffix should be reserved.#}
50
+ {%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%}
51
+ {%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%}
52
+
53
+ {%- if tools -%}
54
+ {%- if tools_ts_str -%}
55
+ <|im_system|>tool_declare<|im_middle|>{{ tools_ts_str }}<|im_end|>
56
+ {%- else -%}
57
+ <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>
58
+ {%- endif -%}
59
+ {%- endif -%}
60
+
61
+ {%- if messages|length == 0 or messages[0]['role'] != 'system' -%}
62
+ <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>
63
+ {%- endif -%}
64
+
65
+ {%- for message in hist_msgs -%}
66
+ {{set_roles(message)}}
67
+ {%- if message['role'] == 'assistant' -%}
68
+ <think></think>{{render_content(message)}}
69
+ {%- if message.get('tool_calls') -%}
70
+ {{render_toolcalls(message)}}
71
+ {%- endif -%}
72
+ {%- elif message['role'] == 'tool' -%}
73
+ {%- set tool_call_id = message.tool_call_id -%}
74
+ ## Return of {{ tool_call_id }}
75
+ {{render_content(message)}}
76
+ {%- elif message['content'] is not none -%}
77
+ {{render_content(message)}}
78
+ {%- endif -%}
79
+ <|im_end|>
80
+ {%- endfor -%}
81
+
82
+ {%- for message in suffix_msgs -%}
83
+ {{set_roles(message)}}
84
+ {%- if message['role'] == 'assistant' -%}
85
+ {%- if thinking is defined and thinking is false -%}
86
+ <think></think>{{render_content(message)}}
87
+ {%- else -%}
88
+ {%- set rc = message.get('reasoning_content', '') -%}
89
+ <think>{{rc}}</think>{{render_content(message)}}
90
+ {%- endif -%}
91
+ {%- if message.get('tool_calls') -%}
92
+ {{render_toolcalls(message)}}
93
+ {%- endif -%}
94
+ {%- elif message['role'] == 'tool' -%}
95
+ {%- set tool_call_id = message.tool_call_id -%}
96
+ ## Return of {{ tool_call_id }}
97
+ {{render_content(message)}}
98
+ {%- elif message['content'] is not none -%}
99
+ {{render_content(message)}}
100
+ {%- endif -%}
101
+ <|im_end|>
102
+ {%- endfor -%}
103
+
104
+
105
+ {%- if add_generation_prompt -%}
106
+ <|im_assistant|>assistant<|im_middle|>
107
+ {%- if thinking is defined and thinking is false -%}
108
+ <think></think>
109
+ {%- else -%}
110
+ <think>
111
+ {%- endif -%}
112
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiK25ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi_k25.KimiK25Config",
7
+ "AutoModel": "modeling_kimi_k25.KimiK25ForConditionalGeneration",
8
+ "AutoModelForCausalLM": "modeling_kimi_k25.KimiK25ForConditionalGeneration"
9
+ },
10
+ "bos_token_id": 163584,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 163585,
13
+ "ignore_index": -100,
14
+ "media_placeholder_token_id": 163605,
15
+ "model_type": "kimi_k25",
16
+ "pad_token_id": 163839,
17
+ "quantization_config": {
18
+ "config_groups": {
19
+ "group_0": {
20
+ "input_activations": {
21
+ "dynamic": false,
22
+ "num_bits": 4,
23
+ "type": "float",
24
+ "group_size": 16
25
+ },
26
+ "weights": {
27
+ "dynamic": false,
28
+ "num_bits": 4,
29
+ "type": "float",
30
+ "group_size": 16
31
+ },
32
+ "targets": [
33
+ "Linear"
34
+ ]
35
+ }
36
+ },
37
+ "ignore": [
38
+ "language_model.lm_head",
39
+ "language_model.model.layers.0.self_attn*",
40
+ "language_model.model.layers.1.self_attn*",
41
+ "language_model.model.layers.10.self_attn*",
42
+ "language_model.model.layers.11.self_attn*",
43
+ "language_model.model.layers.12.self_attn*",
44
+ "language_model.model.layers.13.self_attn*",
45
+ "language_model.model.layers.14.self_attn*",
46
+ "language_model.model.layers.15.self_attn*",
47
+ "language_model.model.layers.16.self_attn*",
48
+ "language_model.model.layers.17.self_attn*",
49
+ "language_model.model.layers.18.self_attn*",
50
+ "language_model.model.layers.19.self_attn*",
51
+ "language_model.model.layers.2.self_attn*",
52
+ "language_model.model.layers.20.self_attn*",
53
+ "language_model.model.layers.21.self_attn*",
54
+ "language_model.model.layers.22.self_attn*",
55
+ "language_model.model.layers.23.self_attn*",
56
+ "language_model.model.layers.24.self_attn*",
57
+ "language_model.model.layers.25.self_attn*",
58
+ "language_model.model.layers.26.self_attn*",
59
+ "language_model.model.layers.27.self_attn*",
60
+ "language_model.model.layers.28.self_attn*",
61
+ "language_model.model.layers.29.self_attn*",
62
+ "language_model.model.layers.3.self_attn*",
63
+ "language_model.model.layers.30.self_attn*",
64
+ "language_model.model.layers.31.self_attn*",
65
+ "language_model.model.layers.32.self_attn*",
66
+ "language_model.model.layers.33.self_attn*",
67
+ "language_model.model.layers.34.self_attn*",
68
+ "language_model.model.layers.35.self_attn*",
69
+ "language_model.model.layers.36.self_attn*",
70
+ "language_model.model.layers.37.self_attn*",
71
+ "language_model.model.layers.38.self_attn*",
72
+ "language_model.model.layers.39.self_attn*",
73
+ "language_model.model.layers.4.self_attn*",
74
+ "language_model.model.layers.40.self_attn*",
75
+ "language_model.model.layers.41.self_attn*",
76
+ "language_model.model.layers.42.self_attn*",
77
+ "language_model.model.layers.43.self_attn*",
78
+ "language_model.model.layers.44.self_attn*",
79
+ "language_model.model.layers.45.self_attn*",
80
+ "language_model.model.layers.46.self_attn*",
81
+ "language_model.model.layers.47.self_attn*",
82
+ "language_model.model.layers.48.self_attn*",
83
+ "language_model.model.layers.49.self_attn*",
84
+ "language_model.model.layers.5.self_attn*",
85
+ "language_model.model.layers.50.self_attn*",
86
+ "language_model.model.layers.51.self_attn*",
87
+ "language_model.model.layers.52.self_attn*",
88
+ "language_model.model.layers.53.self_attn*",
89
+ "language_model.model.layers.54.self_attn*",
90
+ "language_model.model.layers.55.self_attn*",
91
+ "language_model.model.layers.56.self_attn*",
92
+ "language_model.model.layers.57.self_attn*",
93
+ "language_model.model.layers.58.self_attn*",
94
+ "language_model.model.layers.59.self_attn*",
95
+ "language_model.model.layers.6.self_attn*",
96
+ "language_model.model.layers.60.self_attn*",
97
+ "language_model.model.layers.7.self_attn*",
98
+ "language_model.model.layers.8.self_attn*",
99
+ "language_model.model.layers.9.self_attn*",
100
+ "mm_projector*",
101
+ "vision_tower*"
102
+ ],
103
+ "quant_algo": "NVFP4",
104
+ "kv_cache_scheme": {
105
+ "dynamic": false,
106
+ "num_bits": 8,
107
+ "type": "float"
108
+ },
109
+ "producer": {
110
+ "name": "modelopt",
111
+ "version": "0.41.0"
112
+ },
113
+ "quant_method": "modelopt"
114
+ },
115
+ "text_config": {
116
+ "_name_or_path": "",
117
+ "add_cross_attention": false,
118
+ "architectures": [
119
+ "DeepseekV3ForCausalLM"
120
+ ],
121
+ "attention_bias": false,
122
+ "attention_dropout": 0.0,
123
+ "auto_map": {
124
+ "AutoConfig": "configuration_deepseek.DeepseekV3Config",
125
+ "AutoModel": "modeling_deepseek.DeepseekV3Model",
126
+ "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
127
+ },
128
+ "aux_loss_alpha": 0.001,
129
+ "bad_words_ids": null,
130
+ "begin_suppress_tokens": null,
131
+ "bos_token_id": 163584,
132
+ "chunk_size_feed_forward": 0,
133
+ "cross_attention_hidden_size": null,
134
+ "decoder_start_token_id": null,
135
+ "diversity_penalty": 0.0,
136
+ "do_sample": false,
137
+ "dtype": "bfloat16",
138
+ "early_stopping": false,
139
+ "encoder_no_repeat_ngram_size": 0,
140
+ "eos_token_id": 163585,
141
+ "ep_size": 1,
142
+ "exponential_decay_length_penalty": null,
143
+ "finetuning_task": null,
144
+ "first_k_dense_replace": 1,
145
+ "forced_bos_token_id": null,
146
+ "forced_eos_token_id": null,
147
+ "hidden_act": "silu",
148
+ "hidden_size": 7168,
149
+ "id2label": {
150
+ "0": "LABEL_0",
151
+ "1": "LABEL_1"
152
+ },
153
+ "initializer_range": 0.02,
154
+ "intermediate_size": 18432,
155
+ "is_decoder": false,
156
+ "is_encoder_decoder": false,
157
+ "kv_lora_rank": 512,
158
+ "label2id": {
159
+ "LABEL_0": 0,
160
+ "LABEL_1": 1
161
+ },
162
+ "length_penalty": 1.0,
163
+ "max_length": 20,
164
+ "max_position_embeddings": 262144,
165
+ "min_length": 0,
166
+ "model_type": "deepseek_v3",
167
+ "moe_intermediate_size": 2048,
168
+ "moe_layer_freq": 1,
169
+ "n_group": 1,
170
+ "n_routed_experts": 384,
171
+ "n_shared_experts": 1,
172
+ "no_repeat_ngram_size": 0,
173
+ "norm_topk_prob": true,
174
+ "num_attention_heads": 64,
175
+ "num_beam_groups": 1,
176
+ "num_beams": 1,
177
+ "num_experts_per_tok": 8,
178
+ "num_hidden_layers": 61,
179
+ "num_key_value_heads": 64,
180
+ "num_nextn_predict_layers": 0,
181
+ "num_return_sequences": 1,
182
+ "output_attentions": false,
183
+ "output_hidden_states": false,
184
+ "output_scores": false,
185
+ "pad_token_id": 163839,
186
+ "prefix": null,
187
+ "pretraining_tp": 1,
188
+ "problem_type": null,
189
+ "pruned_heads": {},
190
+ "q_lora_rank": 1536,
191
+ "qk_nope_head_dim": 128,
192
+ "qk_rope_head_dim": 64,
193
+ "remove_invalid_values": false,
194
+ "repetition_penalty": 1.0,
195
+ "return_dict": true,
196
+ "return_dict_in_generate": false,
197
+ "rms_norm_eps": 1e-05,
198
+ "rope_scaling": {
199
+ "beta_fast": 32.0,
200
+ "beta_slow": 1.0,
201
+ "factor": 64.0,
202
+ "mscale": 1.0,
203
+ "mscale_all_dim": 1.0,
204
+ "original_max_position_embeddings": 4096,
205
+ "type": "yarn"
206
+ },
207
+ "rope_theta": 50000.0,
208
+ "routed_scaling_factor": 2.827,
209
+ "scoring_func": "sigmoid",
210
+ "sep_token_id": null,
211
+ "seq_aux": true,
212
+ "suppress_tokens": null,
213
+ "task_specific_params": null,
214
+ "temperature": 1.0,
215
+ "tf_legacy_loss": false,
216
+ "tie_encoder_decoder": false,
217
+ "tie_word_embeddings": false,
218
+ "tokenizer_class": null,
219
+ "top_k": 50,
220
+ "top_p": 1.0,
221
+ "topk_group": 1,
222
+ "topk_method": "noaux_tc",
223
+ "torchscript": false,
224
+ "typical_p": 1.0,
225
+ "use_bfloat16": false,
226
+ "use_cache": true,
227
+ "v_head_dim": 128,
228
+ "vocab_size": 163840
229
+ },
230
+ "tie_word_embeddings": false,
231
+ "transformers_version": "4.57.1",
232
+ "use_unified_vision_chunk": true,
233
+ "video_placeholder": "<|kimi_k25_video_placeholder|>",
234
+ "vision_config": {
235
+ "init_pos_emb_height": 64,
236
+ "init_pos_emb_time": 4,
237
+ "init_pos_emb_width": 64,
238
+ "merge_kernel_size": [
239
+ 2,
240
+ 2
241
+ ],
242
+ "merge_type": "sd2_tpool",
243
+ "mm_hidden_size": 1152,
244
+ "mm_projector_type": "patchmerger",
245
+ "model_type": "",
246
+ "patch_size": 14,
247
+ "pos_emb_type": "divided_fixed",
248
+ "projector_hidden_act": "gelu",
249
+ "projector_ln_eps": 1e-05,
250
+ "text_hidden_size": 7168,
251
+ "video_attn_type": "spatial_temporal",
252
+ "vt_hidden_size": 1152,
253
+ "vt_intermediate_size": 4304,
254
+ "vt_num_attention_heads": 16,
255
+ "vt_num_hidden_layers": 27
256
+ }
257
+ }
configuration_deepseek.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/configuration_deepseek.py
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+ DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
9
+
10
+
11
+ class DeepseekV3Config(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+
21
+ Args:
22
+ vocab_size (`int`, *optional*, defaults to 129280):
23
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
24
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
25
+ hidden_size (`int`, *optional*, defaults to 4096):
26
+ Dimension of the hidden representations.
27
+ intermediate_size (`int`, *optional*, defaults to 11008):
28
+ Dimension of the MLP representations.
29
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
30
+ Dimension of the MoE representations.
31
+ num_hidden_layers (`int`, *optional*, defaults to 32):
32
+ Number of hidden layers in the Transformer decoder.
33
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
34
+ Number of nextn predict layers in the DeepSeekV3 Model.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer decoder.
37
+ n_shared_experts (`int`, *optional*, defaults to None):
38
+ Number of shared experts, None means dense model.
39
+ n_routed_experts (`int`, *optional*, defaults to None):
40
+ Number of routed experts, None means dense model.
41
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
42
+ Scaling factor or routed experts.
43
+ topk_method (`str`, *optional*, defaults to `gready`):
44
+ Topk method used in routed gate.
45
+ n_group (`int`, *optional*, defaults to None):
46
+ Number of groups for routed experts.
47
+ topk_group (`int`, *optional*, defaults to None):
48
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
49
+ num_experts_per_tok (`int`, *optional*, defaults to None):
50
+ Number of selected experts, None means dense model.
51
+ moe_layer_freq (`int`, *optional*, defaults to 1):
52
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
53
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
54
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
55
+ \--k dense layers--/
56
+ norm_topk_prob (`bool`, *optional*, defaults to False):
57
+ Whether to normalize the weights of the routed experts.
58
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
59
+ Method of computing expert weights.
60
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
61
+ Auxiliary loss weight coefficient.
62
+ seq_aux = (`bool`, *optional*, defaults to True):
63
+ Whether to compute the auxiliary loss for each individual sample.
64
+ num_key_value_heads (`int`, *optional*):
65
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
66
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
67
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
68
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
69
+ by meanpooling all the original heads within that group. For more details checkout [this
70
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
71
+ `num_attention_heads`.
72
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
73
+ The non-linear activation function (function or string) in the decoder.
74
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
75
+ The maximum sequence length that this model might ever be used with.
76
+ initializer_range (`float`, *optional*, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
79
+ The epsilon used by the rms normalization layers.
80
+ use_cache (`bool`, *optional*, defaults to `True`):
81
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
82
+ relevant if `config.is_decoder=True`.
83
+ pad_token_id (`int`, *optional*):
84
+ Padding token id.
85
+ bos_token_id (`int`, *optional*, defaults to 1):
86
+ Beginning of stream token id.
87
+ eos_token_id (`int`, *optional*, defaults to 2):
88
+ End of stream token id.
89
+ pretraining_tp (`int`, *optional*, defaults to 1):
90
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
91
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
92
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
93
+ issue](https://github.com/pytorch/pytorch/issues/76232).
94
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
95
+ Whether to tie weight embeddings
96
+ rope_theta (`float`, *optional*, defaults to 10000.0):
97
+ The base period of the RoPE embeddings.
98
+ rope_scaling (`Dict`, *optional*):
99
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
100
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
101
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
102
+ `max_position_embeddings` to the expected new maximum.
103
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
104
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
105
+ attention_dropout (`float`, *optional*, defaults to 0.0):
106
+ The dropout ratio for the attention probabilities.
107
+
108
+ ```python
109
+ >>> from transformers import DeepseekV3Model, DeepseekV3Config
110
+
111
+ >>> # Initializing a Deepseek-V3 style configuration
112
+ >>> configuration = DeepseekV3Config()
113
+
114
+ >>> # Accessing the model configuration
115
+ >>> configuration = model.config
116
+ ```"""
117
+
118
+ model_type = "deepseek_v3"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=129280,
124
+ hidden_size=7168,
125
+ intermediate_size=18432,
126
+ moe_intermediate_size=2048,
127
+ num_hidden_layers=61,
128
+ num_nextn_predict_layers=1,
129
+ num_attention_heads=128,
130
+ num_key_value_heads=128,
131
+ n_shared_experts=1,
132
+ n_routed_experts=256,
133
+ ep_size=1,
134
+ routed_scaling_factor=2.5,
135
+ kv_lora_rank=512,
136
+ q_lora_rank=1536,
137
+ qk_rope_head_dim=64,
138
+ v_head_dim=128,
139
+ qk_nope_head_dim=128,
140
+ topk_method='noaux_tc',
141
+ n_group=8,
142
+ topk_group=4,
143
+ num_experts_per_tok=8,
144
+ moe_layer_freq=1,
145
+ first_k_dense_replace=3,
146
+ norm_topk_prob=True,
147
+ scoring_func='sigmoid',
148
+ aux_loss_alpha=0.001,
149
+ seq_aux=True,
150
+ hidden_act="silu",
151
+ max_position_embeddings=4096,
152
+ initializer_range=0.02,
153
+ rms_norm_eps=1e-6,
154
+ use_cache=True,
155
+ pad_token_id=None,
156
+ bos_token_id=0,
157
+ eos_token_id=1,
158
+ pretraining_tp=1,
159
+ tie_word_embeddings=False,
160
+ rope_theta=10000.0,
161
+ rope_scaling=None,
162
+ attention_bias=False,
163
+ attention_dropout=0.0,
164
+ **kwargs,
165
+ ):
166
+ self.vocab_size = vocab_size
167
+ self.max_position_embeddings = max_position_embeddings
168
+ self.hidden_size = hidden_size
169
+ self.intermediate_size = intermediate_size
170
+ self.moe_intermediate_size = moe_intermediate_size
171
+ self.num_hidden_layers = num_hidden_layers
172
+ self.num_nextn_predict_layers = num_nextn_predict_layers
173
+ self.num_attention_heads = num_attention_heads
174
+ self.n_shared_experts = n_shared_experts
175
+ self.n_routed_experts = n_routed_experts
176
+ self.ep_size = ep_size
177
+ self.routed_scaling_factor = routed_scaling_factor
178
+ self.kv_lora_rank = kv_lora_rank
179
+ self.q_lora_rank = q_lora_rank
180
+ self.qk_rope_head_dim = qk_rope_head_dim
181
+ self.v_head_dim = v_head_dim
182
+ self.qk_nope_head_dim = qk_nope_head_dim
183
+ self.topk_method = topk_method
184
+ self.n_group = n_group
185
+ self.topk_group = topk_group
186
+ self.num_experts_per_tok = num_experts_per_tok
187
+ self.moe_layer_freq = moe_layer_freq
188
+ self.first_k_dense_replace = first_k_dense_replace
189
+ self.norm_topk_prob = norm_topk_prob
190
+ self.scoring_func = scoring_func
191
+ self.aux_loss_alpha = aux_loss_alpha
192
+ self.seq_aux = seq_aux
193
+ # for backward compatibility
194
+ if num_key_value_heads is None:
195
+ num_key_value_heads = num_attention_heads
196
+
197
+ self.num_key_value_heads = num_key_value_heads
198
+ self.hidden_act = hidden_act
199
+ self.initializer_range = initializer_range
200
+ self.rms_norm_eps = rms_norm_eps
201
+ self.pretraining_tp = pretraining_tp
202
+ self.use_cache = use_cache
203
+ self.rope_theta = rope_theta
204
+ self.rope_scaling = rope_scaling
205
+ self.attention_bias = attention_bias
206
+ self.attention_dropout = attention_dropout
207
+
208
+ super().__init__(
209
+ pad_token_id=pad_token_id,
210
+ bos_token_id=bos_token_id,
211
+ eos_token_id=eos_token_id,
212
+ tie_word_embeddings=tie_word_embeddings,
213
+ **kwargs,
214
+ )
configuration_kimi_k25.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ try:
4
+ from configuration_deepseek import DeepseekV3Config
5
+ except ImportError:
6
+ from .configuration_deepseek import DeepseekV3Config
7
+
8
+
9
+ class KimiK25VisionConfig(PretrainedConfig):
10
+
11
+ def __init__(
12
+ self,
13
+ patch_size: int = 14,
14
+ init_pos_emb_height: int = 64,
15
+ init_pos_emb_width: int = 64,
16
+ init_pos_emb_time: int = 4,
17
+ pos_emb_type: str = 'divided_fixed',
18
+ vt_num_attention_heads: int = 16,
19
+ vt_num_hidden_layers: int = 27,
20
+ vt_hidden_size: int = 1152,
21
+ vt_intermediate_size: int = 4304,
22
+ merge_kernel_size: tuple = (2, 2),
23
+ video_attn_type: str = 'spatial_temporal',
24
+ merge_type: str = 'sd2_tpool',
25
+ _attn_implementation: str = 'flash_attention_2',
26
+ # MM Projector parameters
27
+ mm_projector_type: str = 'patchmerger',
28
+ mm_hidden_size: int | None = None,
29
+ projector_hidden_act: str = "gelu",
30
+ projector_ln_eps: float = 1e-5,
31
+ # Other parameters
32
+ ignore_index: int = -100,
33
+ media_placeholder_token_id: int = 163605,
34
+ pad_token_id: int = 0,
35
+ use_unified_vision_chunk: bool = True,
36
+ video_placeholder="<|kimi_k25_video_placeholder|>",
37
+ text_hidden_size=7168,
38
+ **vision_config_kwargs):
39
+
40
+ self.patch_size = patch_size
41
+ self.init_pos_emb_height = init_pos_emb_height
42
+ self.init_pos_emb_width = init_pos_emb_width
43
+ self.init_pos_emb_time = init_pos_emb_time
44
+ self.pos_emb_type = pos_emb_type
45
+ self.vt_num_attention_heads = vt_num_attention_heads
46
+ self.vt_num_hidden_layers = vt_num_hidden_layers
47
+ self.vt_hidden_size = vt_hidden_size
48
+ self.vt_intermediate_size = vt_intermediate_size
49
+ self.merge_kernel_size = merge_kernel_size
50
+ self.video_attn_type = video_attn_type
51
+ self.merge_type = merge_type
52
+ self._attn_implementation = _attn_implementation
53
+
54
+ # MM Projector config
55
+ self.mm_projector_type = mm_projector_type
56
+ self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else vt_hidden_size
57
+ self.projector_hidden_act = projector_hidden_act
58
+ self.projector_ln_eps = projector_ln_eps
59
+ self.text_hidden_size = text_hidden_size
60
+
61
+
62
+ class KimiK25Config(PretrainedConfig):
63
+ """Kimi-K2.5 model configuration.
64
+
65
+ Args:
66
+ text_config (dict | DeepseekV3Config): Configuration for the text model.
67
+
68
+ Vision Tower Parameters (from MoonViT3dConfig):
69
+ patch_size (int): Patch size for vision tower.
70
+ init_pos_emb_height (int): Initial position embedding height.
71
+ init_pos_emb_width (int): Initial position embedding width.
72
+ init_pos_emb_time (int): Initial position embedding time dimension.
73
+ pos_emb_type (str): Type of position embedding.
74
+ vt_num_attention_heads (int): Number of attention heads in vision tower.
75
+ vt_num_hidden_layers (int): Number of hidden layers in vision tower.
76
+ vt_hidden_size (int): Hidden size of vision tower.
77
+ vt_intermediate_size (int): Intermediate size in vision tower FFN.
78
+ merge_kernel_size (tuple): Kernel size for patch merging.
79
+ video_attn_type (str): Type of video attention.
80
+ merge_type (str): Type of merge operation.
81
+ _attn_implementation (str): Attention implementation type.
82
+
83
+ MM Projector Parameters (from MultiModalProjectorConfig):
84
+ mm_projector_type (str): Type of multimodal projector.
85
+ mm_hidden_size (int): Hidden size from vision tower (should match vt_hidden_size).
86
+ projector_hidden_act (str): Activation function for projector.
87
+ projector_ln_eps (float): Layer norm epsilon for projector.
88
+
89
+ Other Parameters:
90
+ ignore_index (int): The ignore index for the loss function.
91
+ media_placeholder_token_id (int): The token ID to use for media placeholders.
92
+ pad_token_id (int): The token ID to use for padding.
93
+ """
94
+
95
+ model_type = "kimi_k25"
96
+
97
+ def __init__(
98
+ self,
99
+ text_config: dict | DeepseekV3Config = None,
100
+ vision_config: dict | KimiK25VisionConfig = None,
101
+ # Other parameters
102
+ ignore_index: int = -100,
103
+ media_placeholder_token_id: int = 163605,
104
+ pad_token_id: int = 0,
105
+ use_unified_vision_chunk: bool = True,
106
+ video_placeholder="<|kimi_k25_video_placeholder|>",
107
+ **kwargs,
108
+ ):
109
+ if isinstance(text_config, dict):
110
+ text_config = DeepseekV3Config(**text_config)
111
+ if isinstance(vision_config, dict):
112
+ vision_config = KimiK25VisionConfig(**vision_config)
113
+ self.text_config = text_config
114
+ self.vision_config = vision_config
115
+ # Other config
116
+ self.ignore_index = ignore_index
117
+ self.media_placeholder_token_id = media_placeholder_token_id
118
+ self.use_unified_vision_chunk = use_unified_vision_chunk
119
+ self.video_placeholder = video_placeholder
120
+ if getattr(self.text_config, "quantization_config", None) is not None:
121
+ self.quantization_config = self.text_config.quantization_config
122
+
123
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_length": 262144,
3
+ "eos_token_id": 163586
4
+ }
hf_quant_config.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "producer": {
3
+ "name": "modelopt",
4
+ "version": "0.41.0"
5
+ },
6
+ "quantization": {
7
+ "quant_algo": "NVFP4",
8
+ "kv_cache_quant_algo": "FP8",
9
+ "group_size": 16,
10
+ "exclude_modules": [
11
+ "language_model.lm_head",
12
+ "language_model.model.layers.0.self_attn*",
13
+ "language_model.model.layers.1.self_attn*",
14
+ "language_model.model.layers.10.self_attn*",
15
+ "language_model.model.layers.11.self_attn*",
16
+ "language_model.model.layers.12.self_attn*",
17
+ "language_model.model.layers.13.self_attn*",
18
+ "language_model.model.layers.14.self_attn*",
19
+ "language_model.model.layers.15.self_attn*",
20
+ "language_model.model.layers.16.self_attn*",
21
+ "language_model.model.layers.17.self_attn*",
22
+ "language_model.model.layers.18.self_attn*",
23
+ "language_model.model.layers.19.self_attn*",
24
+ "language_model.model.layers.2.self_attn*",
25
+ "language_model.model.layers.20.self_attn*",
26
+ "language_model.model.layers.21.self_attn*",
27
+ "language_model.model.layers.22.self_attn*",
28
+ "language_model.model.layers.23.self_attn*",
29
+ "language_model.model.layers.24.self_attn*",
30
+ "language_model.model.layers.25.self_attn*",
31
+ "language_model.model.layers.26.self_attn*",
32
+ "language_model.model.layers.27.self_attn*",
33
+ "language_model.model.layers.28.self_attn*",
34
+ "language_model.model.layers.29.self_attn*",
35
+ "language_model.model.layers.3.self_attn*",
36
+ "language_model.model.layers.30.self_attn*",
37
+ "language_model.model.layers.31.self_attn*",
38
+ "language_model.model.layers.32.self_attn*",
39
+ "language_model.model.layers.33.self_attn*",
40
+ "language_model.model.layers.34.self_attn*",
41
+ "language_model.model.layers.35.self_attn*",
42
+ "language_model.model.layers.36.self_attn*",
43
+ "language_model.model.layers.37.self_attn*",
44
+ "language_model.model.layers.38.self_attn*",
45
+ "language_model.model.layers.39.self_attn*",
46
+ "language_model.model.layers.4.self_attn*",
47
+ "language_model.model.layers.40.self_attn*",
48
+ "language_model.model.layers.41.self_attn*",
49
+ "language_model.model.layers.42.self_attn*",
50
+ "language_model.model.layers.43.self_attn*",
51
+ "language_model.model.layers.44.self_attn*",
52
+ "language_model.model.layers.45.self_attn*",
53
+ "language_model.model.layers.46.self_attn*",
54
+ "language_model.model.layers.47.self_attn*",
55
+ "language_model.model.layers.48.self_attn*",
56
+ "language_model.model.layers.49.self_attn*",
57
+ "language_model.model.layers.5.self_attn*",
58
+ "language_model.model.layers.50.self_attn*",
59
+ "language_model.model.layers.51.self_attn*",
60
+ "language_model.model.layers.52.self_attn*",
61
+ "language_model.model.layers.53.self_attn*",
62
+ "language_model.model.layers.54.self_attn*",
63
+ "language_model.model.layers.55.self_attn*",
64
+ "language_model.model.layers.56.self_attn*",
65
+ "language_model.model.layers.57.self_attn*",
66
+ "language_model.model.layers.58.self_attn*",
67
+ "language_model.model.layers.59.self_attn*",
68
+ "language_model.model.layers.6.self_attn*",
69
+ "language_model.model.layers.60.self_attn*",
70
+ "language_model.model.layers.7.self_attn*",
71
+ "language_model.model.layers.8.self_attn*",
72
+ "language_model.model.layers.9.self_attn*",
73
+ "mm_projector*",
74
+ "vision_tower*"
75
+ ]
76
+ }
77
+ }
kimi_k25_processor.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.feature_extraction_utils import BatchFeature
2
+ from transformers.processing_utils import ProcessorMixin
3
+ from transformers.utils import logging
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class KimiK25Processor(ProcessorMixin):
9
+ r"""
10
+ Constructs a KimiK25 processor which wraps a KimiK25 image processor and a tokenizer into a single processor.
11
+
12
+ [`KimiK25Processor`] offers all the functionalities of [`KimiK25ImageProcessor`] and [`TikTokenTokenizer`]. See the
13
+ [`~KimiK25Processor.__call__`] and [`~KimiK25Processor.decode`] for more information.
14
+
15
+ Args:
16
+ image_processor ([`KimiK25ImageProcessor`], *optional*):
17
+ The image processor is a required input.
18
+ tokenizer ([`TikTokenTokenizer`], *optional*):
19
+ The tokenizer is a required input.
20
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
21
+ in a chat into a tokenizable string.
22
+ """
23
+
24
+ attributes = ["image_processor", "tokenizer"]
25
+ valid_kwargs = ["chat_template"]
26
+ image_processor_class = "AutoImageProcessor"
27
+ tokenizer_class = "AutoTokenizer"
28
+
29
+ def __init__(
30
+ self,
31
+ image_processor=None,
32
+ tokenizer=None,
33
+ chat_template=None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(image_processor,
37
+ tokenizer,
38
+ chat_template=chat_template)
39
+ self.media_processor = image_processor
40
+ # A special temporal placeholder to be replaced by actual video placeholders
41
+ self.video_placeholder = "<|kimi_k25_video_placeholder|>"
42
+
43
+ def update_raw_text(self, text: str, video_prompts: list[str]) -> str:
44
+ # replace video prompt in text with video chunk prompts
45
+ video_count = text.count(self.video_placeholder)
46
+ if video_count == 0:
47
+ return text
48
+ assert video_count == len(video_prompts)
49
+ text_parts = text.split(self.video_placeholder)
50
+ assert len(text_parts) == len(video_prompts) + 1
51
+ text = "".join([
52
+ text_parts[i] + video_prompts[i] for i in range(len(video_prompts))
53
+ ])
54
+ text += text_parts[-1]
55
+ return text
56
+
57
+ def preprocess_medias(self, medias: list[dict]) -> list[dict]:
58
+ updated_medias = []
59
+ video_prompts = []
60
+ for media in medias:
61
+ if media['type'] == 'image':
62
+ updated_medias.append(media)
63
+ elif media['type'] == 'video':
64
+ video_chunks = self.media_processor.split_video_chunks(
65
+ media['video'])
66
+ updated_medias.extend(video_chunks)
67
+ video_prompts.append("".join(
68
+ [vc['prompt'] for vc in video_chunks]))
69
+ else:
70
+ raise ValueError(f"unsupported media type: {media['type']}")
71
+ return updated_medias, video_prompts
72
+
73
+ def __call__(self,
74
+ messages: list[dict] = None,
75
+ medias: list[dict] = None,
76
+ text: str = None,
77
+ return_tensors: str = "pt",
78
+ **kwargs) -> BatchFeature:
79
+ """
80
+ Process multimodal inputs for Kimi-K2.5 model.
81
+
82
+ This processor accepts ordered messages and extracts both media and text in a single pass.
83
+ text will be automatically updated if video input detected in messages
84
+
85
+ Args:
86
+ messages: List of message dicts with 'role' and 'content' fields.
87
+ If provided, medias and text will be extracted automatically.
88
+ medias: Pre-extracted list of media dicts. If None, extracted from messages.
89
+ text: Pre-formatted text string. If None, generated via apply_chat_template.
90
+ return_tensors: Format of returned tensors ('pt', 'np', 'tf'). Default: 'pt'.
91
+ **kwargs: Additional arguments passed to tokenizer.apply_chat_template.
92
+
93
+ Returns:
94
+ BatchFeature with fields: input_ids, attention_mask, pixel_values, grid_thws.
95
+ """
96
+ if messages is None and (medias is None or text is None):
97
+ raise ValueError(
98
+ "Provide either 'messages' or both 'medias' and 'text'")
99
+
100
+ if medias is not None and text is not None:
101
+ updated_medias, video_prompts = self.preprocess_medias(medias)
102
+ preprocessed = self.media_processor.preprocess(
103
+ updated_medias, return_tensors=return_tensors)
104
+ text = self.update_raw_text(text, video_prompts)
105
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
106
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
107
+
108
+ if medias is None:
109
+ medias = self._extract_medias_from_messages(messages)
110
+ updated_medias, video_prompts = self.preprocess_medias(medias)
111
+ preprocessed = self.media_processor.preprocess(
112
+ updated_medias, return_tensors=return_tensors)
113
+
114
+ # Generate text if not provided
115
+ if text is None:
116
+ text = self.tokenizer.apply_chat_template(messages, **kwargs)
117
+
118
+ text = self.update_raw_text(text, video_prompts)
119
+
120
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
121
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
122
+
123
+ @staticmethod
124
+ def _extract_medias_from_messages(messages: list[dict]) -> list[dict]:
125
+ """
126
+ Extract media items from messages in a single pass.
127
+
128
+ This is an optimized version that processes messages only once.
129
+ Kept as internal method since external callers should use __call__.
130
+ """
131
+ medias = []
132
+ for msg in messages:
133
+ if msg['role'] != 'user' or not msg.get('content'):
134
+ continue
135
+
136
+ for content_part in msg['content']:
137
+ if not isinstance(content_part, dict):
138
+ continue
139
+
140
+ content_type = content_part.get('type')
141
+ if content_type in ['video_url', 'video']:
142
+ medias.append({
143
+ 'type': 'video',
144
+ 'video': content_part['video_url']['url'],
145
+ 'first_frame_timestamp': 0.0
146
+ })
147
+ elif content_type in ['image_url', 'image']:
148
+ medias.append({
149
+ 'type': 'image',
150
+ 'image': content_part['image_url'],
151
+ })
152
+ return medias
153
+
154
+ def apply_chat_template(self, messages, **kwargs):
155
+ return self.tokenizer.apply_chat_template(messages, **kwargs)
156
+
157
+ def batch_decode(self, *args, **kwargs):
158
+ return self.tokenizer.batch_decode(*args, **kwargs)
159
+
160
+ def decode(self, *args, **kwargs):
161
+ return self.tokenizer.decode(*args, **kwargs)
162
+
163
+ @property
164
+ def model_input_names(self):
165
+ return ['input_ids', 'attention_mask', 'pixel_values', 'grid_thws']
kimi_k25_vision_processing.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for Kimi-K2.5.
2
+ """
3
+
4
+ import json
5
+ from typing import Any, Dict, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from transformers.image_processing_utils import (BaseImageProcessor,
11
+ BatchFeature)
12
+ from transformers.utils import TensorType
13
+
14
+ from .media_utils import (MediaInput, VideoChunkInput, _to_tensor,
15
+ ensure_media_type, get_video_meta, image_to_np,
16
+ navit_patchify, navit_resize_image,
17
+ navit_resize_video, normalize,
18
+ real_sample_fps_and_max_num_frames, timestamp_as_str)
19
+
20
+ try:
21
+ from mecord import VideoReader
22
+ except ImportError:
23
+ VideoReader = None
24
+
25
+
26
+ def resampling(video_bytes: bytes,
27
+ sample_indices: list[int],
28
+ key_indices=None,
29
+ frame_time_info=None,
30
+ num_threads=4) -> str:
31
+ video = VideoReader(video_bytes,
32
+ num_threads=num_threads,
33
+ frame_time_info=frame_time_info,
34
+ key_indices=key_indices)
35
+ # extract target frames
36
+ frames = video[sample_indices]
37
+ frames = [Image.fromarray(frame) for frame in frames]
38
+ return frames
39
+
40
+
41
+ class KimiK25VisionProcessor(BaseImageProcessor):
42
+ model_type = "kimi_k25"
43
+
44
+ def __init__(
45
+ self,
46
+ media_proc_cfg: dict,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.media_proc_cfg = media_proc_cfg
51
+ self.num_frames_per_chunk = media_proc_cfg[
52
+ 'temporal_merge_kernel_size']
53
+
54
+ def media_tokens_calculator(self, media: MediaInput):
55
+ media = ensure_media_type(media)
56
+ ret = self.get_resize_config(media)
57
+ return ret['num_tokens']
58
+
59
+ @classmethod
60
+ def make_chunk_prompt(cls, timestamp_text: str) -> str:
61
+ return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>"
62
+
63
+ def split_video_chunks(self,
64
+ video_url: str | bytes) -> list[list[Image.Image]]:
65
+ # video_url should be base64 str or bytes
66
+ video_spec = get_video_meta(video_url)
67
+ sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps)
68
+ sampled_nframes = max(
69
+ round(video_spec.num_frames * sample_fps / video_spec.fps), 1)
70
+ frame_inds = np.linspace(0, video_spec.num_frames - 1,
71
+ sampled_nframes).round().astype(int)
72
+ frame_inds = frame_inds.tolist()
73
+ sampled_frame_ids = []
74
+ temporal_merge_kernel_size = self.media_proc_cfg[
75
+ "temporal_merge_kernel_size"]
76
+ num_chunks = 0
77
+ chunk_timestamp = []
78
+ for i in range(0, len(frame_inds), temporal_merge_kernel_size):
79
+ sampled_frame_ids.extend(frame_inds[i:i +
80
+ temporal_merge_kernel_size])
81
+ start_time = frame_inds[i] / float(video_spec.fps)
82
+ timestamp_text = timestamp_as_str(
83
+ start_time, self.media_proc_cfg["timestamp_mode"])
84
+ chunk_timestamp.append(timestamp_text)
85
+ num_chunks += 1
86
+
87
+ sampled_frames = resampling(video_url, sampled_frame_ids)
88
+ chunks = []
89
+ for chunk_id in range(num_chunks):
90
+ chunk = sampled_frames[chunk_id *
91
+ temporal_merge_kernel_size:(chunk_id + 1) *
92
+ temporal_merge_kernel_size]
93
+ chunks.append(
94
+ VideoChunkInput(type="video_chunk",
95
+ video_chunk=chunk,
96
+ prompt=self.make_chunk_prompt(
97
+ chunk_timestamp[chunk_id])))
98
+ return chunks
99
+
100
+ def get_resize_config(self, media_input: MediaInput) -> dict:
101
+ if media_input['type'] == 'image':
102
+ w, h = media_input['image'].size
103
+ ret = navit_resize_image(
104
+ w, h, self.media_proc_cfg['patch_size'],
105
+ self.media_proc_cfg['merge_kernel_size'],
106
+ self.media_proc_cfg['in_patch_limit'],
107
+ self.media_proc_cfg['patch_limit_on_one_side'],
108
+ self.media_proc_cfg['fixed_output_tokens'])
109
+ return ret
110
+ elif media_input['type'] == 'video_chunk':
111
+ frame = media_input['video_chunk'][0]
112
+ width, height = frame.size
113
+ num_frames = len(media_input["video_chunk"])
114
+ fps = 1.0
115
+
116
+ sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames(
117
+ media_input["type"],
118
+ self.media_proc_cfg['sample_fps'],
119
+ self.media_proc_cfg['max_num_frames_each_video'],
120
+ )
121
+
122
+ in_patch_limit_each_frame = self.media_proc_cfg[
123
+ 'in_patch_limit_each_frame']
124
+ if in_patch_limit_each_frame is None:
125
+ in_patch_limit_each_frame = self.media_proc_cfg[
126
+ 'in_patch_limit']
127
+
128
+ ret = navit_resize_video(
129
+ width,
130
+ height,
131
+ num_frames,
132
+ fps,
133
+ sample_fps,
134
+ self.media_proc_cfg['patch_size'],
135
+ self.media_proc_cfg['merge_kernel_size'],
136
+ in_patch_limit_each_frame,
137
+ self.media_proc_cfg['patch_limit_on_one_side'],
138
+ self.media_proc_cfg['in_patch_limit_video'],
139
+ max_num_frames_each_video,
140
+ self.media_proc_cfg['fixed_output_tokens'],
141
+ )
142
+ return ret
143
+ else:
144
+ raise ValueError("Unsupported type: {}".format(
145
+ media_input['type']))
146
+
147
+ def resize_image(self, image: Image.Image, new_width: int, new_height: int,
148
+ pad_width: int, pad_height: int) -> np.ndarray:
149
+ image_np = image_to_np(image, (new_width, new_height), "resize")
150
+ image_np = np.pad(
151
+ image_np,
152
+ ((0, pad_height), (0, pad_width), (0, 0)),
153
+ mode="constant",
154
+ constant_values=0,
155
+ )
156
+ return image_np
157
+
158
+ def preprocess(
159
+ self,
160
+ medias: list[MediaInput],
161
+ return_tensors: Optional[Union[str, TensorType]] = None,
162
+ ) -> BatchFeature:
163
+ """
164
+ Preprocess a atom vision input (images/video_chunk) into model-ready tensors.
165
+
166
+ Args:
167
+ medias: List of MediaInput.
168
+ return_tensors: Desired output format ('pt', 'np', 'tf', or None).
169
+
170
+ Returns:
171
+ BatchFeature containing 'pixel_values' and 'grid_thws' tensors.
172
+ """
173
+ if not isinstance(medias, list):
174
+ medias = [medias]
175
+ if medias:
176
+ pixel_values = []
177
+ for item in medias:
178
+ item = ensure_media_type(item)
179
+ resize_config = self.get_resize_config(item)
180
+ new_width, new_height, pad_width, pad_height = resize_config[
181
+ 'new_width'], resize_config['new_height'], resize_config[
182
+ 'pad_width'], resize_config['pad_height']
183
+ if item['type'] == 'image':
184
+ image = item['image']
185
+ image_np = self.resize_image(image, new_width, new_height,
186
+ pad_width, pad_height)
187
+ pixel_values.append(np.expand_dims(image_np, axis=0))
188
+ elif item['type'] == 'video_chunk':
189
+ pixels = []
190
+ for frame in item['video_chunk']:
191
+ frame_np = self.resize_image(frame, new_width,
192
+ new_height, pad_width,
193
+ pad_height)
194
+ pixels.append(frame_np)
195
+ pixel_values.append(np.stack(pixels, axis=0))
196
+ else:
197
+ raise ValueError("Unsupported type: {}".format(
198
+ item['type']))
199
+ normalized_pixel_values = []
200
+ image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std'])
201
+ image_mean = np.array(self.media_proc_cfg['image_mean'])
202
+ for pixels in pixel_values:
203
+ pixels = normalize(pixels, image_mean, image_std_inv)
204
+ pixels_and_thw = navit_patchify(
205
+ pixels,
206
+ self.media_proc_cfg['patch_size'],
207
+ )
208
+ normalized_pixel_values.append(pixels_and_thw)
209
+
210
+ pixel_values = torch.cat([
211
+ _to_tensor(pixel_value['pixel_values'])
212
+ for pixel_value in normalized_pixel_values
213
+ ])
214
+ grid_thws = torch.cat([
215
+ _to_tensor(pixel_value['grid_thw'],
216
+ dtype=torch.int64).unsqueeze(0)
217
+ for pixel_value in normalized_pixel_values
218
+ ])
219
+
220
+ data = {
221
+ 'pixel_values': pixel_values,
222
+ 'grid_thws': grid_thws,
223
+ }
224
+
225
+ else:
226
+ data = {}
227
+
228
+ return BatchFeature(data=data, tensor_type=return_tensors)
229
+
230
+ def __repr__(self):
231
+ return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})"
232
+
233
+ def to_dict(self) -> Dict[str, Any]:
234
+ output = super().to_dict()
235
+ output["media_proc_cfg"] = self.media_proc_cfg
236
+ if "media_processor" in output:
237
+ del output["media_processor"]
238
+ return output
239
+
240
+ @classmethod
241
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
242
+ config = config_dict.copy()
243
+ media_proc_cfg = config.pop("media_proc_cfg", {})
244
+ return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs)
245
+
246
+ def to_json_string(self):
247
+ dictionary = self.to_dict()
248
+ for key, value in dictionary.items():
249
+ if hasattr(value, 'tolist'):
250
+ dictionary[key] = value.tolist()
251
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
media_utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import math
4
+ import os
5
+ from datetime import datetime, timezone
6
+ from typing import List, Literal, Optional, TypedDict
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pydantic import BaseModel, Field
11
+
12
+ try:
13
+ from mecord import VideoReader
14
+ except ImportError:
15
+ VideoReader = None
16
+
17
+
18
+ class VideoSpec(BaseModel):
19
+ media_type: str = Literal['video']
20
+ height: int = Field(..., gt=0, description="video frame height")
21
+ width: int = Field(..., gt=0, description="video frame width")
22
+ num_frames: int = Field(..., gt=0, description="num frames")
23
+ fps: float = Field(..., gt=0, description="average fps")
24
+
25
+ # optional, help to accelerate video reading
26
+ key_indices: list[int] = Field(None, description="key indices")
27
+ frame_time_info: dict = Field(None, description="frame time info")
28
+
29
+
30
+ class ImageInput(TypedDict):
31
+ type: Literal['image']
32
+ image: Image.Image
33
+
34
+
35
+ class VideoChunkInput(TypedDict):
36
+ type: Literal['video_chunk']
37
+ video_chunk: List[Image.Image]
38
+ prompt: Optional[str] = None
39
+
40
+
41
+ MediaInput = ImageInput | VideoChunkInput
42
+
43
+
44
+ def get_video_meta(video_src: bytes | str | os.PathLike,
45
+ accurate: bool = True) -> dict:
46
+ """Get the dimensions of a video."""
47
+ if isinstance(video_src, os.PathLike):
48
+ video_src = str(video_src)
49
+ # if b64 string, decode to bytes
50
+ if isinstance(video_src,
51
+ str) and video_src.startswith('data:video/mp4;base64,'):
52
+ video_src = base64.b64decode(video_src.split(',')[1])
53
+ video = VideoReader(video_src, auto_init=accurate, num_threads=1)
54
+ assert video.num_frames > 0, "Invalid video format."
55
+ assert video.original_width > 0 and video.original_height > 0, (
56
+ "Invalid video format.")
57
+ assert video.avg_fps > 0, "Invalid video format."
58
+ return VideoSpec(media_type='video',
59
+ height=video.original_height,
60
+ width=video.original_width,
61
+ num_frames=video.num_frames,
62
+ fps=video.avg_fps,
63
+ key_indices=video.key_indices,
64
+ frame_time_info=video.frame_time_info)
65
+
66
+
67
+ def timestamp_as_str(timestamp: float,
68
+ timestamp_mode: str = "hh:mm:ss.fff") -> str:
69
+ """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
70
+ if timestamp_mode == "hh:mm:ss.fff":
71
+ return (datetime.fromtimestamp(timestamp,
72
+ tz=timezone.utc).strftime("%H:%M:%S") +
73
+ f".{int((timestamp % 1) * 1000):03d}")
74
+ elif timestamp_mode == "mm:ss.fff":
75
+ return (datetime.fromtimestamp(timestamp,
76
+ tz=timezone.utc).strftime("%M:%S") +
77
+ f".{int((timestamp % 1) * 1000):03d}")
78
+ elif timestamp_mode == "mm:ss":
79
+ return datetime.fromtimestamp(timestamp,
80
+ tz=timezone.utc).strftime("%M:%S")
81
+ else:
82
+ raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
83
+
84
+
85
+ def navit_resize_image(
86
+ width: int,
87
+ height: int,
88
+ patch_size: int,
89
+ merge_kernel_size: int,
90
+ in_patch_limit: int,
91
+ patch_limit_on_one_side: int,
92
+ fixed_output_tokens: int | None,
93
+ ):
94
+ # Apply the patch limits.
95
+ s1 = math.sqrt(
96
+ in_patch_limit /
97
+ (max(1.0, width // patch_size) * max(1.0, height // patch_size)))
98
+ s2 = patch_limit_on_one_side * patch_size / width
99
+ s3 = patch_limit_on_one_side * patch_size / height
100
+ scale = min(1.0, s1, s2, s3)
101
+ new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
102
+ new_w = min(new_w, patch_limit_on_one_side * patch_size)
103
+ new_h = min(new_h, patch_limit_on_one_side * patch_size)
104
+
105
+ # Calculate the padding to make the height and width divisible by the merge kernel size and patch size.
106
+ factor = merge_kernel_size * patch_size
107
+
108
+ pad_height = (factor - new_h % factor) % factor
109
+ pad_width = (factor - new_w % factor) % factor
110
+
111
+ if fixed_output_tokens is not None:
112
+ num_tokens = fixed_output_tokens
113
+ else:
114
+ # Calculate new dimensions after padding and patching
115
+ token_height = (new_h + pad_height) // factor
116
+ token_width = (new_w + pad_width) // factor
117
+
118
+ assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
119
+ f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
120
+ )
121
+ assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
122
+ f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
123
+ )
124
+
125
+ num_tokens = token_height * token_width
126
+ return {
127
+ "num_tokens": num_tokens,
128
+ "new_width": new_w,
129
+ "new_height": new_h,
130
+ "pad_width": pad_width,
131
+ "pad_height": pad_height,
132
+ "sampled_nframes": 1,
133
+ }
134
+
135
+
136
+ def navit_resize_video(
137
+ width: int,
138
+ height: int,
139
+ nframes: int,
140
+ avg_fps: float,
141
+ sample_fps: float,
142
+ patch_size: int,
143
+ merge_kernel_size: int,
144
+ in_patch_limit_each_frame: int,
145
+ patch_limit_on_one_side: int,
146
+ in_patch_limit_total: int | None,
147
+ max_num_frames_each_video: int | None,
148
+ fixed_output_tokens_each_frame: int | None,
149
+ ):
150
+ sample_fps = min(sample_fps, avg_fps)
151
+ # Calculate the number of frames to sample based on target FPS
152
+ sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
153
+ if max_num_frames_each_video is not None:
154
+ sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
155
+
156
+ if in_patch_limit_total is not None:
157
+ in_patch_limit_each_frame = min(
158
+ round(in_patch_limit_total / sampled_nframes),
159
+ in_patch_limit_each_frame)
160
+
161
+ ret = navit_resize_image(
162
+ width,
163
+ height,
164
+ patch_size,
165
+ merge_kernel_size,
166
+ in_patch_limit_each_frame,
167
+ patch_limit_on_one_side,
168
+ fixed_output_tokens_each_frame,
169
+ )
170
+ ret["sampled_nframes"] = sampled_nframes
171
+ return ret
172
+
173
+
174
+ def real_sample_fps_and_max_num_frames(
175
+ type_name: Literal["video", "video_chunk"],
176
+ sample_fps: float,
177
+ max_num_frames_each_video: int | None,
178
+ ) -> tuple[int, int | None]:
179
+ if type_name == "video":
180
+ return sample_fps, max_num_frames_each_video
181
+ elif type_name == "video_chunk":
182
+ max_num_frames_each_video = None
183
+ sample_fps = math.inf
184
+ return sample_fps, max_num_frames_each_video
185
+ else:
186
+ return math.inf, None
187
+
188
+
189
+ def _to_pil(data: str | bytes):
190
+ if isinstance(data, Image.Image):
191
+
192
+ return data.convert("RGB")
193
+ elif isinstance(data, str):
194
+ if data.startswith("data:"):
195
+ raw_base64 = data.split(",")[1]
196
+ return Image.open(io.BytesIO(
197
+ base64.b64decode(raw_base64))).convert("RGB")
198
+ else:
199
+ return Image.open(data).convert("RGB")
200
+ elif isinstance(data, bytes):
201
+ return Image.open(io.BytesIO(data)).convert("RGB")
202
+ else:
203
+ raise ValueError(f"Unsupported data type: {type(data)}")
204
+
205
+
206
+ def ensure_media_type(media: MediaInput) -> MediaInput:
207
+ if media['type'] == 'image':
208
+ media['image'] = _to_pil(media['image'])
209
+ return media
210
+ elif media['type'] == 'video_chunk':
211
+ media['video_chunk'] = [
212
+ _to_pil(frame) for frame in media['video_chunk']
213
+ ]
214
+ return media
215
+ else:
216
+ raise ValueError(f"Unsupported media type: {media['type']}")
217
+
218
+
219
+ def image_to_np(
220
+ image: Image.Image,
221
+ resize_to: tuple[int, int] | None = None,
222
+ mode: str = "resize",
223
+ raise_error_for_ill_resize: bool = True,
224
+ ) -> np.ndarray:
225
+ """Convert an image to a numpy array.
226
+
227
+ Args:
228
+ content: The image to convert.
229
+ resize_to: The size to resize the image to.
230
+ mode: The mode to resize the image to.
231
+ raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
232
+
233
+ Returns:
234
+ A numpy array.
235
+ """
236
+ assert isinstance(image, Image.Image), "image must be a PIL Image"
237
+ if resize_to is not None:
238
+ if mode == "resize":
239
+ image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
240
+
241
+ elif mode == "rescale_and_pad_to_center":
242
+ scale = min(resize_to[0] / image.width,
243
+ resize_to[1] / image.height, 1.0)
244
+ new_width = round(image.width * scale)
245
+ new_height = round(image.height * scale)
246
+ if new_width == 0 or new_height == 0:
247
+ if raise_error_for_ill_resize:
248
+ raise ValueError(
249
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
250
+ )
251
+ else:
252
+ return np.zeros((resize_to[1], resize_to[0], 3),
253
+ dtype=np.uint8)
254
+
255
+ image = image.resize((new_width, new_height),
256
+ resample=Image.Resampling.BICUBIC)
257
+ padding_left = (resize_to[0] - new_width) // 2
258
+ padding_right = resize_to[0] - new_width - padding_left
259
+ padding_top = (resize_to[1] - new_height) // 2
260
+ padding_bottom = resize_to[1] - new_height - padding_top
261
+ image = np.asarray(image)
262
+ image = np.pad(
263
+ image,
264
+ ((padding_top, padding_bottom), (padding_left, padding_right),
265
+ (0, 0)),
266
+ mode="constant",
267
+ constant_values=0,
268
+ )
269
+ assert image.shape == (resize_to[1], resize_to[0], 3)
270
+
271
+ elif mode == "rescale_and_pad_to_rightbottom":
272
+ scale = min(resize_to[0] / image.width,
273
+ resize_to[1] / image.height, 1.0)
274
+ new_width = round(image.width * scale)
275
+ new_height = round(image.height * scale)
276
+ if new_width == 0 or new_height == 0:
277
+ if raise_error_for_ill_resize:
278
+ raise ValueError(
279
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
280
+ )
281
+ else:
282
+ return np.zeros((resize_to[1], resize_to[0], 3),
283
+ dtype=np.uint8)
284
+
285
+ image = image.resize((new_width, new_height),
286
+ resample=Image.Resampling.BICUBIC)
287
+ padding_right = resize_to[0] - new_width
288
+ padding_bottom = resize_to[1] - new_height
289
+ image = np.asarray(image)
290
+ image = np.pad(
291
+ image,
292
+ ((0, padding_bottom), (0, padding_right), (0, 0)),
293
+ mode="constant",
294
+ constant_values=0,
295
+ )
296
+ assert image.shape == (resize_to[1], resize_to[0], 3)
297
+
298
+ else:
299
+ raise ValueError(f"Invalid mode: {mode}")
300
+
301
+ if isinstance(image, Image.Image):
302
+ return np.asarray(image)
303
+ else:
304
+ return image
305
+
306
+
307
+ def navit_patchify(pixel_values: np.ndarray,
308
+ patch_size: int) -> dict[str, np.ndarray]:
309
+ """Reshape the pixel values to a navit shape.
310
+
311
+ Args:
312
+ pixel_values: np.ndarray, shape (t, h, w, c)
313
+ patch_size: int
314
+
315
+ Returns:
316
+ dict[str, np.ndarray]
317
+ - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
318
+ - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
319
+ """
320
+ T, H, W, C = pixel_values.shape
321
+ assert C == 3, "pixel_values must have 3 channels"
322
+
323
+ patches = pixel_values.reshape(T, H // patch_size, patch_size,
324
+ W // patch_size, patch_size, C)
325
+ # (T, H//patch_size, W//patch_size, C, patch_size, patch_size)
326
+ patches = patches.transpose(0, 1, 3, 5, 2, 4)
327
+ patches = patches.reshape(-1, C, patch_size, patch_size)
328
+ grid_thw = np.array([T, H // patch_size, W // patch_size])
329
+ return {"pixel_values": patches, "grid_thw": grid_thw}
330
+
331
+
332
+ def normalize(x: np.ndarray,
333
+ mean,
334
+ std_inv,
335
+ pixels_dtype: np.dtype = np.float32) -> np.ndarray:
336
+ """Normalize the image.
337
+
338
+ Args:
339
+ x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
340
+ mean: The mean of the image.
341
+ std_inv: The inverse of the std of the image.
342
+ pixels_dtype: The dtype of the image.
343
+ Returns:
344
+ The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
345
+ """
346
+ x = (x / 255.0).astype(pixels_dtype)
347
+ x -= mean
348
+ x *= std_inv
349
+ return x
350
+
351
+
352
+ def _to_tensor(data, **kwargs):
353
+ import torch
354
+
355
+ if isinstance(data, np.ndarray):
356
+ return torch.from_numpy(data).to(**kwargs)
357
+ elif isinstance(data, torch.Tensor):
358
+ return data.to(**kwargs)
359
+ elif isinstance(data, list):
360
+ return [_to_tensor(item, **kwargs) for item in data]
361
+ elif isinstance(data, tuple):
362
+ return tuple(_to_tensor(item, **kwargs) for item in data)
363
+ elif isinstance(data, dict):
364
+ return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
365
+ elif data is None:
366
+ return None
367
+ else:
368
+ raise ValueError(f"Unsupported data type: {type(data)}")
model-00002-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1823f86fd14504910f563a89c0b62f2df7d45c4252d8024a7c740464aa6f44f8
3
+ size 4997053752
model-00004-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e41df680614ad1cc82225761f94ba31c14d4a63ceb26696437107dd768e3d2e
3
+ size 4996135968
model-00006-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29c17f0419661491d71da404c0266542baa67914129154dfc6e5621f567e2724
3
+ size 4996136088
model-00007-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce99e6abe45df10cf0f7a4fcdd16ec556601f70392ceba9012bad92e150aebe8
3
+ size 4997468176
model-00008-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a604a4d701f1d8c841f8c886a97d1d1277d5812f2751e578325e04afa31d367c
3
+ size 4996136208
model-00013-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e596730f17df46ce04020e4c321055b2d805f464deea6618c8d94b80faf9c8d2
3
+ size 4997467856
model-00014-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62ba23c5fe2438002591080407abaeb6d76b6b9f9e58f44ca0c2eaaaad5a9d8d
3
+ size 4996136520
model-00015-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75315289ec29c9b9643a22e4b5110c4d6a9872fa5b7657af93581d054dc9ebb9
3
+ size 4997467856
model-00018-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d984a8f018a03f56a8ee3c81b07917ba08699299ece95e1187947a7e770d88c6
3
+ size 4996136520
model-00043-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c1c181dd0731624c186f5146195bd589b1bdf10e4b6f516bade6f4664302911
3
+ size 4996138360
model-00044-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09b857be543ca4da13bd4dd3dac8a4c68a93c00e3eb3f3a013a000a1c6ebb748
3
+ size 4997470656
model-00047-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87c603cbda78c02bd0257a357d66fe173727b4562b40062b79be06adc8843c19
3
+ size 4996138600
model-00048-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c5a8dc76a26ac0b95e72e4472b4cfa8c8da18c873169137e30c4e245ad3b2f1
3
+ size 4997470416
model-00054-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5308ec24d5165a49b2f1d25f9e3ebd6a79b27d03a82ac435aed2c18319f6a67
3
+ size 4997470200
model-00056-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b30acbb89f7a4b68e599a7b1f0f76d6ce64acd8aa1980892a2e62d5fa8e6342
3
+ size 4997470200
model-00058-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a9e78eabc6ea95abdb11cba942eace4844f896852801f8b22c3dadf5ffd3b98
3
+ size 4997470200
model-00060-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7df78bfde7612eb054707e79037a1fa3864155912c6c3fa308cf9940be3fcfe
3
+ size 4997470200
model-00067-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6205cb4b74ab4b39f453d8f2ce9644027f8b977785e9b4df4933b2d0846d8eb0
3
+ size 4996138936
model-00072-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5fcf835ab39c6dd30e67248e56c9cc9823b7464a0fb33cad8bf798facb36efd
3
+ size 4996137624
model-00085-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed5e051983453abc5f3c36aff92ecfa6dc418af457ffb7c7307753b7550b1b5
3
+ size 4997470552
model-00088-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:958215e41271ef632349bd8ca3b44987e0e23df3e8e6c680fdd115bc1e0b2436
3
+ size 4996138696
model-00092-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d395a0a814cff109096da0d74514962d8aff75f54090b961b367e9101c26f32
3
+ size 4996138936
model-00093-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62985d59998c1288f36e5ea1c4302060b96baa529c76920484fc88eacb907073
3
+ size 4997470200
model-00099-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db9a7d0d31c5e55a63bc64170f15cd594cc3f85cff16bde8763a4d6cfcecfb23
3
+ size 4997470200
model-00102-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:122ad83ebde781897522862950e4bcc74e80e0dfa49cff36cacb37fa74dda263
3
+ size 4996138936
model-00107-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d375edd624a38537e942b42eb20d5b895b58fbce2578c76f576d7d4ce52621da
3
+ size 4997470200
model-00108-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b14f35510e4be932d79c60a9c5431e631382b7dfbd6102475721f279c15ef7a
3
+ size 4996138936
model-00110-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb2c88c490f885ec929091fb23602c52e58d5db7bbcd7e06b42726eb9d1622bf
3
+ size 4913062912
model-00118-of-00119.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:874e706179dadd3f530ba932dca28e7fdc9e53c76c48f9b4649eec3ba58cf0e5
3
+ size 3919879576
model.safetensors.index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e8723756225d00dea4ea2b89b4d59bc262f94118a7d8d2e2f5d9e04071fd265
3
+ size 30783853
modeling_deepseek.py ADDED
@@ -0,0 +1,1808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch DeepSeek model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.distributed as dist
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import \
35
+ _prepare_4d_causal_attention_mask
36
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast,
38
+ SequenceClassifierOutputWithPast)
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS,
41
+ is_torch_greater_or_equal_than_1_13)
42
+ from transformers.utils import (add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10, logging,
46
+ replace_return_docstrings)
47
+ from transformers.utils.import_utils import is_torch_fx_available
48
+
49
+ from .configuration_deepseek import DeepseekV3Config
50
+
51
+ if is_flash_attn_2_available():
52
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
53
+ from flash_attn.bert_padding import pad_input # noqa
54
+ from flash_attn.bert_padding import index_first_axis, unpad_input
55
+
56
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57
+ # It means that the function will not be traced through and simply appear as a node in the graph.
58
+ if is_torch_fx_available():
59
+ if not is_torch_greater_or_equal_than_1_13:
60
+ import torch.fx
61
+
62
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(
63
+ _prepare_4d_causal_attention_mask)
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
68
+
69
+
70
+ def _get_unpad_data(attention_mask):
71
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
72
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
73
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
74
+ cu_seqlens = F.pad(
75
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
76
+ return (
77
+ indices,
78
+ cu_seqlens,
79
+ max_seqlen_in_batch,
80
+ )
81
+
82
+
83
+ # code modified from transformers 4.48.3 to amend breaks in newer transformers versions
84
+ def get_usable_length(past_key_value,
85
+ new_seq_length: int,
86
+ layer_idx: Optional[int] = 0) -> int:
87
+ max_length = past_key_value.get_max_cache_shape()
88
+ previous_seq_length = past_key_value.get_seq_length(layer_idx)
89
+ if max_length is not None and max_length > 0 and previous_seq_length + new_seq_length > max_length:
90
+ return max_length - new_seq_length
91
+ return previous_seq_length
92
+
93
+
94
+ class DeepseekV3RMSNorm(nn.Module):
95
+
96
+ def __init__(self, hidden_size, eps=1e-6):
97
+ """
98
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
99
+ """
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size))
102
+ self.variance_epsilon = eps
103
+
104
+ def forward(self, hidden_states):
105
+ input_dtype = hidden_states.dtype
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
108
+ hidden_states = hidden_states * torch.rsqrt(variance +
109
+ self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+
113
+ ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
114
+
115
+
116
+ class DeepseekV3RotaryEmbedding(nn.Module):
117
+
118
+ def __init__(self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None):
123
+ super().__init__()
124
+
125
+ self.dim = dim
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.base = base
128
+ inv_freq = 1.0 / (self.base**(
129
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
130
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
131
+
132
+ # Build here to make `torch.jit.trace` work.
133
+ self._set_cos_sin_cache(
134
+ seq_len=max_position_embeddings,
135
+ device=self.inv_freq.device,
136
+ dtype=torch.get_default_dtype(),
137
+ )
138
+ self.max_seq_len_cached = None
139
+
140
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
141
+ self.max_seq_len_cached = seq_len
142
+ t = torch.arange(self.max_seq_len_cached,
143
+ device=device,
144
+ dtype=self.inv_freq.dtype)
145
+
146
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
147
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
+ emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer("cos_cached",
150
+ emb.cos().to(dtype),
151
+ persistent=False)
152
+ self.register_buffer("sin_cached",
153
+ emb.sin().to(dtype),
154
+ persistent=False)
155
+
156
+ def forward(self, x, seq_len=None):
157
+ # x: [bs, num_attention_heads, seq_len, head_size]
158
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
159
+ self._set_cos_sin_cache(seq_len=seq_len,
160
+ device=x.device,
161
+ dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3
170
+ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
171
+ """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ max_position_embeddings=2048,
177
+ base=10000,
178
+ device=None,
179
+ scaling_factor=1.0,
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(self.max_seq_len_cached,
187
+ device=device,
188
+ dtype=self.inv_freq.dtype)
189
+ t = t / self.scaling_factor
190
+
191
+ freqs = torch.outer(t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached",
195
+ emb.cos().to(dtype),
196
+ persistent=False)
197
+ self.register_buffer("sin_cached",
198
+ emb.sin().to(dtype),
199
+ persistent=False)
200
+
201
+
202
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3
203
+ class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
204
+ """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
205
+
206
+ def __init__(
207
+ self,
208
+ dim,
209
+ max_position_embeddings=2048,
210
+ base=10000,
211
+ device=None,
212
+ scaling_factor=1.0,
213
+ ):
214
+ self.scaling_factor = scaling_factor
215
+ super().__init__(dim, max_position_embeddings, base, device)
216
+
217
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
218
+ self.max_seq_len_cached = seq_len
219
+
220
+ if seq_len > self.max_position_embeddings:
221
+ base = self.base * ((self.scaling_factor * seq_len /
222
+ self.max_position_embeddings) -
223
+ (self.scaling_factor - 1))**(self.dim /
224
+ (self.dim - 2))
225
+ inv_freq = 1.0 / (base**(
226
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
227
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
228
+
229
+ t = torch.arange(self.max_seq_len_cached,
230
+ device=device,
231
+ dtype=self.inv_freq.dtype)
232
+
233
+ freqs = torch.outer(t, self.inv_freq)
234
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
235
+ emb = torch.cat((freqs, freqs), dim=-1)
236
+ self.register_buffer("cos_cached",
237
+ emb.cos().to(dtype),
238
+ persistent=False)
239
+ self.register_buffer("sin_cached",
240
+ emb.sin().to(dtype),
241
+ persistent=False)
242
+
243
+
244
+ # Inverse dim formula to find dim based on number of rotations
245
+ def yarn_find_correction_dim(num_rotations,
246
+ dim,
247
+ base=10000,
248
+ max_position_embeddings=2048):
249
+ return (dim * math.log(max_position_embeddings /
250
+ (num_rotations * 2 * math.pi))) / (2 *
251
+ math.log(base))
252
+
253
+
254
+ # Find dim range bounds based on rotations
255
+ def yarn_find_correction_range(low_rot,
256
+ high_rot,
257
+ dim,
258
+ base=10000,
259
+ max_position_embeddings=2048):
260
+ low = math.floor(
261
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
262
+ high = math.ceil(
263
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
264
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
265
+
266
+
267
+ def yarn_get_mscale(scale=1, mscale=1):
268
+ if scale <= 1:
269
+ return 1.0
270
+ return 0.1 * mscale * math.log(scale) + 1.0
271
+
272
+
273
+ def yarn_linear_ramp_mask(min, max, dim):
274
+ if min == max:
275
+ max += 0.001 # Prevent singularity
276
+
277
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
278
+ ramp_func = torch.clamp(linear_func, 0, 1)
279
+ return ramp_func
280
+
281
+
282
+ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
283
+
284
+ def __init__(
285
+ self,
286
+ dim,
287
+ max_position_embeddings=2048,
288
+ base=10000,
289
+ device=None,
290
+ scaling_factor=1.0,
291
+ original_max_position_embeddings=4096,
292
+ beta_fast=32,
293
+ beta_slow=1,
294
+ mscale=1,
295
+ mscale_all_dim=0,
296
+ ):
297
+ self.scaling_factor = scaling_factor
298
+ self.original_max_position_embeddings = original_max_position_embeddings
299
+ self.beta_fast = beta_fast
300
+ self.beta_slow = beta_slow
301
+ self.mscale = mscale
302
+ self.mscale_all_dim = mscale_all_dim
303
+ super().__init__(dim, max_position_embeddings, base, device)
304
+
305
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
306
+ self.max_seq_len_cached = seq_len
307
+ dim = self.dim
308
+
309
+ freq_extra = 1.0 / (self.base**(
310
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
311
+ freq_inter = 1.0 / (self.scaling_factor * self.base**(
312
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
313
+
314
+ low, high = yarn_find_correction_range(
315
+ self.beta_fast,
316
+ self.beta_slow,
317
+ dim,
318
+ self.base,
319
+ self.original_max_position_embeddings,
320
+ )
321
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
322
+ device=device, dtype=torch.float32)
323
+ inv_freq = freq_inter * (1 -
324
+ inv_freq_mask) + freq_extra * inv_freq_mask
325
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
326
+
327
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
328
+
329
+ freqs = torch.outer(t, inv_freq)
330
+
331
+ _mscale = float(
332
+ yarn_get_mscale(self.scaling_factor, self.mscale) /
333
+ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
334
+
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype),
337
+ persistent=False)
338
+ self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype),
339
+ persistent=False)
340
+
341
+
342
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
343
+ def rotate_half(x):
344
+ """Rotates half the hidden dims of the input."""
345
+ x1 = x[..., :x.shape[-1] // 2]
346
+ x2 = x[..., x.shape[-1] // 2:]
347
+ return torch.cat((-x2, x1), dim=-1)
348
+
349
+
350
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
351
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
352
+ """Applies Rotary Position Embedding to the query and key tensors.
353
+
354
+ Args:
355
+ q (`torch.Tensor`): The query tensor.
356
+ k (`torch.Tensor`): The key tensor.
357
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
358
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
359
+ position_ids (`torch.Tensor`):
360
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
361
+ used to pass offsetted position ids when working with a KV-cache.
362
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
363
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
364
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
365
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
366
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
367
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
368
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
369
+ Returns:
370
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
371
+ """
372
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
373
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
374
+
375
+ b, h, s, d = q.shape
376
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
377
+
378
+ b, h, s, d = k.shape
379
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
380
+
381
+ q_embed = (q * cos) + (rotate_half(q) * sin)
382
+ k_embed = (k * cos) + (rotate_half(k) * sin)
383
+ return q_embed, k_embed
384
+
385
+
386
+ class DeepseekV3MLP(nn.Module):
387
+
388
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
389
+ super().__init__()
390
+ self.config = config
391
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
392
+ self.intermediate_size = (config.intermediate_size if intermediate_size
393
+ is None else intermediate_size)
394
+
395
+ self.gate_proj = nn.Linear(self.hidden_size,
396
+ self.intermediate_size,
397
+ bias=False)
398
+ self.up_proj = nn.Linear(self.hidden_size,
399
+ self.intermediate_size,
400
+ bias=False)
401
+ self.down_proj = nn.Linear(self.intermediate_size,
402
+ self.hidden_size,
403
+ bias=False)
404
+ self.act_fn = ACT2FN[config.hidden_act]
405
+
406
+ def forward(self, x):
407
+ down_proj = self.down_proj(
408
+ self.act_fn(self.gate_proj(x)) * self.up_proj(x))
409
+ return down_proj
410
+
411
+
412
+ class MoEGate(nn.Module):
413
+
414
+ def __init__(self, config):
415
+ super().__init__()
416
+ self.config = config
417
+ self.top_k = config.num_experts_per_tok
418
+ self.n_routed_experts = config.n_routed_experts
419
+ self.routed_scaling_factor = config.routed_scaling_factor
420
+ self.scoring_func = config.scoring_func
421
+ self.seq_aux = config.seq_aux
422
+ self.topk_method = config.topk_method
423
+ self.n_group = config.n_group
424
+ self.topk_group = config.topk_group
425
+
426
+ # topk selection algorithm
427
+ self.norm_topk_prob = config.norm_topk_prob
428
+ self.gating_dim = config.hidden_size
429
+ self.weight = nn.Parameter(
430
+ torch.empty((self.n_routed_experts, self.gating_dim)))
431
+ if self.topk_method == "noaux_tc":
432
+ self.e_score_correction_bias = nn.Parameter(
433
+ torch.empty((self.n_routed_experts)))
434
+ self.reset_parameters()
435
+
436
+ def reset_parameters(self) -> None:
437
+ import torch.nn.init as init
438
+
439
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
440
+
441
+ def forward(self, hidden_states):
442
+ bsz, seq_len, h = hidden_states.shape
443
+ ### compute gating score
444
+ hidden_states = hidden_states.view(-1, h)
445
+ logits = F.linear(hidden_states.type(torch.float32),
446
+ self.weight.type(torch.float32), None)
447
+ if self.scoring_func == "sigmoid":
448
+ scores = logits.sigmoid()
449
+ else:
450
+ raise NotImplementedError(
451
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
452
+ )
453
+
454
+ ### select top-k experts
455
+ if self.topk_method == "noaux_tc":
456
+ assert not self.training
457
+ scores_for_choice = scores.view(
458
+ bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
459
+ group_scores = (scores_for_choice.view(
460
+ bsz * seq_len, self.n_group,
461
+ -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
462
+ group_idx = torch.topk(group_scores,
463
+ k=self.topk_group,
464
+ dim=-1,
465
+ sorted=False)[1] # [n, top_k_group]
466
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
467
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
468
+ score_mask = (group_mask.unsqueeze(-1).expand(
469
+ bsz * seq_len, self.n_group,
470
+ self.n_routed_experts // self.n_group).reshape(
471
+ bsz * seq_len, -1)) # [n, e]
472
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
473
+ 0.0) # [n, e]
474
+ _, topk_idx = torch.topk(tmp_scores,
475
+ k=self.top_k,
476
+ dim=-1,
477
+ sorted=False)
478
+ topk_weight = scores.gather(1, topk_idx)
479
+ else:
480
+ raise NotImplementedError(
481
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
482
+ )
483
+
484
+ ### norm gate to sum 1
485
+ if self.top_k > 1 and self.norm_topk_prob:
486
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
487
+ topk_weight = topk_weight / denominator
488
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
489
+
490
+ return topk_idx, topk_weight
491
+
492
+
493
+ class DeepseekV3MoE(nn.Module):
494
+ """
495
+ A mixed expert module containing shared experts.
496
+ """
497
+
498
+ def __init__(self, config):
499
+ super().__init__()
500
+ self.config = config
501
+ self.num_experts_per_tok = config.num_experts_per_tok
502
+
503
+ if hasattr(config, "ep_size") and config.ep_size > 1:
504
+ assert config.ep_size == dist.get_world_size()
505
+ self.ep_size = config.ep_size
506
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
507
+ self.ep_rank = dist.get_rank()
508
+ self.experts = nn.ModuleList([
509
+ (DeepseekV3MLP(config,
510
+ intermediate_size=config.moe_intermediate_size)
511
+ if i >= self.ep_rank * self.experts_per_rank
512
+ and i < (self.ep_rank + 1) * self.experts_per_rank else None)
513
+ for i in range(config.n_routed_experts)
514
+ ])
515
+ else:
516
+ self.ep_size = 1
517
+ self.experts_per_rank = config.n_routed_experts
518
+ self.ep_rank = 0
519
+ self.experts = nn.ModuleList([
520
+ DeepseekV3MLP(config,
521
+ intermediate_size=config.moe_intermediate_size)
522
+ for i in range(config.n_routed_experts)
523
+ ])
524
+ self.gate = MoEGate(config)
525
+ if config.n_shared_experts is not None:
526
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
527
+ self.shared_experts = DeepseekV3MLP(
528
+ config=config, intermediate_size=intermediate_size)
529
+
530
+ def forward(self, hidden_states):
531
+ identity = hidden_states
532
+ orig_shape = hidden_states.shape
533
+ topk_idx, topk_weight = self.gate(hidden_states)
534
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
535
+ flat_topk_idx = topk_idx.view(-1)
536
+ if not self.training:
537
+ y = self.moe_infer(hidden_states, topk_idx,
538
+ topk_weight).view(*orig_shape)
539
+ if self.config.n_shared_experts is not None:
540
+ y = y + self.shared_experts(identity)
541
+ return y
542
+
543
+ @torch.no_grad()
544
+ def moe_infer(self, x, topk_ids, topk_weight):
545
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
546
+ cnts.scatter_(1, topk_ids, 1)
547
+ tokens_per_expert = cnts.sum(dim=0)
548
+ idxs = topk_ids.view(-1).argsort()
549
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
550
+ sorted_tokens_shape = sorted_tokens.shape
551
+ if self.ep_size > 1:
552
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
553
+ -1).sum(dim=1)
554
+ tokens_per_expert_group = tokens_per_expert.new_empty(
555
+ tokens_per_expert.shape[0])
556
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
557
+ output_splits = (tokens_per_expert_group.view(
558
+ self.ep_size, -1).sum(1).cpu().numpy().tolist())
559
+ gathered_tokens = sorted_tokens.new_empty(
560
+ tokens_per_expert_group.sum(dim=0).cpu().item(),
561
+ sorted_tokens.shape[1])
562
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
563
+ dist.all_to_all(
564
+ list(gathered_tokens.split(output_splits)),
565
+ list(sorted_tokens.split(input_split_sizes)),
566
+ )
567
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
568
+ self.ep_size, self.experts_per_rank).sum(dim=0)
569
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0], ),
570
+ dtype=np.int32)
571
+ s = 0
572
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
573
+ gatherd_idxs[s:s + k] = i % self.experts_per_rank
574
+ s += k
575
+ gatherd_idxs = gatherd_idxs.argsort()
576
+ sorted_tokens = gathered_tokens[gatherd_idxs]
577
+ tokens_per_expert = tokens_per_expert_post_gather
578
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
579
+
580
+ outputs = []
581
+ start_idx = 0
582
+ for i, num_tokens in enumerate(tokens_per_expert):
583
+ end_idx = start_idx + num_tokens
584
+ if num_tokens == 0:
585
+ continue
586
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
587
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
588
+ expert_out = expert(tokens_for_this_expert)
589
+ outputs.append(expert_out)
590
+ start_idx = end_idx
591
+
592
+ outs = torch.cat(outputs,
593
+ dim=0) if len(outputs) else sorted_tokens.new_empty(0)
594
+ if self.ep_size > 1:
595
+ new_x = torch.empty_like(outs)
596
+ new_x[gatherd_idxs] = outs
597
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
598
+ dist.all_to_all(
599
+ list(gathered_tokens.split(input_split_sizes)),
600
+ list(new_x.split(output_splits)),
601
+ )
602
+ outs = gathered_tokens
603
+
604
+ new_x = torch.empty_like(outs)
605
+ new_x[idxs] = outs
606
+ final_out = (new_x.view(
607
+ *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
608
+ topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
609
+ return final_out
610
+
611
+
612
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
613
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
614
+ """
615
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
616
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
617
+ """
618
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
619
+ if n_rep == 1:
620
+ return hidden_states
621
+ hidden_states = hidden_states[:, :,
622
+ None, :, :].expand(batch,
623
+ num_key_value_heads,
624
+ n_rep, slen, head_dim)
625
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
626
+ head_dim)
627
+
628
+
629
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3
630
+ class DeepseekV3Attention(nn.Module):
631
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
632
+
633
+ def __init__(self,
634
+ config: DeepseekV3Config,
635
+ layer_idx: Optional[int] = None):
636
+ super().__init__()
637
+ self.config = config
638
+ self.layer_idx = layer_idx
639
+ if layer_idx is None:
640
+ logger.warning_once(
641
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
642
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
643
+ "when creating this class.")
644
+
645
+ self.attention_dropout = config.attention_dropout
646
+ self.hidden_size = config.hidden_size
647
+ self.num_heads = config.num_attention_heads
648
+
649
+ self.max_position_embeddings = config.max_position_embeddings
650
+ self.rope_theta = config.rope_theta
651
+ self.q_lora_rank = config.q_lora_rank
652
+ self.qk_rope_head_dim = config.qk_rope_head_dim
653
+ self.kv_lora_rank = config.kv_lora_rank
654
+ self.v_head_dim = config.v_head_dim
655
+ self.qk_nope_head_dim = config.qk_nope_head_dim
656
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
657
+
658
+ self.is_causal = True
659
+
660
+ if self.q_lora_rank is None:
661
+ self.q_proj = nn.Linear(self.hidden_size,
662
+ self.num_heads * self.q_head_dim,
663
+ bias=False)
664
+ else:
665
+ self.q_a_proj = nn.Linear(self.hidden_size,
666
+ config.q_lora_rank,
667
+ bias=config.attention_bias)
668
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
669
+ self.q_b_proj = nn.Linear(config.q_lora_rank,
670
+ self.num_heads * self.q_head_dim,
671
+ bias=False)
672
+
673
+ self.kv_a_proj_with_mqa = nn.Linear(
674
+ self.hidden_size,
675
+ config.kv_lora_rank + config.qk_rope_head_dim,
676
+ bias=config.attention_bias,
677
+ )
678
+ self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
679
+ self.kv_b_proj = nn.Linear(
680
+ config.kv_lora_rank,
681
+ self.num_heads *
682
+ (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
683
+ bias=False,
684
+ )
685
+
686
+ self.o_proj = nn.Linear(
687
+ self.num_heads * self.v_head_dim,
688
+ self.hidden_size,
689
+ bias=config.attention_bias,
690
+ )
691
+ self._init_rope()
692
+
693
+ self.softmax_scale = self.q_head_dim**(-0.5)
694
+ if self.config.rope_scaling is not None:
695
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
696
+ scaling_factor = self.config.rope_scaling["factor"]
697
+ if mscale_all_dim:
698
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
699
+ self.softmax_scale = self.softmax_scale * mscale * mscale
700
+
701
+ def _init_rope(self):
702
+ if self.config.rope_scaling is None:
703
+ self.rotary_emb = DeepseekV3RotaryEmbedding(
704
+ self.qk_rope_head_dim,
705
+ max_position_embeddings=self.max_position_embeddings,
706
+ base=self.rope_theta,
707
+ )
708
+ else:
709
+ scaling_type = self.config.rope_scaling["type"]
710
+ scaling_factor = self.config.rope_scaling["factor"]
711
+ if scaling_type == "linear":
712
+ self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
713
+ self.qk_rope_head_dim,
714
+ max_position_embeddings=self.max_position_embeddings,
715
+ scaling_factor=scaling_factor,
716
+ base=self.rope_theta,
717
+ )
718
+ elif scaling_type == "dynamic":
719
+ self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
720
+ self.qk_rope_head_dim,
721
+ max_position_embeddings=self.max_position_embeddings,
722
+ scaling_factor=scaling_factor,
723
+ base=self.rope_theta,
724
+ )
725
+ elif scaling_type == "yarn":
726
+ kwargs = {
727
+ key: self.config.rope_scaling[key]
728
+ for key in [
729
+ "original_max_position_embeddings",
730
+ "beta_fast",
731
+ "beta_slow",
732
+ "mscale",
733
+ "mscale_all_dim",
734
+ ] if key in self.config.rope_scaling
735
+ }
736
+ self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
737
+ self.qk_rope_head_dim,
738
+ max_position_embeddings=self.max_position_embeddings,
739
+ scaling_factor=scaling_factor,
740
+ base=self.rope_theta,
741
+ **kwargs,
742
+ )
743
+ else:
744
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
745
+
746
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
747
+ return (tensor.view(bsz, seq_len, self.num_heads,
748
+ self.v_head_dim).transpose(1, 2).contiguous())
749
+
750
+ def forward(
751
+ self,
752
+ hidden_states: torch.Tensor,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ position_ids: Optional[torch.LongTensor] = None,
755
+ past_key_value: Optional[Cache] = None,
756
+ output_attentions: bool = False,
757
+ use_cache: bool = False,
758
+ **kwargs,
759
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
760
+ Optional[Tuple[torch.Tensor]]]:
761
+ if "padding_mask" in kwargs:
762
+ warnings.warn(
763
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
764
+ )
765
+ bsz, q_len, _ = hidden_states.size()
766
+
767
+ if self.q_lora_rank is None:
768
+ q = self.q_proj(hidden_states)
769
+ else:
770
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
771
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
772
+ q_nope, q_pe = torch.split(
773
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
774
+
775
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
776
+ compressed_kv, k_pe = torch.split(
777
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
778
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
779
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
780
+ bsz, q_len, self.num_heads,
781
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
782
+
783
+ k_nope, value_states = torch.split(
784
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
785
+ kv_seq_len = value_states.shape[-2]
786
+ if past_key_value is not None:
787
+ if self.layer_idx is None:
788
+ raise ValueError(
789
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
790
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
791
+ "with a layer index.")
792
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
793
+ self.layer_idx)
794
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
795
+
796
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
797
+
798
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
799
+ self.q_head_dim)
800
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
801
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
802
+
803
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
804
+ self.q_head_dim)
805
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
806
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
807
+ if past_key_value is not None:
808
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
809
+ key_states, value_states = past_key_value.update(
810
+ key_states, value_states, self.layer_idx, cache_kwargs)
811
+
812
+ attn_weights = (
813
+ torch.matmul(query_states, key_states.transpose(2, 3)) *
814
+ self.softmax_scale)
815
+
816
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
817
+ raise ValueError(
818
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
819
+ f" {attn_weights.size()}")
820
+ assert attention_mask is not None
821
+ if attention_mask is not None:
822
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
823
+ raise ValueError(
824
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
825
+ )
826
+ attn_weights = attn_weights + attention_mask
827
+
828
+ # upcast attention to fp32
829
+ attn_weights = nn.functional.softmax(attn_weights,
830
+ dim=-1,
831
+ dtype=torch.float32).to(
832
+ query_states.dtype)
833
+ attn_weights = nn.functional.dropout(attn_weights,
834
+ p=self.attention_dropout,
835
+ training=self.training)
836
+ attn_output = torch.matmul(attn_weights, value_states)
837
+
838
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
839
+ raise ValueError(
840
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
841
+ f" {attn_output.size()}")
842
+
843
+ attn_output = attn_output.transpose(1, 2).contiguous()
844
+
845
+ attn_output = attn_output.reshape(bsz, q_len,
846
+ self.num_heads * self.v_head_dim)
847
+
848
+ attn_output = self.o_proj(attn_output)
849
+
850
+ if not output_attentions:
851
+ attn_weights = None
852
+
853
+ return attn_output, attn_weights, past_key_value
854
+
855
+
856
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
857
+ class DeepseekV3FlashAttention2(DeepseekV3Attention):
858
+ """
859
+ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
860
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
861
+ flash attention and deal with padding tokens in case the input contains any of them.
862
+ """
863
+
864
+ def __init__(self, *args, **kwargs):
865
+ super().__init__(*args, **kwargs)
866
+
867
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
868
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
869
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
870
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
871
+ )
872
+
873
+ def forward(
874
+ self,
875
+ hidden_states: torch.Tensor,
876
+ attention_mask: Optional[torch.LongTensor] = None,
877
+ position_ids: Optional[torch.LongTensor] = None,
878
+ past_key_value: Optional[Cache] = None,
879
+ output_attentions: bool = False,
880
+ use_cache: bool = False,
881
+ **kwargs,
882
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
883
+ Optional[Tuple[torch.Tensor]]]:
884
+ # DeepseekV3FlashAttention2 attention does not support output_attentions
885
+ if "padding_mask" in kwargs:
886
+ warnings.warn(
887
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
888
+ )
889
+
890
+ # overwrite attention_mask with padding_mask
891
+ attention_mask = kwargs.pop("padding_mask")
892
+
893
+ output_attentions = False
894
+
895
+ bsz, q_len, _ = hidden_states.size()
896
+
897
+ if self.q_lora_rank is None:
898
+ q = self.q_proj(hidden_states)
899
+ else:
900
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
901
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
902
+ q_nope, q_pe = torch.split(
903
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
904
+
905
+ # Flash attention requires the input to have the shape
906
+ # batch_size x seq_length x head_dim x hidden_dim
907
+ # therefore we just need to keep the original shape
908
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
909
+ compressed_kv, k_pe = torch.split(
910
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
911
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
912
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
913
+ bsz, q_len, self.num_heads,
914
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
915
+
916
+ k_nope, value_states = torch.split(
917
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
918
+ kv_seq_len = value_states.shape[-2]
919
+
920
+ kv_seq_len = value_states.shape[-2]
921
+ if past_key_value is not None:
922
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
923
+ self.layer_idx)
924
+
925
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
926
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
927
+
928
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
929
+ self.q_head_dim)
930
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
931
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
932
+
933
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
934
+ self.q_head_dim)
935
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
936
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
937
+
938
+ if self.q_head_dim != self.v_head_dim:
939
+ value_states = F.pad(value_states,
940
+ [0, self.q_head_dim - self.v_head_dim])
941
+
942
+ if past_key_value is not None:
943
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
944
+ key_states, value_states = past_key_value.update(
945
+ key_states, value_states, self.layer_idx, cache_kwargs)
946
+
947
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
948
+ # to be able to avoid many of these transpose/reshape/view.
949
+ query_states = query_states.transpose(1, 2)
950
+ key_states = key_states.transpose(1, 2)
951
+ value_states = value_states.transpose(1, 2)
952
+
953
+ dropout_rate = self.attention_dropout if self.training else 0.0
954
+
955
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
956
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
957
+ # cast them back in the correct dtype just to be sure everything works as expected.
958
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
959
+ # in fp32. (DeepseekV3RMSNorm handles it correctly)
960
+
961
+ input_dtype = query_states.dtype
962
+ if input_dtype == torch.float32:
963
+ # Handle the case where the model is quantized
964
+ if hasattr(self.config, "_pre_quantization_dtype"):
965
+ target_dtype = self.config._pre_quantization_dtype
966
+ elif torch.is_autocast_enabled():
967
+ target_dtype = torch.get_autocast_gpu_dtype()
968
+ else:
969
+ target_dtype = (self.q_proj.weight.dtype if self.q_lora_rank
970
+ is None else self.q_a_proj.weight.dtype)
971
+
972
+ logger.warning_once(
973
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
974
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
975
+ f" {target_dtype}.")
976
+
977
+ query_states = query_states.to(target_dtype)
978
+ key_states = key_states.to(target_dtype)
979
+ value_states = value_states.to(target_dtype)
980
+
981
+ attn_output = self._flash_attention_forward(
982
+ query_states,
983
+ key_states,
984
+ value_states,
985
+ attention_mask,
986
+ q_len,
987
+ dropout=dropout_rate,
988
+ softmax_scale=self.softmax_scale,
989
+ )
990
+ if self.q_head_dim != self.v_head_dim:
991
+ attn_output = attn_output[:, :, :, :self.v_head_dim]
992
+
993
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads *
994
+ self.v_head_dim).contiguous()
995
+ attn_output = self.o_proj(attn_output)
996
+
997
+ if not output_attentions:
998
+ attn_weights = None
999
+
1000
+ return attn_output, attn_weights, past_key_value
1001
+
1002
+ def _flash_attention_forward(
1003
+ self,
1004
+ query_states,
1005
+ key_states,
1006
+ value_states,
1007
+ attention_mask,
1008
+ query_length,
1009
+ dropout=0.0,
1010
+ softmax_scale=None,
1011
+ ):
1012
+ """
1013
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1014
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1015
+
1016
+ Args:
1017
+ query_states (`torch.Tensor`):
1018
+ Input query states to be passed to Flash Attention API
1019
+ key_states (`torch.Tensor`):
1020
+ Input key states to be passed to Flash Attention API
1021
+ value_states (`torch.Tensor`):
1022
+ Input value states to be passed to Flash Attention API
1023
+ attention_mask (`torch.Tensor`):
1024
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1025
+ position of padding tokens and 1 for the position of non-padding tokens.
1026
+ dropout (`int`, *optional*):
1027
+ Attention dropout
1028
+ softmax_scale (`float`, *optional*):
1029
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1030
+ """
1031
+ if not self._flash_attn_uses_top_left_mask:
1032
+ causal = self.is_causal
1033
+ else:
1034
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
1035
+ causal = self.is_causal and query_length != 1
1036
+
1037
+ # Contains at least one padding token in the sequence
1038
+ if attention_mask is not None:
1039
+ batch_size = query_states.shape[0]
1040
+ (
1041
+ query_states,
1042
+ key_states,
1043
+ value_states,
1044
+ indices_q,
1045
+ cu_seq_lens,
1046
+ max_seq_lens,
1047
+ ) = self._upad_input(query_states, key_states, value_states,
1048
+ attention_mask, query_length)
1049
+
1050
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1051
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1052
+
1053
+ attn_output_unpad = flash_attn_varlen_func(
1054
+ query_states,
1055
+ key_states,
1056
+ value_states,
1057
+ cu_seqlens_q=cu_seqlens_q,
1058
+ cu_seqlens_k=cu_seqlens_k,
1059
+ max_seqlen_q=max_seqlen_in_batch_q,
1060
+ max_seqlen_k=max_seqlen_in_batch_k,
1061
+ dropout_p=dropout,
1062
+ softmax_scale=softmax_scale,
1063
+ causal=causal,
1064
+ )
1065
+
1066
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
1067
+ query_length)
1068
+ else:
1069
+ attn_output = flash_attn_func(
1070
+ query_states,
1071
+ key_states,
1072
+ value_states,
1073
+ dropout,
1074
+ softmax_scale=softmax_scale,
1075
+ causal=causal,
1076
+ )
1077
+
1078
+ return attn_output
1079
+
1080
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
1081
+ query_length):
1082
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
1083
+ attention_mask)
1084
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1085
+
1086
+ key_layer = index_first_axis(
1087
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1088
+ head_dim),
1089
+ indices_k,
1090
+ )
1091
+ value_layer = index_first_axis(
1092
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1093
+ head_dim),
1094
+ indices_k,
1095
+ )
1096
+ if query_length == kv_seq_len:
1097
+ query_layer = index_first_axis(
1098
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
1099
+ head_dim),
1100
+ indices_k,
1101
+ )
1102
+ cu_seqlens_q = cu_seqlens_k
1103
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1104
+ indices_q = indices_k
1105
+ elif query_length == 1:
1106
+ max_seqlen_in_batch_q = 1
1107
+ cu_seqlens_q = torch.arange(
1108
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1109
+ ) # There is a memcpy here, that is very bad.
1110
+ indices_q = cu_seqlens_q[:-1]
1111
+ query_layer = query_layer.squeeze(1)
1112
+ else:
1113
+ # The -q_len: slice assumes left padding.
1114
+ attention_mask = attention_mask[:, -query_length:]
1115
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1116
+ query_layer, attention_mask)
1117
+
1118
+ return (
1119
+ query_layer,
1120
+ key_layer,
1121
+ value_layer,
1122
+ indices_q,
1123
+ (cu_seqlens_q, cu_seqlens_k),
1124
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1125
+ )
1126
+
1127
+
1128
+ ATTENTION_CLASSES = {
1129
+ "eager": DeepseekV3Attention,
1130
+ "flash_attention_2": DeepseekV3FlashAttention2,
1131
+ }
1132
+
1133
+
1134
+ class DeepseekV3DecoderLayer(nn.Module):
1135
+
1136
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
1137
+ super().__init__()
1138
+ self.hidden_size = config.hidden_size
1139
+
1140
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1141
+ config=config, layer_idx=layer_idx)
1142
+
1143
+ self.mlp = (DeepseekV3MoE(config) if
1144
+ (config.n_routed_experts is not None
1145
+ and layer_idx >= config.first_k_dense_replace
1146
+ and layer_idx % config.moe_layer_freq == 0) else
1147
+ DeepseekV3MLP(config))
1148
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size,
1149
+ eps=config.rms_norm_eps)
1150
+ self.post_attention_layernorm = DeepseekV3RMSNorm(
1151
+ config.hidden_size, eps=config.rms_norm_eps)
1152
+
1153
+ def forward(
1154
+ self,
1155
+ hidden_states: torch.Tensor,
1156
+ attention_mask: Optional[torch.Tensor] = None,
1157
+ position_ids: Optional[torch.LongTensor] = None,
1158
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1159
+ output_attentions: Optional[bool] = False,
1160
+ use_cache: Optional[bool] = False,
1161
+ **kwargs,
1162
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
1163
+ torch.FloatTensor]]]:
1164
+ """
1165
+ Args:
1166
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1167
+ attention_mask (`torch.FloatTensor`, *optional*):
1168
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1169
+ query_sequence_length, key_sequence_length)` if default attention is used.
1170
+ output_attentions (`bool`, *optional*):
1171
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1172
+ returned tensors for more detail.
1173
+ use_cache (`bool`, *optional*):
1174
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1175
+ (see `past_key_values`).
1176
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1177
+ """
1178
+ if "padding_mask" in kwargs:
1179
+ warnings.warn(
1180
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1181
+ )
1182
+ residual = hidden_states
1183
+
1184
+ hidden_states = self.input_layernorm(hidden_states)
1185
+
1186
+ # Self Attention
1187
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1188
+ hidden_states=hidden_states,
1189
+ attention_mask=attention_mask,
1190
+ position_ids=position_ids,
1191
+ past_key_value=past_key_value,
1192
+ output_attentions=output_attentions,
1193
+ use_cache=use_cache,
1194
+ **kwargs,
1195
+ )
1196
+ hidden_states = residual + hidden_states
1197
+
1198
+ # Fully Connected
1199
+ residual = hidden_states
1200
+ hidden_states = self.post_attention_layernorm(hidden_states)
1201
+ hidden_states = self.mlp(hidden_states)
1202
+ hidden_states = residual + hidden_states
1203
+
1204
+ outputs = (hidden_states, )
1205
+
1206
+ if output_attentions:
1207
+ outputs += (self_attn_weights, )
1208
+
1209
+ if use_cache:
1210
+ outputs += (present_key_value, )
1211
+
1212
+ return outputs
1213
+
1214
+
1215
+ DeepseekV3_START_DOCSTRING = r"""
1216
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1217
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1218
+ etc.)
1219
+
1220
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1221
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1222
+ and behavior.
1223
+
1224
+ Parameters:
1225
+ config ([`DeepseekV3Config`]):
1226
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1227
+ load the weights associated with the model, only the configuration. Check out the
1228
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1229
+ """
1230
+
1231
+
1232
+ @add_start_docstrings(
1233
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1234
+ DeepseekV3_START_DOCSTRING,
1235
+ )
1236
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
1237
+ config_class = DeepseekV3Config
1238
+ base_model_prefix = "model"
1239
+ supports_gradient_checkpointing = True
1240
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
1241
+ _skip_keys_device_placement = "past_key_values"
1242
+ _supports_flash_attn_2 = True
1243
+ _supports_cache_class = True
1244
+
1245
+ def _init_weights(self, module):
1246
+ std = self.config.initializer_range
1247
+ if isinstance(module, nn.Linear):
1248
+ module.weight.data.normal_(mean=0.0, std=std)
1249
+ if module.bias is not None:
1250
+ module.bias.data.zero_()
1251
+ elif isinstance(module, nn.Embedding):
1252
+ module.weight.data.normal_(mean=0.0, std=std)
1253
+ if module.padding_idx is not None:
1254
+ module.weight.data[module.padding_idx].zero_()
1255
+
1256
+
1257
+ DeepseekV3_INPUTS_DOCSTRING = r"""
1258
+ Args:
1259
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1260
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1261
+ it.
1262
+
1263
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1264
+ [`PreTrainedTokenizer.__call__`] for details.
1265
+
1266
+ [What are input IDs?](../glossary#input-ids)
1267
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1268
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1269
+
1270
+ - 1 for tokens that are **not masked**,
1271
+ - 0 for tokens that are **masked**.
1272
+
1273
+ [What are attention masks?](../glossary#attention-mask)
1274
+
1275
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1276
+ [`PreTrainedTokenizer.__call__`] for details.
1277
+
1278
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1279
+ `past_key_values`).
1280
+
1281
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1282
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1283
+ information on the default strategy.
1284
+
1285
+ - 1 indicates the head is **not masked**,
1286
+ - 0 indicates the head is **masked**.
1287
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1288
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1289
+ config.n_positions - 1]`.
1290
+
1291
+ [What are position IDs?](../glossary#position-ids)
1292
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1293
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1294
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1295
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1296
+
1297
+ Two formats are allowed:
1298
+ - a [`~cache_utils.Cache`] instance;
1299
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1300
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1301
+ cache format.
1302
+
1303
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1304
+ legacy cache format will be returned.
1305
+
1306
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1307
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1308
+ of shape `(batch_size, sequence_length)`.
1309
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1310
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1311
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1312
+ model's internal embedding lookup matrix.
1313
+ use_cache (`bool`, *optional*):
1314
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1315
+ `past_key_values`).
1316
+ output_attentions (`bool`, *optional*):
1317
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1318
+ tensors for more detail.
1319
+ output_hidden_states (`bool`, *optional*):
1320
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1321
+ more detail.
1322
+ return_dict (`bool`, *optional*):
1323
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1324
+ """
1325
+
1326
+
1327
+ @add_start_docstrings(
1328
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1329
+ DeepseekV3_START_DOCSTRING,
1330
+ )
1331
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1332
+ """
1333
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
1334
+
1335
+ Args:
1336
+ config: DeepseekV3Config
1337
+ """
1338
+
1339
+ def __init__(self, config: DeepseekV3Config):
1340
+ super().__init__(config)
1341
+ self.padding_idx = config.pad_token_id
1342
+ self.vocab_size = config.vocab_size
1343
+
1344
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
1345
+ self.padding_idx)
1346
+ self.layers = nn.ModuleList([
1347
+ DeepseekV3DecoderLayer(config, layer_idx)
1348
+ for layer_idx in range(config.num_hidden_layers)
1349
+ ])
1350
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1351
+ self.norm = DeepseekV3RMSNorm(config.hidden_size,
1352
+ eps=config.rms_norm_eps)
1353
+
1354
+ self.gradient_checkpointing = False
1355
+ # Initialize weights and apply final processing
1356
+ self.post_init()
1357
+
1358
+ def get_input_embeddings(self):
1359
+ return self.embed_tokens
1360
+
1361
+ def set_input_embeddings(self, value):
1362
+ self.embed_tokens = value
1363
+
1364
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1365
+ def forward(
1366
+ self,
1367
+ input_ids: torch.LongTensor = None,
1368
+ attention_mask: Optional[torch.Tensor] = None,
1369
+ position_ids: Optional[torch.LongTensor] = None,
1370
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1371
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1372
+ use_cache: Optional[bool] = None,
1373
+ output_attentions: Optional[bool] = None,
1374
+ output_hidden_states: Optional[bool] = None,
1375
+ return_dict: Optional[bool] = None,
1376
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1377
+ output_attentions = (output_attentions if output_attentions is not None
1378
+ else self.config.output_attentions)
1379
+ output_hidden_states = (output_hidden_states
1380
+ if output_hidden_states is not None else
1381
+ self.config.output_hidden_states)
1382
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1383
+
1384
+ return_dict = (return_dict if return_dict is not None else
1385
+ self.config.use_return_dict)
1386
+
1387
+ # retrieve input_ids and inputs_embeds
1388
+ if input_ids is not None and inputs_embeds is not None:
1389
+ raise ValueError(
1390
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1391
+ )
1392
+ elif input_ids is not None:
1393
+ batch_size, seq_length = input_ids.shape[:2]
1394
+ elif inputs_embeds is not None:
1395
+ batch_size, seq_length = inputs_embeds.shape[:2]
1396
+ else:
1397
+ raise ValueError(
1398
+ "You have to specify either input_ids or inputs_embeds")
1399
+
1400
+ past_key_values_length = 0
1401
+ if use_cache:
1402
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1403
+ if use_legacy_cache:
1404
+ past_key_values = DynamicCache.from_legacy_cache(
1405
+ past_key_values)
1406
+ past_key_values_length = get_usable_length(past_key_values,
1407
+ seq_length)
1408
+
1409
+ if position_ids is None:
1410
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1411
+ position_ids = torch.arange(
1412
+ past_key_values_length,
1413
+ seq_length + past_key_values_length,
1414
+ dtype=torch.long,
1415
+ device=device,
1416
+ )
1417
+ position_ids = position_ids.unsqueeze(0)
1418
+
1419
+ if inputs_embeds is None:
1420
+ inputs_embeds = self.embed_tokens(input_ids)
1421
+
1422
+ if self._use_flash_attention_2:
1423
+ # 2d mask is passed through the layers
1424
+ attention_mask = (attention_mask if
1425
+ (attention_mask is not None
1426
+ and 0 in attention_mask) else None)
1427
+ else:
1428
+ # 4d mask is passed through the layers
1429
+ attention_mask = _prepare_4d_causal_attention_mask(
1430
+ attention_mask,
1431
+ (batch_size, seq_length),
1432
+ inputs_embeds,
1433
+ past_key_values_length,
1434
+ )
1435
+
1436
+ # embed positions
1437
+ hidden_states = inputs_embeds
1438
+
1439
+ # decoder layers
1440
+ all_hidden_states = () if output_hidden_states else None
1441
+ all_self_attns = () if output_attentions else None
1442
+ next_decoder_cache = None
1443
+
1444
+ for decoder_layer in self.layers:
1445
+ if output_hidden_states:
1446
+ all_hidden_states += (hidden_states, )
1447
+
1448
+ layer_outputs = decoder_layer(
1449
+ hidden_states,
1450
+ attention_mask=attention_mask,
1451
+ position_ids=position_ids,
1452
+ past_key_value=past_key_values,
1453
+ output_attentions=output_attentions,
1454
+ use_cache=use_cache,
1455
+ )
1456
+
1457
+ hidden_states = layer_outputs[0]
1458
+
1459
+ if use_cache:
1460
+ next_decoder_cache = layer_outputs[
1461
+ 2 if output_attentions else 1]
1462
+
1463
+ if output_attentions:
1464
+ all_self_attns += (layer_outputs[1], )
1465
+
1466
+ hidden_states = self.norm(hidden_states)
1467
+
1468
+ # add hidden states from the last decoder layer
1469
+ if output_hidden_states:
1470
+ all_hidden_states += (hidden_states, )
1471
+
1472
+ next_cache = None
1473
+ if use_cache:
1474
+ next_cache = (next_decoder_cache.to_legacy_cache()
1475
+ if use_legacy_cache else next_decoder_cache)
1476
+ if not return_dict:
1477
+ return tuple(
1478
+ v for v in
1479
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
1480
+ if v is not None)
1481
+ return BaseModelOutputWithPast(
1482
+ last_hidden_state=hidden_states,
1483
+ past_key_values=next_cache,
1484
+ hidden_states=all_hidden_states,
1485
+ attentions=all_self_attns,
1486
+ )
1487
+
1488
+
1489
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1490
+ _tied_weights_keys = ["lm_head.weight"]
1491
+
1492
+ def __init__(self, config):
1493
+ super().__init__(config)
1494
+ self.model = DeepseekV3Model(config)
1495
+ self.vocab_size = config.vocab_size
1496
+ self.lm_head = nn.Linear(config.hidden_size,
1497
+ config.vocab_size,
1498
+ bias=False)
1499
+
1500
+ # Initialize weights and apply final processing
1501
+ self.post_init()
1502
+
1503
+ def get_input_embeddings(self):
1504
+ return self.model.embed_tokens
1505
+
1506
+ def set_input_embeddings(self, value):
1507
+ self.model.embed_tokens = value
1508
+
1509
+ def get_output_embeddings(self):
1510
+ return self.lm_head
1511
+
1512
+ def set_output_embeddings(self, new_embeddings):
1513
+ self.lm_head = new_embeddings
1514
+
1515
+ def set_decoder(self, decoder):
1516
+ self.model = decoder
1517
+
1518
+ def get_decoder(self):
1519
+ return self.model
1520
+
1521
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1522
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast,
1523
+ config_class=_CONFIG_FOR_DOC)
1524
+ def forward(
1525
+ self,
1526
+ input_ids: torch.LongTensor = None,
1527
+ attention_mask: Optional[torch.Tensor] = None,
1528
+ position_ids: Optional[torch.LongTensor] = None,
1529
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1530
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1531
+ labels: Optional[torch.LongTensor] = None,
1532
+ use_cache: Optional[bool] = None,
1533
+ output_attentions: Optional[bool] = None,
1534
+ output_hidden_states: Optional[bool] = None,
1535
+ return_dict: Optional[bool] = None,
1536
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1537
+ r"""
1538
+ Args:
1539
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1540
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1541
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1542
+ (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1543
+
1544
+ Returns:
1545
+
1546
+ Example:
1547
+
1548
+ ```python
1549
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
1550
+
1551
+ >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1552
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1553
+
1554
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1555
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1556
+
1557
+ >>> # Generate
1558
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1559
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1560
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1561
+ ```"""
1562
+ output_attentions = (output_attentions if output_attentions is not None
1563
+ else self.config.output_attentions)
1564
+ output_hidden_states = (output_hidden_states
1565
+ if output_hidden_states is not None else
1566
+ self.config.output_hidden_states)
1567
+ return_dict = (return_dict if return_dict is not None else
1568
+ self.config.use_return_dict)
1569
+
1570
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1571
+ outputs = self.model(
1572
+ input_ids=input_ids,
1573
+ attention_mask=attention_mask,
1574
+ position_ids=position_ids,
1575
+ past_key_values=past_key_values,
1576
+ inputs_embeds=inputs_embeds,
1577
+ use_cache=use_cache,
1578
+ output_attentions=output_attentions,
1579
+ output_hidden_states=output_hidden_states,
1580
+ return_dict=return_dict,
1581
+ )
1582
+
1583
+ hidden_states = outputs[0]
1584
+ logits = self.lm_head(hidden_states)
1585
+ logits = logits.float()
1586
+
1587
+ loss = None
1588
+ if labels is not None:
1589
+ # Shift so that tokens < n predict n
1590
+ shift_logits = logits[..., :-1, :].contiguous()
1591
+ shift_labels = labels[..., 1:].contiguous()
1592
+ # Flatten the tokens
1593
+ loss_fct = CrossEntropyLoss()
1594
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1595
+ shift_labels = shift_labels.view(-1)
1596
+ # Enable model parallelism
1597
+ shift_labels = shift_labels.to(shift_logits.device)
1598
+ loss = loss_fct(shift_logits, shift_labels)
1599
+
1600
+ if not return_dict:
1601
+ output = (logits, ) + outputs[1:]
1602
+ return (loss, ) + output if loss is not None else output
1603
+
1604
+ return CausalLMOutputWithPast(
1605
+ loss=loss,
1606
+ logits=logits,
1607
+ past_key_values=outputs.past_key_values,
1608
+ hidden_states=outputs.hidden_states,
1609
+ attentions=outputs.attentions,
1610
+ )
1611
+
1612
+ def prepare_inputs_for_generation(
1613
+ self,
1614
+ input_ids,
1615
+ past_key_values=None,
1616
+ attention_mask=None,
1617
+ inputs_embeds=None,
1618
+ **kwargs,
1619
+ ):
1620
+ if past_key_values is not None:
1621
+ if isinstance(past_key_values, Cache):
1622
+ cache_length = past_key_values.get_seq_length()
1623
+ # seen_tokens 可能在某些 transformers 版本中不存在,使用 getattr 安全访问
1624
+ past_length = getattr(past_key_values, 'seen_tokens',
1625
+ cache_length)
1626
+ max_cache_length = past_key_values.get_max_length()
1627
+ else:
1628
+ cache_length = past_length = past_key_values[0][0].shape[2]
1629
+ max_cache_length = None
1630
+
1631
+ # Keep only the unprocessed tokens:
1632
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1633
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1634
+ # input)
1635
+ if (attention_mask is not None
1636
+ and attention_mask.shape[1] > input_ids.shape[1]):
1637
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
1638
+ past_length):]
1639
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1640
+ # input_ids based on the past_length.
1641
+ elif past_length < input_ids.shape[1]:
1642
+ input_ids = input_ids[:, past_length:]
1643
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1644
+
1645
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1646
+ if (max_cache_length is not None and attention_mask is not None
1647
+ and cache_length + input_ids.shape[1] > max_cache_length):
1648
+ attention_mask = attention_mask[:, -max_cache_length:]
1649
+
1650
+ position_ids = kwargs.get("position_ids", None)
1651
+ if attention_mask is not None and position_ids is None:
1652
+ # create position_ids on the fly for batch generation
1653
+ position_ids = attention_mask.long().cumsum(-1) - 1
1654
+ position_ids.masked_fill_(attention_mask == 0, 1)
1655
+ if past_key_values:
1656
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1657
+
1658
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1659
+ if inputs_embeds is not None and past_key_values is None:
1660
+ model_inputs = {"inputs_embeds": inputs_embeds}
1661
+ else:
1662
+ model_inputs = {"input_ids": input_ids}
1663
+
1664
+ model_inputs.update({
1665
+ "position_ids": position_ids,
1666
+ "past_key_values": past_key_values,
1667
+ "use_cache": kwargs.get("use_cache"),
1668
+ "attention_mask": attention_mask,
1669
+ })
1670
+ return model_inputs
1671
+
1672
+ @staticmethod
1673
+ def _reorder_cache(past_key_values, beam_idx):
1674
+ reordered_past = ()
1675
+ for layer_past in past_key_values:
1676
+ reordered_past += (tuple(
1677
+ past_state.index_select(0, beam_idx.to(past_state.device))
1678
+ for past_state in layer_past), )
1679
+ return reordered_past
1680
+
1681
+
1682
+ @add_start_docstrings(
1683
+ """
1684
+ The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).
1685
+
1686
+ [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1687
+ (e.g. GPT-2) do.
1688
+
1689
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1690
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1691
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1692
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1693
+ each row of the batch).
1694
+ """,
1695
+ DeepseekV3_START_DOCSTRING,
1696
+ )
1697
+ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1698
+
1699
+ def __init__(self, config):
1700
+ super().__init__(config)
1701
+ self.num_labels = config.num_labels
1702
+ self.model = DeepseekV3Model(config)
1703
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1704
+
1705
+ # Initialize weights and apply final processing
1706
+ self.post_init()
1707
+
1708
+ def get_input_embeddings(self):
1709
+ return self.model.embed_tokens
1710
+
1711
+ def set_input_embeddings(self, value):
1712
+ self.model.embed_tokens = value
1713
+
1714
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1715
+ def forward(
1716
+ self,
1717
+ input_ids: torch.LongTensor = None,
1718
+ attention_mask: Optional[torch.Tensor] = None,
1719
+ position_ids: Optional[torch.LongTensor] = None,
1720
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1721
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1722
+ labels: Optional[torch.LongTensor] = None,
1723
+ use_cache: Optional[bool] = None,
1724
+ output_attentions: Optional[bool] = None,
1725
+ output_hidden_states: Optional[bool] = None,
1726
+ return_dict: Optional[bool] = None,
1727
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1728
+ r"""
1729
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1730
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1731
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1732
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1733
+ """
1734
+ return_dict = (return_dict if return_dict is not None else
1735
+ self.config.use_return_dict)
1736
+
1737
+ transformer_outputs = self.model(
1738
+ input_ids,
1739
+ attention_mask=attention_mask,
1740
+ position_ids=position_ids,
1741
+ past_key_values=past_key_values,
1742
+ inputs_embeds=inputs_embeds,
1743
+ use_cache=use_cache,
1744
+ output_attentions=output_attentions,
1745
+ output_hidden_states=output_hidden_states,
1746
+ return_dict=return_dict,
1747
+ )
1748
+ hidden_states = transformer_outputs[0]
1749
+ logits = self.score(hidden_states)
1750
+
1751
+ if input_ids is not None:
1752
+ batch_size = input_ids.shape[0]
1753
+ else:
1754
+ batch_size = inputs_embeds.shape[0]
1755
+
1756
+ if self.config.pad_token_id is None and batch_size != 1:
1757
+ raise ValueError(
1758
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1759
+ )
1760
+ if self.config.pad_token_id is None:
1761
+ sequence_lengths = -1
1762
+ else:
1763
+ if input_ids is not None:
1764
+ sequence_lengths = (torch.eq(
1765
+ input_ids, self.config.pad_token_id).int().argmax(-1) -
1766
+ 1).to(logits.device)
1767
+ else:
1768
+ sequence_lengths = -1
1769
+
1770
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device),
1771
+ sequence_lengths]
1772
+
1773
+ loss = None
1774
+ if labels is not None:
1775
+ labels = labels.to(logits.device)
1776
+ if self.config.problem_type is None:
1777
+ if self.num_labels == 1:
1778
+ self.config.problem_type = "regression"
1779
+ elif self.num_labels > 1 and (labels.dtype == torch.long
1780
+ or labels.dtype == torch.int):
1781
+ self.config.problem_type = "single_label_classification"
1782
+ else:
1783
+ self.config.problem_type = "multi_label_classification"
1784
+
1785
+ if self.config.problem_type == "regression":
1786
+ loss_fct = MSELoss()
1787
+ if self.num_labels == 1:
1788
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1789
+ else:
1790
+ loss = loss_fct(pooled_logits, labels)
1791
+ elif self.config.problem_type == "single_label_classification":
1792
+ loss_fct = CrossEntropyLoss()
1793
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels),
1794
+ labels.view(-1))
1795
+ elif self.config.problem_type == "multi_label_classification":
1796
+ loss_fct = BCEWithLogitsLoss()
1797
+ loss = loss_fct(pooled_logits, labels)
1798
+ if not return_dict:
1799
+ output = (pooled_logits, ) + transformer_outputs[1:]
1800
+ return ((loss, ) + output) if loss is not None else output
1801
+
1802
+ return SequenceClassifierOutputWithPast(
1803
+ loss=loss,
1804
+ logits=pooled_logits,
1805
+ past_key_values=transformer_outputs.past_key_values,
1806
+ hidden_states=transformer_outputs.hidden_states,
1807
+ attentions=transformer_outputs.attentions,
1808
+ )
modeling_kimi_k25.py ADDED
@@ -0,0 +1,1251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-2026 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for Kimi-K2.5.
5
+ #
6
+ # Licensing Information:
7
+ # - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
8
+ # - Other parts of the code are licensed under the MIT License.
9
+ #
10
+ # Apache License, Version 2.0:
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ #
23
+ # MIT License:
24
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
25
+ # of this software and associated documentation files (the "Software"), to deal
26
+ # in the Software without restriction, including without limitation the rights
27
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28
+ # copies of the Software, and to permit persons to whom the Software is
29
+ # furnished to do so, subject to the following conditions:
30
+ #
31
+ # The above copyright notice and this permission notice shall be included in all
32
+ # copies or substantial portions of the Software.
33
+ #
34
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40
+ # SOFTWARE.
41
+ import math
42
+ from collections.abc import Sequence
43
+ from copy import deepcopy
44
+ from typing import Optional
45
+
46
+ import numpy as np
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ from transformers import activations
51
+
52
+ try:
53
+ from transformers.activations import PytorchGELUTanh
54
+ except ImportError:
55
+ from transformers.activations import GELUTanh
56
+ activations.PytorchGELUTanh = GELUTanh
57
+ PytorchGELUTanh = GELUTanh
58
+ from transformers.activations import PytorchGELUTanh
59
+ from transformers.cache_utils import Cache
60
+ from transformers.configuration_utils import PretrainedConfig
61
+ from transformers.modeling_utils import PreTrainedModel
62
+ from transformers.models.llava.modeling_llava import \
63
+ LlavaCausalLMOutputWithPast
64
+ from transformers.utils import is_flash_attn_2_available
65
+
66
+ from .configuration_kimi_k25 import KimiK25Config
67
+ from .modeling_deepseek import DeepseekV3ForCausalLM
68
+
69
+ # Flash attention imports
70
+ if is_flash_attn_2_available():
71
+ from flash_attn import flash_attn_varlen_func
72
+ else:
73
+ flash_attn_varlen_func = None
74
+
75
+
76
+ def multihead_attention(
77
+ q: torch.Tensor,
78
+ k: torch.Tensor,
79
+ v: torch.Tensor,
80
+ q_cu_seqlens: torch.Tensor | None = None,
81
+ k_cu_seqlens: torch.Tensor | None = None,
82
+ max_seqlen_q: int | None = None,
83
+ max_seqlen_k: int | None = None,
84
+ deterministic: bool = False,
85
+ ):
86
+ """Multi-head attention using flash attention 2.
87
+
88
+ Args:
89
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
90
+ or (tot_seqlens, num_heads, head_dim) if packing.
91
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
92
+ The first element should be 0 and the last element should be q.shape[0].
93
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
94
+ The first element should be 0 and the last element should be k.shape[0].
95
+
96
+ Returns:
97
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
98
+ where dim = num_heads * head_dim
99
+ """
100
+ attn_out = flash_attn_varlen_func(
101
+ q,
102
+ k,
103
+ v,
104
+ q_cu_seqlens,
105
+ k_cu_seqlens,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ causal=False,
109
+ deterministic=deterministic,
110
+ )
111
+ if isinstance(attn_out, tuple):
112
+ attn_out = attn_out[0]
113
+
114
+ attn_out = attn_out.flatten(start_dim=-2)
115
+
116
+ return attn_out
117
+
118
+
119
+ def eager_attention(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ q_cu_seqlens: Optional[torch.Tensor] = None,
124
+ k_cu_seqlens: Optional[torch.Tensor] = None,
125
+ **kwargs,
126
+ ) -> torch.Tensor:
127
+ seq_length = q.shape[0]
128
+ attention_mask = torch.zeros([1, seq_length, seq_length],
129
+ device=q.device,
130
+ dtype=torch.bool)
131
+ for i in range(1, len(q_cu_seqlens)):
132
+ attention_mask[
133
+ ...,
134
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
135
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
136
+ ] = True
137
+ q = q.transpose(0, 1)
138
+ k = k.transpose(0, 1)
139
+ v = v.transpose(0, 1)
140
+
141
+ attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
142
+ attn_weight += attention_mask
143
+ attn_weight = torch.softmax(attn_weight, dim=-1,
144
+ dtype=torch.float32).to(q.dtype)
145
+
146
+ attn_output = attn_weight @ v
147
+ attn_output = attn_output.transpose(0, 1)
148
+ attn_output = attn_output.reshape(seq_length, -1)
149
+ return attn_output
150
+
151
+
152
+ VL_VISION_ATTENTION_FUNCTIONS = {
153
+ "flash_attention_2": multihead_attention,
154
+ "eager": eager_attention,
155
+ }
156
+
157
+
158
+ def _apply_rope_input_validation(x, freqs_cis):
159
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
160
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
161
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
162
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
163
+
164
+
165
+ def get_rope_shape_decorate(func):
166
+ _get_rope_shape_first_call_flag = set()
167
+
168
+ def wrapper(org, interpolation_mode, shape):
169
+ key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode)
170
+ if key not in _get_rope_shape_first_call_flag:
171
+ _get_rope_shape_first_call_flag.add(key)
172
+ _ = func(org, interpolation_mode, shape=(64, 64))
173
+ return func(org, interpolation_mode, shape)
174
+
175
+ return wrapper
176
+
177
+
178
+ @get_rope_shape_decorate
179
+ @torch.compile(dynamic=True)
180
+ def get_rope_shape(org, interpolation_mode, shape):
181
+ return (F.interpolate(
182
+ org.permute((2, 0, 1)).unsqueeze(0),
183
+ size=shape,
184
+ mode=interpolation_mode,
185
+ ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
186
+
187
+
188
+ def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
189
+ freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Args: (The leading dimensions of all inputs should be the same)
192
+ xq: query, tensor of shape (..., num_heads, head_dim)
193
+ xk: key, tensor of shape (..., num_heads, head_dim)
194
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
195
+ Returns:
196
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
197
+ """
198
+ _apply_rope_input_validation(xq, freqs_cis)
199
+ _apply_rope_input_validation(xk, freqs_cis)
200
+
201
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
202
+ # ..., num_heads, head_dim/2
203
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
204
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
205
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
206
+ -2) # ..., num_heads, head_dim
207
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
208
+ -2) # ..., num_heads, head_dim
209
+ return xq_out.type_as(xq), xk_out.type_as(xk)
210
+
211
+
212
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
213
+ """
214
+ From:
215
+ https://github.com/OpenGVLab/InternVideo/blob/421f6d2361fc8f61a3394244571f2601a4e99e29/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py#L86
216
+ embed_dim: output dimension for each position
217
+ pos: a list of positions to be encoded: size (M,)
218
+ out: (M, D)
219
+ """
220
+ assert embed_dim % 2 == 0
221
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
222
+ omega /= embed_dim / 2.0
223
+ omega = 1.0 / 10000**omega # (D/2,)
224
+
225
+ pos = pos.reshape(-1) # (M,)
226
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
227
+
228
+ emb_sin = np.sin(out) # (M, D/2)
229
+ emb_cos = np.cos(out) # (M, D/2)
230
+
231
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
232
+ return emb
233
+
234
+
235
+ def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
236
+ """
237
+ t_size: int of the temporal size
238
+ return:
239
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
240
+ """
241
+ grid_t = np.arange(t_size, dtype=np.float32)
242
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
243
+ if cls_token:
244
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
245
+ axis=0)
246
+ return pos_embed
247
+
248
+
249
+ class Learnable2DInterpPosEmbDivided_fixed(nn.Module):
250
+
251
+ def __init__(self,
252
+ height: int,
253
+ width: int,
254
+ num_frames: int,
255
+ dim: int,
256
+ interpolation_mode: str = 'bicubic') -> None:
257
+ super().__init__()
258
+ self.height = height
259
+ self.width = width
260
+ self.num_frames = num_frames
261
+ self.dim = dim
262
+ self.interpolation_mode = interpolation_mode
263
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
264
+ self.register_buffer('time_weight',
265
+ torch.from_numpy(
266
+ get_1d_sincos_pos_embed(
267
+ self.dim,
268
+ self.num_frames)).float().unsqueeze(1),
269
+ persistent=False)
270
+
271
+ self.reset_parameters()
272
+
273
+ def reset_parameters(self):
274
+ nn.init.normal_(self.weight)
275
+
276
+ def forward(self, x: torch.Tensor,
277
+ grid_thws: torch.Tensor) -> torch.Tensor:
278
+ pos_embs = []
279
+ for t, h, w in grid_thws.tolist():
280
+ assert t <= self.num_frames, f't:{t} > self.num_frames:{self.num_frames}'
281
+ if (h, w) == self.weight.shape[:-1]:
282
+ pos_emb_2d = self.weight.flatten(end_dim=1)
283
+ else:
284
+ pos_emb_2d = get_rope_shape(
285
+ self.weight,
286
+ interpolation_mode=self.interpolation_mode,
287
+ shape=(h, w),
288
+ )
289
+
290
+ if t == 1:
291
+ pos_emb_3d = pos_emb_2d
292
+ else:
293
+ pos_emb_3d = pos_emb_2d.unsqueeze(0).repeat(
294
+ t, 1, 1) + self.time_weight[0:t]
295
+
296
+ pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1]))
297
+
298
+ out = x + torch.cat(pos_embs)
299
+ return out
300
+
301
+
302
+ class MoonVision3dPatchEmbed(nn.Module):
303
+
304
+ def __init__(self,
305
+ out_dim: int,
306
+ in_dim: int = 3,
307
+ patch_size: int | tuple[int, int] = (14, 14),
308
+ pos_emb_height: int = 14,
309
+ pos_emb_width: int = 14,
310
+ pos_emb_time: int = 4,
311
+ pos_emb_type: str = 'divided_fixed'):
312
+ super().__init__()
313
+ assert isinstance(
314
+ patch_size,
315
+ int | Sequence), f'Invalid patch_size type: {type(patch_size)}'
316
+ if isinstance(patch_size, int):
317
+ patch_size = (patch_size, patch_size)
318
+ assert (len(patch_size) == 2
319
+ ), f'Expected patch_size to be a tuple of 2, got {patch_size}'
320
+ self.patch_size = patch_size
321
+
322
+ self.proj = nn.Conv2d(in_dim,
323
+ out_dim,
324
+ kernel_size=patch_size,
325
+ stride=patch_size)
326
+
327
+ if pos_emb_type == 'divided_fixed':
328
+ self.pos_emb = Learnable2DInterpPosEmbDivided_fixed(
329
+ height=pos_emb_height,
330
+ width=pos_emb_width,
331
+ num_frames=pos_emb_time,
332
+ dim=out_dim)
333
+ else:
334
+ raise NotImplementedError(
335
+ f'Not support pos_emb_type: {pos_emb_type}')
336
+
337
+ def forward(self, x: torch.Tensor,
338
+ grid_thws: torch.Tensor) -> torch.Tensor:
339
+ """
340
+ Args:
341
+ x (L, Channels): input tensor
342
+ grid_hws (N, 3): temporal, height and width
343
+
344
+ Returns:
345
+ (L, Cout) tensor
346
+ """
347
+ x = self.proj(x).view(x.size(0), -1)
348
+ # apply positional embedding
349
+ x = self.pos_emb(x, grid_thws)
350
+ return x
351
+
352
+
353
+ class Rope2DPosEmbRepeated(nn.Module):
354
+ """2D rotary position embedding with multi-resolution support.
355
+
356
+ This class is intended to be used in the following way:
357
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
358
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
359
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
360
+ The rope is shared across all attention layers and all heads.
361
+
362
+ Refs:
363
+ - RoFormer: https://arxiv.org/abs/2104.09864
364
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
365
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
366
+
367
+ Args:
368
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
369
+ max_height (int): the maximum height of the 2D grid
370
+ max_width (int): the maximum width of the 2D grid
371
+ theta_base (float): the base of the theta
372
+ device (str): the device to store the precomputed cis
373
+ """
374
+
375
+ def __init__(self,
376
+ dim: int,
377
+ max_height: int,
378
+ max_width: int,
379
+ theta_base=10000):
380
+ super().__init__()
381
+ self.dim = dim
382
+ assert self.dim % 4 == 0, 'dim must be divisible by 4'
383
+ self.max_height = max_height
384
+ self.max_width = max_width
385
+ self.theta_base = theta_base
386
+
387
+ def extra_repr(self):
388
+ return f'dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}'
389
+
390
+ def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
391
+ """Calculate the cis(freqs) for each position in the 2D grid.
392
+
393
+ Return: complex tensor of shape (max_height, max_width, dim//2) and value:
394
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
395
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
396
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
397
+ """
398
+ N = self.max_height * self.max_width
399
+ flat_pos = torch.arange(0, N).float().to(device)
400
+ x_pos = flat_pos % self.max_width
401
+ y_pos = flat_pos // self.max_width
402
+ dim_range = (torch.arange(0, self.dim,
403
+ 4)[:(self.dim // 4)].float().to(device)
404
+ ) # C/4
405
+ freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
406
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
407
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
408
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
409
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
410
+ # N, C/4, 2
411
+ freqs_cis = torch.cat(
412
+ [x_cis.unsqueeze(dim=-1),
413
+ y_cis.unsqueeze(dim=-1)], dim=-1)
414
+ # max_height, max_width, C/2
415
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
416
+ return freqs_cis
417
+
418
+ def get_freqs_cis(self, grid_thws: torch.Tensor,
419
+ device: torch.device) -> torch.Tensor:
420
+ """
421
+ Args:
422
+ grid_thws (torch.Tensor): grid time, height and width
423
+
424
+ Returns:
425
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
426
+ """
427
+ if not hasattr(self, 'freqs_cis'):
428
+ self.register_buffer('freqs_cis',
429
+ self._precompute_freqs_cis(device),
430
+ persistent=False)
431
+
432
+ shapes = grid_thws.tolist()
433
+ assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
434
+ for t, h, w in shapes), (
435
+ shapes,
436
+ self.max_height,
437
+ self.max_width,
438
+ )
439
+ freqs_cis = torch.cat(
440
+ [
441
+ self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1)
442
+ for t, h, w in shapes
443
+ ],
444
+ dim=0,
445
+ )
446
+ return freqs_cis
447
+
448
+
449
+ class MLP2(nn.Module):
450
+ """
451
+ Args:
452
+ dims: [in_dim, hidden_dim, out_dim]
453
+ bias: whether to use bias in linear layer.
454
+ """
455
+
456
+ def __init__(self, dims: list[int], activation, bias=True):
457
+ super().__init__()
458
+ assert len(dims) == 3
459
+ self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
460
+ self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
461
+ self.activation = activation
462
+ for m in [self.fc0, self.fc1]:
463
+ nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
464
+ if m.bias is not None:
465
+ nn.init.zeros_(m.bias)
466
+
467
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
468
+ x = self.fc0(x)
469
+ x = self.activation(x)
470
+ return self.fc1(x)
471
+
472
+
473
+ class MoonViTEncoderLayer(nn.Module):
474
+
475
+ def __init__(
476
+ self,
477
+ num_heads: int,
478
+ hidden_dim: int,
479
+ mlp_dim: int,
480
+ *,
481
+ attn_implementation: str = 'flash_attention_2',
482
+ activation=F.gelu,
483
+ attn_bias: bool = False,
484
+ use_deterministic_attn: bool = False,
485
+ ):
486
+ super().__init__()
487
+ self.num_heads = num_heads
488
+ self.hidden_dim = hidden_dim
489
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
490
+ self.attn_implementation = attn_implementation
491
+ self.use_deterministic_attn = use_deterministic_attn
492
+
493
+ self.norm0 = nn.LayerNorm(hidden_dim)
494
+ self.norm1 = nn.LayerNorm(hidden_dim)
495
+ self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
496
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
497
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
498
+
499
+ def attention_qkvpacked(
500
+ self,
501
+ x: torch.Tensor,
502
+ cu_seqlens: torch.Tensor,
503
+ max_seqlen: torch.Tensor,
504
+ rope_freqs_cis: torch.Tensor | None = None,
505
+ ):
506
+ """
507
+ Args:
508
+ x (torch.Tensor): (batch_size, seqlen, hidden_dim)
509
+ cu_seqlens (torch.Tensor):
510
+ """
511
+ xqkv = self.wqkv(x)
512
+
513
+ qkv_shape = xqkv.size()[:-1] + (
514
+ 3,
515
+ self.num_heads,
516
+ self.hidden_size_per_attention_head,
517
+ )
518
+ # xqkv: (batch_size, seqlen, 3, nheads, headdim)
519
+ xqkv = xqkv.view(*qkv_shape)
520
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
521
+
522
+ xq, xk = apply_rope(xq, xk, rope_freqs_cis)
523
+
524
+ attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
525
+ attn_out = attn_func(xq,
526
+ xk,
527
+ xv,
528
+ q_cu_seqlens=cu_seqlens,
529
+ k_cu_seqlens=cu_seqlens,
530
+ max_seqlen_k=max_seqlen,
531
+ max_seqlen_q=max_seqlen,
532
+ deterministic=self.use_deterministic_attn)
533
+
534
+ attn_out = self.wo(attn_out)
535
+ return attn_out
536
+
537
+ def forward(
538
+ self,
539
+ hidden_states: torch.Tensor,
540
+ cu_seqlens: torch.Tensor,
541
+ max_seqlen: int,
542
+ rope_freqs_cis: torch.Tensor | None = None,
543
+ ):
544
+ residual = hidden_states
545
+ hidden_states = self.norm0(hidden_states)
546
+
547
+ hidden_states = self.attention_qkvpacked(hidden_states, cu_seqlens,
548
+ max_seqlen, rope_freqs_cis)
549
+ hidden_states = residual + hidden_states
550
+
551
+ residual = hidden_states
552
+ hidden_states = self.norm1(hidden_states)
553
+ hidden_states = self.mlp(hidden_states)
554
+ hidden_states = residual + hidden_states
555
+
556
+ return hidden_states
557
+
558
+
559
+ class MoonViT3dEncoder(nn.Module):
560
+
561
+ def __init__(self,
562
+ hidden_dim: int,
563
+ num_layers: int,
564
+ block_cfg: dict,
565
+ video_attn_type: str = 'spatial_temporal', use_deterministic_attn: bool = False) -> None:
566
+ super().__init__()
567
+
568
+ assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
569
+ self.use_deterministic_attn = use_deterministic_attn
570
+ self.video_attn_type = video_attn_type
571
+ self.rope_2d = Rope2DPosEmbRepeated(
572
+ block_cfg['hidden_dim'] // block_cfg['num_heads'], 512, 512)
573
+ self.blocks = nn.ModuleList([
574
+ MoonViTEncoderLayer(
575
+ **block_cfg,
576
+ use_deterministic_attn=self.use_deterministic_attn)
577
+ for _ in range(num_layers)
578
+ ])
579
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
580
+
581
+ def forward(
582
+ self,
583
+ hidden_states: torch.Tensor,
584
+ grid_thws: torch.Tensor,
585
+ ) -> torch.Tensor:
586
+ rope_freqs_cis = self.rope_2d.get_freqs_cis(
587
+ grid_thws=grid_thws, device=hidden_states.device)
588
+
589
+ lengths = torch.cat((
590
+ torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device),
591
+ grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2],
592
+ ))
593
+
594
+ max_seqlen = lengths.max()
595
+ cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0,
596
+ dtype=torch.int32)
597
+ for block in self.blocks:
598
+ hidden_states = block(hidden_states,
599
+ cu_seqlens,
600
+ max_seqlen,
601
+ rope_freqs_cis=rope_freqs_cis)
602
+
603
+ hidden_states = self.final_layernorm(hidden_states)
604
+ return hidden_states
605
+
606
+
607
+ def tpool_patch_merger(
608
+ x: torch.Tensor,
609
+ grid_thws: torch.Tensor,
610
+ merge_kernel_size: tuple[int, int] = (2, 2),
611
+ ) -> list[torch.Tensor]:
612
+ d_model = x.size(-1)
613
+
614
+ outputs = []
615
+ pre_sum = 0
616
+ for t, h, w in grid_thws.tolist():
617
+ # Get the current sequence
618
+ seq = x[pre_sum:pre_sum + t * h * w]
619
+ # Reshape along self.merge_kernel_size and concat to the last dimension
620
+ kernel_height, kernel_width = merge_kernel_size
621
+ new_height, new_width = h // kernel_height, w // kernel_width
622
+ reshaped_seq = seq.view(t, new_height, kernel_height, new_width,
623
+ kernel_width, d_model)
624
+ reshaped_seq = reshaped_seq.permute(0, 1,
625
+ 3, 2, 4, 5).contiguous().mean(
626
+ dim=0) # temporal pooling
627
+ padded_seq = reshaped_seq.view(new_height * new_width,
628
+ kernel_height * kernel_width, -1)
629
+ outputs.append(padded_seq)
630
+ pre_sum += t * h * w
631
+
632
+ return outputs
633
+
634
+
635
+ class MoonViT3dPretrainedModel(PreTrainedModel):
636
+ config_class = None
637
+ model_type = 'moonvit3d'
638
+ _no_split_modules = ['PackingTransformer']
639
+ _supports_flash_attn_2 = True
640
+ _supports_sdpa = True
641
+
642
+ def __init__(self, config, *inputs, **kwargs):
643
+ super().__init__(config, *inputs, **kwargs)
644
+ self.use_deterministic_attn = False
645
+ config = deepcopy(config)
646
+ self.merge_kernel_size = config.merge_kernel_size
647
+ self.patch_size = config.patch_size
648
+ self.merge_type = config.merge_type
649
+
650
+ self.patch_embed = MoonVision3dPatchEmbed(
651
+ out_dim=config.hidden_size,
652
+ patch_size=config.patch_size,
653
+ pos_emb_height=config.init_pos_emb_height,
654
+ pos_emb_width=config.init_pos_emb_width,
655
+ pos_emb_time=config.init_pos_emb_time,
656
+ pos_emb_type=config.pos_emb_type,
657
+ )
658
+
659
+ self.encoder = MoonViT3dEncoder(hidden_dim=config.hidden_size,
660
+ num_layers=config.num_hidden_layers,
661
+ block_cfg={
662
+ 'num_heads':
663
+ config.num_attention_heads,
664
+ 'hidden_dim':
665
+ config.hidden_size,
666
+ 'mlp_dim':
667
+ config.intermediate_size,
668
+ 'activation':
669
+ PytorchGELUTanh(),
670
+ 'attn_bias':
671
+ True,
672
+ 'attn_implementation':
673
+ config._attn_implementation,
674
+ },
675
+ video_attn_type=config.video_attn_type, use_deterministic_attn=self.use_deterministic_attn)
676
+
677
+ def forward(self, pixel_values: torch.Tensor,
678
+ grid_thws: torch.Tensor) -> torch.Tensor:
679
+ """
680
+ Args:
681
+ pixel_values (torch.Tensor): The input pixel values.
682
+ grid_thws (torch.Tensor): Temporal, height and width.
683
+
684
+ Returns:
685
+ torch.Tensor: The output tokens.
686
+ """
687
+ # grid_thws = grid_thws.to('cpu')
688
+ assert grid_thws.ndim == 2, f'grid_thws should be 2D, got {grid_thws.ndim}'
689
+ assert grid_thws.size(1) == 3, f'No support for thw: {grid_thws}'
690
+ hidden_states = self.patch_embed(pixel_values, grid_thws)
691
+ hidden_states = self.encoder(hidden_states, grid_thws)
692
+ if self.merge_type == 'sd2_tpool': # spatial downsampling 2x with temporal pooling all
693
+ hidden_states = tpool_patch_merger(
694
+ hidden_states,
695
+ grid_thws,
696
+ merge_kernel_size=self.merge_kernel_size)
697
+ else:
698
+ raise NotImplementedError(f'Not support {self.merge_type}')
699
+
700
+ return hidden_states
701
+
702
+
703
+ # ============================================================================
704
+ # MM Projector Helper Classes (from mm_projector/modeling_mm_projectors.py)
705
+ # ============================================================================
706
+
707
+
708
+ class IdentityMap(nn.Module):
709
+
710
+ def __init__(self):
711
+ super().__init__()
712
+
713
+ def forward(self, x, *args, **kwargs):
714
+ return x
715
+
716
+
717
+ class MLP(nn.Module):
718
+
719
+ def __init__(self, config):
720
+ super().__init__()
721
+ # TODO, use faster LayerNorm
722
+ self.pre_norm = nn.LayerNorm(config.mm_hidden_size)
723
+ self.proj = nn.Sequential(
724
+ nn.Linear(config.mm_hidden_size, config.hidden_size), nn.GELU(),
725
+ nn.Linear(config.hidden_size, config.hidden_size))
726
+
727
+ def forward(self, x, *args, **kwargs):
728
+ assert isinstance(x,
729
+ list | tuple), f'x is not a list or tuple: {type(x)}'
730
+ lengths = [item.shape[0] for item in x]
731
+ x = torch.cat(x, dim=0)
732
+ x = self.pre_norm(x)
733
+ x = self.proj(x)
734
+ x = torch.split(x, lengths, dim=0)
735
+
736
+ return x
737
+
738
+
739
+ class PatchMergerMLP(nn.Module):
740
+
741
+ def __init__(self, config):
742
+ super().__init__()
743
+ eps = config.projector_ln_eps
744
+ self.hidden_size = config.mm_hidden_size * (
745
+ config.merge_kernel_size[0] * config.merge_kernel_size[1])
746
+ self.pre_norm = nn.LayerNorm(config.mm_hidden_size, eps=eps)
747
+ self.proj = nn.Sequential(
748
+ nn.Linear(self.hidden_size, self.hidden_size),
749
+ nn.GELU(),
750
+ nn.Linear(self.hidden_size, config.hidden_size),
751
+ )
752
+
753
+ def forward(self, x, *args, **kwargs):
754
+ if isinstance(x, list) or isinstance(x, tuple):
755
+ x = [
756
+ self.proj(self.pre_norm(item).view(item.shape[0], -1))
757
+ for item in x
758
+ ]
759
+ else:
760
+ # B, N, N_k, C = x.shape
761
+ B = x.shape[0]
762
+ x = self.proj(self.pre_norm(x).view(B, -1, self.hidden_size))
763
+ return x
764
+
765
+
766
+ class KimiK25PreTrainedModel(PreTrainedModel):
767
+ config_class = KimiK25Config
768
+ base_model_prefix = "model"
769
+ _no_split_modules = [
770
+ "MoonViT3dPretrainedModel",
771
+ "MoonViTEncoderLayer",
772
+ "DeepseekDecoderLayer",
773
+ "PatchMergerMLP",
774
+ ]
775
+ _skip_keys_device_placement = "past_key_values"
776
+ _supports_flash_attn_2 = True
777
+ _supports_sdpa = False
778
+
779
+ def _init_weights(self, module):
780
+ # important: this ported version of Llava isn't meant for training from scratch - only
781
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
782
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
783
+ std = (self.config.initializer_range if hasattr(
784
+ self.config, "initializer_range") else
785
+ self.config.text_config.initializer_range)
786
+
787
+ if hasattr(module, "class_embedding"):
788
+ module.class_embedding.data.normal_(mean=0.0, std=std)
789
+
790
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
791
+ module.weight.data.normal_(mean=0.0, std=std)
792
+ if module.bias is not None:
793
+ module.bias.data.zero_()
794
+ elif isinstance(module, nn.Embedding):
795
+ module.weight.data.normal_(mean=0.0, std=std)
796
+ if module.padding_idx is not None:
797
+ module.weight.data[module.padding_idx].zero_()
798
+
799
+
800
+ class VisionTowerConfig(PretrainedConfig):
801
+ model_type = 'moonvit3d'
802
+
803
+ def __init__(self, config: KimiK25Config, **kwargs):
804
+ super().__init__(**kwargs)
805
+ self.patch_size = config.patch_size
806
+ self.init_pos_emb_height = config.init_pos_emb_height
807
+ self.init_pos_emb_width = config.init_pos_emb_width
808
+ self.init_pos_emb_time = config.init_pos_emb_time
809
+ self.pos_emb_type = config.pos_emb_type
810
+ self.num_attention_heads = config.vt_num_attention_heads
811
+ self.num_hidden_layers = config.vt_num_hidden_layers
812
+ self.hidden_size = config.vt_hidden_size
813
+ self.intermediate_size = config.vt_intermediate_size
814
+ self.merge_kernel_size = config.merge_kernel_size
815
+ self.video_attn_type = config.video_attn_type
816
+ self.merge_type = config.merge_type
817
+ self._attn_implementation = config._attn_implementation
818
+
819
+
820
+ class ProjectorConfig:
821
+
822
+ def __init__(self, config: KimiK25Config):
823
+ self.mm_projector_type = config.mm_projector_type
824
+ self.mm_hidden_size = config.mm_hidden_size
825
+ self.hidden_size = config.text_hidden_size
826
+ self.merge_kernel_size = config.merge_kernel_size
827
+ self.projector_hidden_act = config.projector_hidden_act
828
+ self.projector_ln_eps = config.projector_ln_eps
829
+
830
+
831
+ # ref https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/llava/modeling_llava.py#L240
832
+ class KimiK25ForConditionalGeneration(KimiK25PreTrainedModel):
833
+
834
+ def __init__(self, config: KimiK25Config):
835
+ super().__init__(config)
836
+
837
+ vt_config = VisionTowerConfig(config.vision_config)
838
+ self.vision_tower = MoonViT3dPretrainedModel(vt_config)
839
+
840
+ proj_config = ProjectorConfig(config.vision_config)
841
+ if proj_config.mm_projector_type == 'identity':
842
+ self.mm_projector = IdentityMap()
843
+ elif proj_config.mm_projector_type == 'mlp':
844
+ self.mm_projector = MLP(proj_config)
845
+ elif proj_config.mm_projector_type == 'patchmerger':
846
+ self.mm_projector = PatchMergerMLP(proj_config)
847
+ else:
848
+ raise ValueError(
849
+ f"Unsupported mm_projector_type: {proj_config.mm_projector_type}"
850
+ )
851
+
852
+ self.language_model = DeepseekV3ForCausalLM(config.text_config)
853
+ self.post_init()
854
+
855
+ if hasattr(self.language_model, 'dtype'):
856
+ target_dtype = self.language_model.dtype
857
+ self.vision_tower = self.vision_tower.to(dtype=target_dtype)
858
+ self.mm_projector = self.mm_projector.to(dtype=target_dtype)
859
+
860
+ def get_input_embeddings(self):
861
+ return self.language_model.get_input_embeddings()
862
+
863
+ def set_input_embeddings(self, value):
864
+ self.language_model.set_input_embeddings(value)
865
+
866
+ def get_output_embeddings(self):
867
+ return self.language_model.get_output_embeddings()
868
+
869
+ def set_output_embeddings(self, new_embeddings):
870
+ self.language_model.set_output_embeddings(new_embeddings)
871
+
872
+ def set_decoder(self, decoder):
873
+ self.language_model.set_decoder(decoder)
874
+
875
+ def get_decoder(self):
876
+ return self.language_model.get_decoder()
877
+
878
+ def tie_weights(self):
879
+ return self.language_model.tie_weights()
880
+
881
+ def resize_token_embeddings(self,
882
+ new_num_tokens: int | None = None,
883
+ pad_to_multiple_of=None) -> nn.Embedding:
884
+ model_embeds = self.language_model.resize_token_embeddings(
885
+ new_num_tokens, pad_to_multiple_of)
886
+ # update vocab size
887
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
888
+ self.vocab_size = model_embeds.num_embeddings
889
+ return model_embeds
890
+
891
+ def _merge_input_ids_with_image_features(
892
+ self,
893
+ image_features: list[torch.Tensor],
894
+ inputs_embeds: torch.Tensor,
895
+ input_ids: torch.Tensor,
896
+ attention_mask: torch.Tensor,
897
+ labels: torch.Tensor | None = None,
898
+ ):
899
+ """
900
+ Args:
901
+ image_features (:obj:`torch.Tensor` of shape :obj:`(num_image_tokens, embed_dim)`):
902
+ The image features to merge with the input embeddings.
903
+ inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, embed_dim)`):
904
+ The input embeddings.
905
+ input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
906
+ The input ids.
907
+ attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
908
+ The attention mask.
909
+ labels (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, *optional*):
910
+ The labels.
911
+ """
912
+ _, embed_dim = image_features[0].shape
913
+ feature_lengths = [x.shape[0] for x in image_features]
914
+ image_features = torch.cat(image_features, dim=0)
915
+
916
+ image_token_index: int = self.config.media_placeholder_token_id
917
+ pad_token_id: int = self.config.pad_token_id
918
+ ignore_index: int = self.config.ignore_index
919
+
920
+ batch_size, sequence_length = input_ids.shape
921
+ left_padding = not torch.sum(
922
+ input_ids[:, -1] == torch.tensor(pad_token_id))
923
+
924
+ # 1. Create a mask to know where special image tokens are
925
+ _token_occupation_table = torch.ones_like(input_ids.flatten())
926
+ _token_occupation_table[input_ids.flatten() ==
927
+ image_token_index] = torch.tensor(
928
+ feature_lengths,
929
+ dtype=torch.long,
930
+ device=input_ids.device)
931
+ _token_occupation_table = _token_occupation_table.reshape(
932
+ input_ids.shape)
933
+
934
+ max_embed_dim = _token_occupation_table.sum(-1).max().item()
935
+ assert (
936
+ max_embed_dim >= sequence_length
937
+ ), f"The maximum embedding dimension ({max_embed_dim}) is less than the sequence length ({sequence_length})"
938
+ batch_indices, non_image_indices = torch.where(
939
+ input_ids != image_token_index)
940
+
941
+ # 2. Compute the positions where text should be written
942
+ # Calculate new positions for text tokens in merged image-text sequence.
943
+ new_token_positions = torch.cumsum(_token_occupation_table, -1) - 1
944
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
945
+ if left_padding:
946
+ new_token_positions += nb_image_pad[:,
947
+ None] # offset for left padding
948
+ text_to_overwrite = new_token_positions[batch_indices,
949
+ non_image_indices]
950
+
951
+ # 3. Create the full embedding, already padded to the maximum position
952
+ final_embedding = torch.zeros(
953
+ batch_size,
954
+ max_embed_dim,
955
+ embed_dim,
956
+ dtype=inputs_embeds.dtype,
957
+ device=inputs_embeds.device,
958
+ )
959
+ final_attention_mask = torch.zeros(batch_size,
960
+ max_embed_dim,
961
+ dtype=attention_mask.dtype,
962
+ device=inputs_embeds.device)
963
+ if labels is not None:
964
+ final_labels = torch.full(
965
+ (batch_size, max_embed_dim),
966
+ ignore_index,
967
+ dtype=input_ids.dtype,
968
+ device=input_ids.device,
969
+ )
970
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
971
+ # set the corresponding tensors into their correct target device.
972
+ target_device = inputs_embeds.device
973
+ batch_indices, non_image_indices, text_to_overwrite = (
974
+ batch_indices.to(target_device),
975
+ non_image_indices.to(target_device),
976
+ text_to_overwrite.to(target_device),
977
+ )
978
+ attention_mask = attention_mask.to(target_device)
979
+
980
+ # 4. Fill the embeddings based on the mask.
981
+ final_embedding[batch_indices,
982
+ text_to_overwrite] = inputs_embeds[batch_indices,
983
+ non_image_indices]
984
+ final_attention_mask[batch_indices,
985
+ text_to_overwrite] = attention_mask[
986
+ batch_indices, non_image_indices]
987
+ if labels is not None:
988
+ final_labels[batch_indices,
989
+ text_to_overwrite] = labels[batch_indices,
990
+ non_image_indices]
991
+
992
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
993
+ image_to_overwrite = torch.full((batch_size, max_embed_dim),
994
+ True,
995
+ dtype=torch.bool,
996
+ device=inputs_embeds.device)
997
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
998
+ image_to_overwrite &= image_to_overwrite.cumsum(
999
+ -1) - 1 >= nb_image_pad[:, None].to(target_device)
1000
+
1001
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
1002
+ raise ValueError(
1003
+ f"The input provided to the model are wrong. The number of image tokens is {image_to_overwrite.sum()} while"
1004
+ f" the number of image features given to the model is {image_features.shape[:-1].numel()}. "
1005
+ "This prevents correct indexing and breaks batch generation.")
1006
+
1007
+ final_embedding[image_to_overwrite] = (
1008
+ image_features.contiguous().reshape(-1,
1009
+ embed_dim).to(target_device))
1010
+ final_attention_mask |= image_to_overwrite
1011
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
1012
+ (final_attention_mask == 0), 1)
1013
+
1014
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
1015
+ batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
1016
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
1017
+
1018
+ final_embedding[batch_indices, indices_to_mask] = 0
1019
+
1020
+ if labels is None:
1021
+ final_labels = None
1022
+
1023
+ return final_embedding, final_attention_mask, final_labels, position_ids
1024
+
1025
+ def _extract_image_features(self, pixel_values: torch.Tensor,
1026
+ grid_thws: torch.Tensor) -> list[torch.Tensor]:
1027
+ """
1028
+ Args:
1029
+ pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
1030
+ The pixel values of the images processed by image processor.
1031
+ grid_thws (:obj:`torch.Tensor` of shape :obj:`(batch_size, 3)`):
1032
+ The grid, height, width of the images.
1033
+
1034
+ Returns:
1035
+ selected_image_feature (:obj:`torch.FloatTensor` of shape :obj:`(num_image_tokens, embed_dim)`):
1036
+ The selected image features to use as input to the projector head.
1037
+
1038
+ """
1039
+
1040
+ target_dtype = self.vision_tower.patch_embed.proj.weight.dtype
1041
+ pixel_values = pixel_values.to(target_dtype)
1042
+
1043
+ image_features = self.vision_tower(pixel_values, grid_thws)
1044
+ return image_features
1045
+
1046
+ def forward(
1047
+ self,
1048
+ input_ids: torch.LongTensor | None = None,
1049
+ pixel_values: torch.FloatTensor | list[torch.FloatTensor]
1050
+ | None = None,
1051
+ grid_thws: torch.Tensor | None = None,
1052
+ attention_mask: torch.Tensor | None = None,
1053
+ position_ids: torch.LongTensor | None = None,
1054
+ past_key_values: list[torch.FloatTensor] | None = None,
1055
+ inputs_embeds: torch.FloatTensor | None = None,
1056
+ labels: torch.LongTensor | None = None,
1057
+ use_cache: bool | None = None,
1058
+ output_attentions: bool | None = None,
1059
+ output_hidden_states: bool | None = None,
1060
+ return_dict: bool | None = None,
1061
+ ) -> tuple | LlavaCausalLMOutputWithPast:
1062
+ r"""
1063
+ Args:
1064
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1065
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1066
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1067
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1068
+
1069
+ ```"""
1070
+ assert self.vision_tower is not None, "vision_tower is not loaded"
1071
+ output_attentions = (output_attentions if output_attentions is not None
1072
+ else self.config.output_attentions)
1073
+ output_hidden_states = (output_hidden_states
1074
+ if output_hidden_states is not None else
1075
+ self.config.output_hidden_states)
1076
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1077
+
1078
+ if inputs_embeds is None:
1079
+ # 1. Extra the input embeddings
1080
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1081
+
1082
+ # 2. Merge text and images
1083
+ if pixel_values is not None and len(
1084
+ pixel_values) > 0 and input_ids.shape[1] != 1:
1085
+ image_features = self._extract_image_features(
1086
+ pixel_values, grid_thws)
1087
+ if self.mm_projector:
1088
+ image_features = self.mm_projector(image_features)
1089
+
1090
+ inputs_embeds = inputs_embeds.to(
1091
+ image_features[0].dtype) # num_tokens, embed_dim
1092
+ inputs_embeds, attention_mask, labels, position_ids = (
1093
+ self._merge_input_ids_with_image_features(
1094
+ image_features,
1095
+ inputs_embeds,
1096
+ input_ids,
1097
+ attention_mask,
1098
+ labels,
1099
+ ))
1100
+
1101
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1102
+ # generation with cache
1103
+ elif (past_key_values is not None and pixel_values is not None
1104
+ and input_ids.shape[1] == 1):
1105
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
1106
+ # that are set to 0
1107
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
1108
+
1109
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
1110
+ batch_index, non_attended_tokens = torch.where(
1111
+ first_layer_past_key_value.float().sum(-2) == 0)
1112
+
1113
+ # Get the target length
1114
+ target_length = input_ids.shape[1]
1115
+ past_length = first_layer_past_key_value.shape[-1]
1116
+
1117
+ extended_attention_mask = torch.ones(
1118
+ (attention_mask.shape[0], past_length),
1119
+ dtype=attention_mask.dtype,
1120
+ device=attention_mask.device,
1121
+ )
1122
+
1123
+ # Filter out only the tokens that can be un-attended, this can happen
1124
+ # if one uses Llava + Fused modules where the cache on the
1125
+ # first iteration is already big enough, or if one passes custom cache
1126
+ valid_indices = non_attended_tokens < extended_attention_mask.size(
1127
+ -1)
1128
+ new_batch_index = batch_index[valid_indices]
1129
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
1130
+
1131
+ # Zero-out the places where we don't need to attend
1132
+ extended_attention_mask[new_batch_index,
1133
+ new_non_attended_tokens] = 0
1134
+
1135
+ attention_mask = torch.cat(
1136
+ (extended_attention_mask, attention_mask[:,
1137
+ -target_length:]),
1138
+ dim=1)
1139
+ position_ids = torch.sum(attention_mask,
1140
+ dim=1).unsqueeze(-1) - 1
1141
+
1142
+ outputs = self.language_model(
1143
+ attention_mask=attention_mask,
1144
+ position_ids=position_ids,
1145
+ past_key_values=past_key_values,
1146
+ inputs_embeds=inputs_embeds,
1147
+ use_cache=use_cache,
1148
+ output_attentions=output_attentions,
1149
+ output_hidden_states=output_hidden_states,
1150
+ return_dict=return_dict,
1151
+ )
1152
+
1153
+ logits = outputs[0]
1154
+
1155
+ loss = None
1156
+ if labels is not None:
1157
+ # Shift so that tokens < n predict n
1158
+ if attention_mask is not None:
1159
+ shift_attention_mask = attention_mask[..., 1:]
1160
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(
1161
+ logits.device) != 0].contiguous()
1162
+ shift_labels = labels[..., 1:][shift_attention_mask.to(
1163
+ labels.device) != 0].contiguous()
1164
+ else:
1165
+ shift_logits = logits[..., :-1, :].contiguous()
1166
+ shift_labels = labels[..., 1:].contiguous()
1167
+ # Flatten the tokens
1168
+ loss_fct = nn.CrossEntropyLoss()
1169
+ loss = loss_fct(
1170
+ shift_logits.view(-1, shift_logits.size(-1)),
1171
+ shift_labels.view(-1).to(shift_logits.device),
1172
+ )
1173
+
1174
+ if not return_dict:
1175
+ output = (logits, ) + outputs[1:]
1176
+ return (loss, ) + output if loss is not None else output
1177
+
1178
+ return LlavaCausalLMOutputWithPast(
1179
+ loss=loss,
1180
+ logits=logits,
1181
+ past_key_values=outputs.past_key_values,
1182
+ hidden_states=outputs.hidden_states,
1183
+ attentions=outputs.attentions,
1184
+ )
1185
+
1186
+ def prepare_inputs_for_generation(
1187
+ self,
1188
+ input_ids,
1189
+ past_key_values=None,
1190
+ inputs_embeds=None,
1191
+ pixel_values=None,
1192
+ grid_thws=None,
1193
+ attention_mask=None,
1194
+ **kwargs,
1195
+ ):
1196
+ if past_key_values is not None:
1197
+ if isinstance(past_key_values, Cache):
1198
+ cache_length = past_key_values.get_seq_length()
1199
+ past_length = getattr(past_key_values, 'seen_tokens',
1200
+ cache_length)
1201
+ else:
1202
+ cache_length = past_length = past_key_values[0][0].shape[2]
1203
+
1204
+ # Keep only the unprocessed tokens:
1205
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1206
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1207
+ # input)
1208
+ if attention_mask is not None and attention_mask.shape[
1209
+ 1] > input_ids.shape[1]:
1210
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
1211
+ past_length):]
1212
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1213
+ # input_ids based on the past_length.
1214
+ elif past_length < input_ids.shape[1]:
1215
+ input_ids = input_ids[:, past_length:]
1216
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1217
+ elif self.config.media_placeholder_token_id in input_ids:
1218
+ input_ids = input_ids[:, input_ids.shape[1] - 1:]
1219
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1220
+ # older attention values, as their corresponding values are not part of the input.
1221
+ if cache_length < past_length and attention_mask is not None:
1222
+ attention_mask = attention_mask[:, -(cache_length +
1223
+ input_ids.shape[1]):]
1224
+
1225
+ position_ids = kwargs.get("position_ids", None)
1226
+ if attention_mask is not None and position_ids is None:
1227
+ # create position_ids on the fly for batch generation
1228
+ position_ids = attention_mask.long().cumsum(-1) - 1
1229
+ position_ids.masked_fill_(attention_mask == 0, 1)
1230
+ if past_key_values:
1231
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1232
+
1233
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1234
+ if inputs_embeds is not None and past_key_values is None:
1235
+ model_inputs = {"inputs_embeds": inputs_embeds}
1236
+ else:
1237
+ model_inputs = {"input_ids": input_ids}
1238
+
1239
+ model_inputs.update({
1240
+ "position_ids": position_ids,
1241
+ "past_key_values": past_key_values,
1242
+ "use_cache": kwargs.get("use_cache"),
1243
+ "attention_mask": attention_mask,
1244
+ "pixel_values": pixel_values,
1245
+ "grid_thws": grid_thws,
1246
+ })
1247
+ return model_inputs
1248
+
1249
+ def _reorder_cache(self, *args, **kwargs):
1250
+ return self.language_model._reorder_cache(*args, **kwargs)
1251
+
preprocessor_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "kimi_k25_processor.KimiK25Processor",
4
+ "AutoImageProcessor": "kimi_k25_vision_processing.KimiK25VisionProcessor"
5
+ },
6
+ "media_proc_cfg": {
7
+ "in_patch_limit": 16384,
8
+ "patch_size": 14,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "merge_kernel_size": 2,
20
+ "fixed_output_tokens": null,
21
+ "patch_limit_on_one_side": 512,
22
+ "in_patch_limit_each_frame": 4096,
23
+ "in_patch_limit_video": null,
24
+ "sample_fps": 2.0,
25
+ "max_num_frames_each_video": null,
26
+ "temporal_merge_kernel_size": 4,
27
+ "timestamp_mode": "hh:mm:ss.fff",
28
+ "config_type": "media_proc.processors.moonvit.MoonViTMediaProcessorConfig"
29
+ }
30
+ }
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "kimi_k25_processor.KimiK25Processor"
4
+ },
5
+ "processor_class": "KimiK25Processor"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_end|>",
4
+ "<|im_user|>",
5
+ "<|im_assistant|>",
6
+ "<|start_header_id|>",
7
+ "<|end_header_id|>",
8
+ "[EOT]",
9
+ "<|im_system|>",
10
+ "<|im_middle|>",
11
+ "<|media_begin|>",
12
+ "<|media_content|>",
13
+ "<|media_end|>",
14
+ "<|media_pad|>"
15
+ ],
16
+ "bos_token": {
17
+ "content": "[BOS]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "eos_token": {
24
+ "content": "[EOS]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": "[EOS]",
31
+ "unk_token": {
32
+ "content": "[UNK]",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ }
38
+ }
tiktoken.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103
3
+ size 2795286
tokenization_kimi.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from shutil import copyfile
6
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
+
8
+ import tiktoken
9
+ from tiktoken.load import load_tiktoken_bpe
10
+ from tokenizers import AddedToken
11
+
12
+ try:
13
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
14
+ except:
15
+ from transformers.convert_slow_tokenizer import bytes_to_unicode
16
+
17
+ from transformers.tokenization_utils import PreTrainedTokenizer
18
+
19
+ from .tool_declaration_ts import encode_tools_to_typescript_style
20
+
21
+ logger = getLogger(__name__)
22
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
23
+
24
+
25
+ class TikTokenTokenizer(PreTrainedTokenizer):
26
+ """
27
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
28
+
29
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
30
+ this superclass for more information regarding those methods.
31
+
32
+ Args:
33
+ vocab_file (`str`):
34
+ The path to the Tiktoken model file.
35
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
36
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
37
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
38
+ The end of sequence token.
39
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
40
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
41
+ token instead. The second to last item in special_tokens.
42
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
43
+ The token used for padding, for example when batching sequences of different lengths.
44
+ additional_special_tokens (list of `str`, *optional*):
45
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
46
+ skipped when decoding if `skip_special_tokens` is set to `True`.
47
+ """
48
+
49
+ vocab_files_names = VOCAB_FILES_NAMES
50
+
51
+ model_input_names = ["input_ids", "attention_mask"]
52
+
53
+ special_tokens: Dict[str, int]
54
+
55
+ num_reserved_special_tokens = 256
56
+
57
+ pat_str = "|".join([
58
+ r"""[\p{Han}]+""",
59
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
60
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
61
+ r"""\p{N}{1,3}""",
62
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
63
+ r"""\s*[\r\n]+""",
64
+ r"""\s+(?!\S)""",
65
+ r"""\s+""",
66
+ ])
67
+
68
+ def __init__(
69
+ self,
70
+ vocab_file,
71
+ bos_token: Union[str, AddedToken] = "[BOS]",
72
+ eos_token: Union[str, AddedToken] = "[EOS]",
73
+ unk_token: Union[str, AddedToken, None] = None,
74
+ pad_token: Union[str, AddedToken, None] = None,
75
+ additional_special_tokens: List[str] = None,
76
+ added_tokens_decoder: Optional[dict] = None,
77
+ **kwargs,
78
+ ):
79
+ assert os.path.isfile(vocab_file), vocab_file
80
+
81
+ if additional_special_tokens is None:
82
+ additional_special_tokens = [
83
+ "<|im_end|>",
84
+ "<|im_user|>",
85
+ "<|im_assistant|>",
86
+ "<|start_header_id|>",
87
+ "<|end_header_id|>",
88
+ "[EOT]",
89
+ "<|im_system|>",
90
+ "<|im_middle|>",
91
+ ]
92
+
93
+ if added_tokens_decoder:
94
+ special_tokens_mapping = {
95
+ i: added_tokens_decoder[i].content
96
+ for i in added_tokens_decoder
97
+ }
98
+ else:
99
+ special_tokens_mapping = {}
100
+
101
+ self.vocab_file = vocab_file
102
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
103
+ num_base_tokens = len(mergeable_ranks)
104
+ self.special_tokens = {
105
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
106
+ for i in range(num_base_tokens, num_base_tokens +
107
+ self.num_reserved_special_tokens)
108
+ }
109
+
110
+ self.model = tiktoken.Encoding(
111
+ name=Path(vocab_file).name,
112
+ pat_str=self.pat_str,
113
+ mergeable_ranks=mergeable_ranks,
114
+ special_tokens=self.special_tokens,
115
+ )
116
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
117
+
118
+ self.n_words: int = self.model.n_vocab
119
+ # BOS / EOS token IDs
120
+ self.bos_id: int = self.special_tokens[str(bos_token)]
121
+ self.eos_id: int = self.special_tokens[str(eos_token)]
122
+ logger.info(
123
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
124
+ )
125
+
126
+ self.pad_id: int = self.special_tokens[str(pad_token)]
127
+ self.unk_id: int = self.special_tokens[str(unk_token)]
128
+
129
+ self.byte_encoder = bytes_to_unicode()
130
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
131
+
132
+ self.decoder = {}
133
+ for i in range(self.n_words):
134
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
135
+ decoding = ''.join([
136
+ self.byte_encoder[ord(char)] for char in
137
+ self.model.decode_single_token_bytes(i).decode('latin-1')
138
+ ])
139
+ self.decoder[i] = decoding
140
+
141
+ self.encoder = {}
142
+ for i in range(self.n_words):
143
+ if i in self.decoder:
144
+ self.encoder[self.decoder[i]] = i
145
+
146
+ self._token_config_cache = OrderedDict()
147
+ self._cache_max_size = 128
148
+
149
+ super().__init__(
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ unk_token=unk_token,
153
+ pad_token=pad_token,
154
+ additional_special_tokens=additional_special_tokens,
155
+ added_tokens_decoder=added_tokens_decoder,
156
+ **kwargs,
157
+ )
158
+ self.all_special_ids_set = set(self.all_special_ids)
159
+
160
+ def encode(self,
161
+ text: str,
162
+ allow_special_tokens: bool = True,
163
+ add_special_tokens: bool = True,
164
+ truncation: Optional[bool] = None,
165
+ max_length: Optional[int] = None,
166
+ **kwargs) -> List[int]:
167
+ """
168
+ Encodes a string into a list of token IDs.
169
+
170
+ Args:
171
+ text (str): The input string to be encoded.
172
+ allow_special_tokens (bool): Whether to allow special tokens
173
+ in the input text (tiktoken-level).
174
+ add_special_tokens (bool): HF-compatible param. Accepted for
175
+ API compatibility; Kimi does not override
176
+ build_inputs_with_special_tokens so this is a no-op.
177
+ truncation (bool, optional): Whether to truncate to max_length.
178
+ max_length (int, optional): Maximum token length when truncation
179
+ is enabled.
180
+
181
+ Returns:
182
+ list[int]: A list of token IDs.
183
+ """
184
+ # If there are unknown args, fall back to super().encode which
185
+ # handles the full HF protocol (padding, stride, etc.).
186
+ if len(kwargs) > 0:
187
+ import time as _time, sys as _sys
188
+ _t0 = _time.perf_counter()
189
+ print(f"[KIMI encode] SLOW PATH entry: kwargs={kwargs}, "
190
+ f"text_len={len(text)} chars", file=_sys.stderr, flush=True)
191
+ result = super().encode(text, **kwargs)
192
+ _ms = (_time.perf_counter() - _t0) * 1000
193
+ print(f"[KIMI encode] SLOW PATH done: {_ms:.1f}ms, "
194
+ f"{len(result)} tokens", file=_sys.stderr, flush=True)
195
+ return result
196
+
197
+ assert type(text) is str
198
+
199
+ # The tiktoken tokenizer can handle <=400k chars without
200
+ # pyo3_runtime.PanicException.
201
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
202
+
203
+ # https://github.com/openai/tiktoken/issues/195
204
+ # Here we iterate over subsequences and split if we exceed the limit
205
+ # of max consecutive non-whitespace or whitespace characters.
206
+ MAX_NO_WHITESPACES_CHARS = 25_000
207
+
208
+ texts = self.pre_tokenizer_process(text)
209
+
210
+ all_substrs = []
211
+ for text in texts:
212
+ substrs = (
213
+ substr for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
214
+ for substr in self._split_whitespaces_or_nonwhitespaces(
215
+ text[i:i +
216
+ TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS))
217
+ all_substrs.extend(substrs)
218
+
219
+ import time as _time, sys as _sys
220
+ _t0 = _time.perf_counter()
221
+
222
+ t: List[int] = []
223
+ for substr in all_substrs:
224
+ if allow_special_tokens:
225
+ t.extend(
226
+ # we should consider special token as a common token
227
+ self.model.encode(
228
+ substr,
229
+ allowed_special="all",
230
+ ))
231
+ else:
232
+ t.extend(
233
+ # we should consider special token as a common token
234
+ self.model.encode(
235
+ substr,
236
+ disallowed_special=(),
237
+ ))
238
+
239
+ if truncation and max_length is not None and len(t) > max_length:
240
+ t = t[:max_length]
241
+
242
+ _ms = (_time.perf_counter() - _t0) * 1000
243
+ print(f"[KIMI encode] FAST PATH (tiktoken): {_ms:.1f}ms, "
244
+ f"{len(t)} tokens, text_len={len(text)} chars, "
245
+ f"truncation={truncation}, max_length={max_length}",
246
+ file=_sys.stderr, flush=True)
247
+ return t
248
+
249
+ def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
250
+ """
251
+ Decodes a list of token IDs into a string.
252
+
253
+ Args:
254
+ token_ids (List[int]): The list of token IDs to be decoded.
255
+
256
+ Returns:
257
+ str: The decoded string.
258
+ """
259
+ # If there are other args, we should call super().decode because there are a lot of code
260
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
261
+ if len(kwargs) > 0:
262
+ return super().decode(token_ids, **kwargs)
263
+
264
+ if type(token_ids) is int:
265
+ token_ids = [token_ids]
266
+
267
+ return self.model.decode(cast(List[int], token_ids))
268
+
269
+ @staticmethod
270
+ def _split_whitespaces_or_nonwhitespaces(
271
+ s: str, max_consecutive_slice_len: int) -> Iterator[str]:
272
+ """
273
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
274
+ consecutive whitespaces or consecutive non-whitespaces.
275
+ """
276
+ current_slice_len = 0
277
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
278
+ slice_start = 0
279
+
280
+ for i in range(len(s)):
281
+ is_now_space = s[i].isspace()
282
+
283
+ if current_slice_is_space ^ is_now_space:
284
+ current_slice_len = 1
285
+ current_slice_is_space = is_now_space
286
+ else:
287
+ current_slice_len += 1
288
+ if current_slice_len > max_consecutive_slice_len:
289
+ yield s[slice_start:i]
290
+ slice_start = i
291
+ current_slice_len = 1
292
+ yield s[slice_start:]
293
+
294
+ def pre_tokenizer_process(self, text: str) -> List[str]:
295
+ """
296
+ pre-tokenizes the input text into a list of tokens.
297
+ This method is used to split the input text into smaller chunks for internal processing.
298
+ """
299
+ return [text]
300
+
301
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
302
+
303
+ @property
304
+ def vocab_size(self) -> int:
305
+ return self.n_words
306
+
307
+ def get_vocab(self) -> Dict[str, int]:
308
+ return self.encoder
309
+
310
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
311
+ import time as _time, sys as _sys
312
+ _t0 = _time.perf_counter()
313
+ ids = self.encode(text)
314
+ _t1 = _time.perf_counter()
315
+ strs = [self.decoder[t] for t in ids]
316
+ _t2 = _time.perf_counter()
317
+ print(f"[KIMI _tokenize] encode={(_t1-_t0)*1000:.1f}ms, "
318
+ f"id→str={(_t2-_t1)*1000:.1f}ms, "
319
+ f"{len(ids)} tokens, text_len={len(text)} chars",
320
+ file=_sys.stderr, flush=True)
321
+ return strs
322
+
323
+ def _convert_token_to_id(self, token: str) -> int:
324
+ return self.encoder.get(token, self.unk_id)
325
+
326
+ def _convert_id_to_token(self, index: int) -> str:
327
+ return self.decoder.get(index)
328
+
329
+ @staticmethod
330
+ def clean_up_tokenization(out_string: str) -> str:
331
+ return out_string
332
+
333
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
334
+ text = ''.join(tokens)
335
+ text = bytearray([self.byte_decoder[c]
336
+ for c in text]).decode('utf-8', 'replace')
337
+ return text
338
+
339
+ def save_vocabulary(self,
340
+ save_directory: str,
341
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
342
+ if not os.path.isdir(save_directory):
343
+ raise ValueError(
344
+ f"vocabulary path ({save_directory}) should be a directory")
345
+ out_vocab_file = os.path.join(
346
+ save_directory,
347
+ (filename_prefix + "-" if filename_prefix else "") +
348
+ VOCAB_FILES_NAMES["vocab_file"])
349
+
350
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
351
+ out_vocab_file) and os.path.isfile(self.vocab_file):
352
+ copyfile(self.vocab_file, out_vocab_file)
353
+
354
+ return (out_vocab_file, )
355
+
356
+ def apply_chat_template(self,
357
+ conversation,
358
+ tools: Optional[list[dict]] = None,
359
+ tokenize: bool = False,
360
+ add_generation_prompt: bool = True,
361
+ thinking: bool = True,
362
+ **kwargs):
363
+
364
+ tools = deep_sort_dict(tools)
365
+
366
+ # Convert tools to TypeScript style string if tools are provided
367
+ tools_ts_str = None
368
+ if tools:
369
+ try:
370
+ tools_ts_str = encode_tools_to_typescript_style(tools)
371
+
372
+ except Exception as e:
373
+ print(f"Failed to convert tools to TypeScript style: {e}")
374
+ tools_ts_str = None
375
+
376
+ # Store the TypeScript string in kwargs so it can be accessed by the template
377
+ if tools_ts_str is not None:
378
+ kwargs['tools_ts_str'] = tools_ts_str
379
+ return super().apply_chat_template(
380
+ conversation,
381
+ tools=tools,
382
+ tokenize=tokenize,
383
+ add_generation_prompt=add_generation_prompt,
384
+ thinking=thinking,
385
+ **kwargs)
386
+
387
+
388
+ def deep_sort_dict(obj: Any) -> Any:
389
+ if isinstance(obj, dict):
390
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
391
+ if isinstance(obj, list):
392
+ return [deep_sort_dict(item) for item in obj]
393
+ return obj
tokenizer_config.json ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "163584": {
4
+ "content": "[BOS]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "163585": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "163586": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "163587": {
28
+ "content": "<|im_user|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "163588": {
36
+ "content": "<|im_assistant|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "163590": {
44
+ "content": "<|start_header_id|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "163591": {
52
+ "content": "<|end_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "163593": {
60
+ "content": "[EOT]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "163594": {
68
+ "content": "<|im_system|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "163595": {
76
+ "content": "<|tool_calls_section_begin|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "163596": {
84
+ "content": "<|tool_calls_section_end|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "163597": {
92
+ "content": "<|tool_call_begin|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "163598": {
100
+ "content": "<|tool_call_argument_begin|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "163599": {
108
+ "content": "<|tool_call_end|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "163601": {
116
+ "content": "<|im_middle|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "163602": {
124
+ "content": "<|media_begin|>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "163603": {
132
+ "content": "<|media_content|>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "163604": {
140
+ "content": "<|media_end|>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "163605": {
148
+ "content": "<|media_pad|>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "163606": {
156
+ "content": "<think>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "163607": {
164
+ "content": "</think>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "163838": {
172
+ "content": "[UNK]",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ },
179
+ "163839": {
180
+ "content": "[PAD]",
181
+ "lstrip": false,
182
+ "normalized": false,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": true
186
+ }
187
+ },
188
+ "additional_special_tokens": [
189
+ "<|im_end|>",
190
+ "<|im_user|>",
191
+ "<|im_assistant|>",
192
+ "<|start_header_id|>",
193
+ "<|end_header_id|>",
194
+ "[EOT]",
195
+ "<|im_system|>",
196
+ "<|im_middle|>",
197
+ "<|media_begin|>",
198
+ "<|media_content|>",
199
+ "<|media_end|>",
200
+ "<|media_pad|>"
201
+ ],
202
+ "auto_map": {
203
+ "AutoTokenizer": [
204
+ "tokenization_kimi.TikTokenTokenizer",
205
+ null
206
+ ]
207
+ },
208
+ "bos_token": "[BOS]",
209
+ "clean_up_tokenization_spaces": false,
210
+ "eos_token": "[EOS]",
211
+ "extra_special_tokens": {},
212
+ "model_max_length": 1000000000000000019884624838656,
213
+ "pad_token": "[EOS]",
214
+ "tokenizer_class": "TikTokenTokenizer",
215
+ "unk_token": "[UNK]"
216
+ }
tool_declaration_ts.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encode structured tool declaration to typescript style string.
3
+ """
4
+ import dataclasses
5
+ import json
6
+ import logging
7
+ from collections.abc import Sequence
8
+ from typing import Any
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _TS_INDENT = " "
13
+ _TS_FIELD_DELIMITER = ",\n"
14
+
15
+
16
+ class _SchemaRegistry:
17
+ """Registry for schema definitions to handle $ref resolution"""
18
+
19
+ def __init__(self):
20
+ self.definitions = {}
21
+ self.has_self_ref = False
22
+
23
+ def register_definitions(self, defs: dict[str, Any]):
24
+ """Register schema definitions from $defs section"""
25
+ if not defs:
26
+ return
27
+ for def_name, def_schema in defs.items():
28
+ self.definitions[def_name] = def_schema
29
+
30
+ def resolve_ref(self, ref: str) -> dict[str, Any]:
31
+ """Resolve a reference to its schema definition"""
32
+ if ref == "#":
33
+ self.has_self_ref = True
34
+ return {"$self_ref": True}
35
+ elif ref.startswith("#/$defs/"):
36
+ def_name = ref.split("/")[-1]
37
+ if def_name not in self.definitions:
38
+ raise ValueError(f"Reference not found: {ref}")
39
+ return self.definitions[def_name]
40
+ else:
41
+ raise ValueError(f"Unsupported reference format: {ref}")
42
+
43
+
44
+ def _format_description(description: str, indent: str = "") -> str:
45
+ return "\n".join([
46
+ f"{indent}// {line}" if line else ""
47
+ for line in description.split("\n")
48
+ ])
49
+
50
+
51
+ class _BaseType:
52
+ description: str
53
+ constraints: dict[str, Any]
54
+
55
+ def __init__(
56
+ self,
57
+ extra_props: dict[str, Any],
58
+ *,
59
+ allowed_constraint_keys: Sequence[str] = (),
60
+ ):
61
+ self.description = extra_props.get("description", "")
62
+ self.constraints = {
63
+ k: v
64
+ for k, v in extra_props.items() if k in allowed_constraint_keys
65
+ }
66
+
67
+ def to_typescript_style(self, indent: str = "") -> str:
68
+ raise NotImplementedError
69
+
70
+ def format_docstring(self, indent: str) -> str:
71
+ lines = []
72
+ if self.description:
73
+ lines.append(_format_description(self.description, indent))
74
+ if self.constraints:
75
+ constraints_str = ", ".join(f"{k}: {v}" for k, v in sorted(
76
+ self.constraints.items(), key=lambda kv: kv[0]))
77
+ lines.append(f"{indent}// {constraints_str}")
78
+
79
+ return "".join(x + "\n" for x in lines)
80
+
81
+
82
+ class _ParameterTypeScalar(_BaseType):
83
+ type: str
84
+
85
+ def __init__(self, type: str, extra_props: dict[str, Any] | None = None):
86
+ self.type = type
87
+
88
+ allowed_constraint_keys: list[str] = []
89
+ if self.type == "string":
90
+ allowed_constraint_keys = ["maxLength", "minLength", "pattern"]
91
+ elif self.type in ("number", "integer"):
92
+ allowed_constraint_keys = ["maximum", "minimum"]
93
+
94
+ super().__init__(extra_props or {},
95
+ allowed_constraint_keys=allowed_constraint_keys)
96
+
97
+ def to_typescript_style(self, indent: str = "") -> str:
98
+ # Map integer to number in TypeScript
99
+ if self.type == "integer":
100
+ return "number"
101
+ return self.type
102
+
103
+
104
+ class _ParameterTypeObject(_BaseType):
105
+ properties: list["_Parameter"]
106
+ additional_properties: Any | None = None
107
+
108
+ def __init__(self,
109
+ json_schema_object: dict[str, Any],
110
+ registry: _SchemaRegistry | None = None):
111
+ super().__init__(json_schema_object)
112
+
113
+ self.properties = []
114
+ self.additional_properties = None
115
+
116
+ if not json_schema_object:
117
+ return
118
+
119
+ if "$defs" in json_schema_object and registry:
120
+ registry.register_definitions(json_schema_object["$defs"])
121
+
122
+ self.additional_properties = json_schema_object.get(
123
+ "additionalProperties")
124
+ if isinstance(self.additional_properties, dict):
125
+ self.additional_properties = _parse_parameter_type(
126
+ self.additional_properties, registry)
127
+
128
+ if "properties" not in json_schema_object:
129
+ return
130
+
131
+ required_parameters = json_schema_object.get("required", [])
132
+ optional_parameters = set(
133
+ json_schema_object["properties"].keys()) - set(required_parameters)
134
+
135
+ self.properties = [
136
+ _Parameter(
137
+ name=name,
138
+ type=_parse_parameter_type(prop, registry),
139
+ optional=name in optional_parameters,
140
+ default=prop.get("default")
141
+ if isinstance(prop, dict) else None,
142
+ ) for name, prop in json_schema_object["properties"].items()
143
+ ]
144
+
145
+ def to_typescript_style(self, indent: str = "") -> str:
146
+ # sort by optional, make the required parameters first
147
+ parameters = [p for p in self.properties if not p.optional]
148
+ opt_params = [p for p in self.properties if p.optional]
149
+
150
+ parameters = sorted(parameters, key=lambda p: p.name)
151
+ parameters.extend(sorted(opt_params, key=lambda p: p.name))
152
+
153
+ param_strs = []
154
+ for p in parameters:
155
+ one = p.to_typescript_style(indent=indent + _TS_INDENT)
156
+ param_strs.append(one)
157
+
158
+ if self.additional_properties is not None:
159
+ ap_type_str = "any"
160
+ if self.additional_properties is True:
161
+ ap_type_str = "any"
162
+ elif self.additional_properties is False:
163
+ ap_type_str = "never"
164
+ elif isinstance(self.additional_properties, _ParameterType):
165
+ ap_type_str = self.additional_properties.to_typescript_style(
166
+ indent=indent + _TS_INDENT)
167
+ else:
168
+ raise ValueError(
169
+ f"Unknown additionalProperties: {self.additional_properties}"
170
+ )
171
+ param_strs.append(
172
+ f"{indent + _TS_INDENT}[k: string]: {ap_type_str}")
173
+
174
+ if not param_strs:
175
+ return "{}"
176
+
177
+ params_str = _TS_FIELD_DELIMITER.join(param_strs)
178
+ if params_str:
179
+ # add new line before and after
180
+ params_str = f"\n{params_str}\n"
181
+ # always wrap with object
182
+ return f"{{{params_str}{indent}}}"
183
+
184
+
185
+ class _ParameterTypeArray(_BaseType):
186
+ item: "_ParameterType"
187
+
188
+ def __init__(self,
189
+ json_schema_object: dict[str, Any],
190
+ registry: _SchemaRegistry | None = None):
191
+ super().__init__(json_schema_object,
192
+ allowed_constraint_keys=("minItems", "maxItems"))
193
+ if json_schema_object.get("items"):
194
+ self.item = _parse_parameter_type(json_schema_object["items"],
195
+ registry)
196
+ else:
197
+ self.item = _ParameterTypeScalar(type="any")
198
+
199
+ def to_typescript_style(self, indent: str = "") -> str:
200
+ item_docstring = self.item.format_docstring(indent + _TS_INDENT)
201
+ if item_docstring:
202
+ return ("Array<\n" + item_docstring + indent + _TS_INDENT +
203
+ self.item.to_typescript_style(indent=indent + _TS_INDENT) +
204
+ "\n" + indent + ">")
205
+ else:
206
+ return f"Array<{self.item.to_typescript_style(indent=indent)}>"
207
+
208
+
209
+ class _ParameterTypeEnum(_BaseType):
210
+ # support scalar types only
211
+ enum: list[str | int | float | bool | None]
212
+
213
+ def __init__(self, json_schema_object: dict[str, Any]):
214
+ super().__init__(json_schema_object)
215
+ self.enum = json_schema_object["enum"]
216
+
217
+ # Validate enum values against declared type if present
218
+ if "type" in json_schema_object:
219
+ typ = json_schema_object["type"]
220
+ if isinstance(typ, list):
221
+ if len(typ) == 1:
222
+ typ = typ[0]
223
+ elif len(typ) == 2:
224
+ if "null" not in typ:
225
+ raise ValueError(f"Enum type {typ} is not supported")
226
+ else:
227
+ typ = typ[0] if typ[0] != "null" else typ[1]
228
+ else:
229
+ raise ValueError(f"Enum type {typ} is not supported")
230
+ for val in self.enum:
231
+ if val is None:
232
+ continue
233
+ if typ == "string" and not isinstance(val, str):
234
+ raise ValueError(f"Enum value {val} is not a string")
235
+ elif typ == "number" and not isinstance(val, (int, float)):
236
+ raise ValueError(f"Enum value {val} is not a number")
237
+ elif typ == "integer" and not isinstance(val, int):
238
+ raise ValueError(f"Enum value {val} is not an integer")
239
+ elif typ == "boolean" and not isinstance(val, bool):
240
+ raise ValueError(f"Enum value {val} is not a boolean")
241
+
242
+ def to_typescript_style(self, indent: str = "") -> str:
243
+ return " | ".join(
244
+ [f'"{e}"' if isinstance(e, str) else str(e) for e in self.enum])
245
+
246
+
247
+ class _ParameterTypeAnyOf(_BaseType):
248
+ types: list["_ParameterType"]
249
+
250
+ def __init__(
251
+ self,
252
+ json_schema_object: dict[str, Any],
253
+ registry: _SchemaRegistry | None = None,
254
+ ):
255
+ super().__init__(json_schema_object)
256
+ self.types = [
257
+ _parse_parameter_type(t, registry)
258
+ for t in json_schema_object["anyOf"]
259
+ ]
260
+
261
+ def to_typescript_style(self, indent: str = "") -> str:
262
+ return " | ".join(
263
+ [t.to_typescript_style(indent=indent) for t in self.types])
264
+
265
+
266
+ class _ParameterTypeUnion(_BaseType):
267
+ types: list[str]
268
+
269
+ def __init__(self, json_schema_object: dict[str, Any]):
270
+ super().__init__(json_schema_object)
271
+
272
+ mapping = {
273
+ "string": "string",
274
+ "number": "number",
275
+ "integer": "number",
276
+ "boolean": "boolean",
277
+ "null": "null",
278
+ "object": "{}",
279
+ "array": "Array<any>",
280
+ }
281
+ self.types = [mapping[t] for t in json_schema_object["type"]]
282
+
283
+ def to_typescript_style(self, indent: str = "") -> str:
284
+ return " | ".join(self.types)
285
+
286
+
287
+ class _ParameterTypeRef(_BaseType):
288
+ ref_name: str
289
+ is_self_ref: bool = False
290
+
291
+ def __init__(self, json_schema_object: dict[str, Any],
292
+ registry: _SchemaRegistry):
293
+ super().__init__(json_schema_object)
294
+
295
+ ref = json_schema_object["$ref"]
296
+ resolved_schema = registry.resolve_ref(ref)
297
+
298
+ if resolved_schema.get("$self_ref", False):
299
+ self.ref_name = "parameters"
300
+ self.is_self_ref = True
301
+ else:
302
+ self.ref_name = ref.split("/")[-1]
303
+
304
+ def to_typescript_style(self, indent: str = "") -> str:
305
+ return self.ref_name
306
+
307
+
308
+ _ParameterType = (_ParameterTypeScalar
309
+ | _ParameterTypeObject
310
+ | _ParameterTypeArray
311
+ | _ParameterTypeEnum
312
+ | _ParameterTypeAnyOf
313
+ | _ParameterTypeUnion
314
+ | _ParameterTypeRef)
315
+
316
+
317
+ @dataclasses.dataclass
318
+ class _Parameter:
319
+ """
320
+ A parameter in a function, or a field in a object.
321
+ It consists of the type as well as the name.
322
+ """
323
+
324
+ type: _ParameterType
325
+ name: str = "_"
326
+ optional: bool = True
327
+ default: Any | None = None
328
+
329
+ @classmethod
330
+ def parse_extended(cls, attributes: dict[str, Any]) -> "_Parameter":
331
+ if not attributes:
332
+ raise ValueError("attributes is empty")
333
+
334
+ return cls(
335
+ name=attributes.get("name", "_"),
336
+ type=_parse_parameter_type(attributes),
337
+ optional=attributes.get("optional", False),
338
+ default=attributes.get("default"),
339
+ )
340
+
341
+ def to_typescript_style(self, indent: str = "") -> str:
342
+ comments = self.type.format_docstring(indent)
343
+
344
+ if self.default is not None:
345
+ default_repr = (json.dumps(self.default, ensure_ascii=False)
346
+ if not isinstance(self.default, (int, float, bool))
347
+ else repr(self.default))
348
+ comments += f"{indent}// Default: {default_repr}\n"
349
+
350
+ return (
351
+ comments +
352
+ f"{indent}{self.name}{'?' if self.optional else ''}: {self.type.to_typescript_style(indent=indent)}"
353
+ )
354
+
355
+
356
+ def _parse_parameter_type(
357
+ json_schema_object: dict[str, Any] | bool,
358
+ registry: _SchemaRegistry | None = None) -> _ParameterType:
359
+ if isinstance(json_schema_object, bool):
360
+ if json_schema_object:
361
+ return _ParameterTypeScalar(type="any")
362
+ else:
363
+ logger.warning(
364
+ f"Warning: Boolean value {json_schema_object} is not supported, use null instead."
365
+ )
366
+ return _ParameterTypeScalar(type="null")
367
+
368
+ if "$ref" in json_schema_object and registry:
369
+ return _ParameterTypeRef(json_schema_object, registry)
370
+
371
+ if "anyOf" in json_schema_object:
372
+ return _ParameterTypeAnyOf(json_schema_object, registry)
373
+ elif "enum" in json_schema_object:
374
+ return _ParameterTypeEnum(json_schema_object)
375
+ elif "type" in json_schema_object:
376
+ typ = json_schema_object["type"]
377
+ if isinstance(typ, list):
378
+ return _ParameterTypeUnion(json_schema_object)
379
+ elif typ == "object":
380
+ return _ParameterTypeObject(json_schema_object, registry)
381
+ elif typ == "array":
382
+ return _ParameterTypeArray(json_schema_object, registry)
383
+ else:
384
+ return _ParameterTypeScalar(typ, json_schema_object)
385
+ elif json_schema_object == {}:
386
+ return _ParameterTypeScalar(type="any")
387
+ else:
388
+ raise ValueError(f"Invalid JSON Schema object: {json_schema_object}")
389
+
390
+
391
+ def _openai_function_to_typescript_style(function: dict[str, Any], ) -> str:
392
+ """Convert OpenAI function definition (dict) to TypeScript style string."""
393
+ registry = _SchemaRegistry()
394
+ parameters = function.get("parameters") or {}
395
+ parsed = _ParameterTypeObject(parameters, registry)
396
+
397
+ interfaces = []
398
+ root_interface_name = None
399
+ if registry.has_self_ref:
400
+ root_interface_name = "parameters"
401
+ params_str = _TS_FIELD_DELIMITER.join([
402
+ p.to_typescript_style(indent=_TS_INDENT) for p in parsed.properties
403
+ ])
404
+ params_str = f"\n{params_str}\n" if params_str else ""
405
+ interface_def = f"interface {root_interface_name} {{{params_str}}}"
406
+ interfaces.append(interface_def)
407
+
408
+ definitions_copy = dict(registry.definitions)
409
+ for def_name, def_schema in definitions_copy.items():
410
+ obj_type = _parse_parameter_type(def_schema, registry)
411
+ params_str = obj_type.to_typescript_style()
412
+
413
+ description_part = ""
414
+ if obj_description := def_schema.get("description", ""):
415
+ description_part = _format_description(obj_description) + "\n"
416
+
417
+ interface_def = f"{description_part}interface {def_name} {params_str}"
418
+ interfaces.append(interface_def)
419
+
420
+ interface_str = "\n".join(interfaces)
421
+ function_name = function.get("name", "function")
422
+ if root_interface_name:
423
+ type_def = f"type {function_name} = (_: {root_interface_name}) => any;"
424
+ else:
425
+ params_str = parsed.to_typescript_style()
426
+ type_def = f"type {function_name} = (_: {params_str}) => any;"
427
+
428
+ description = function.get("description")
429
+ return "\n".join(
430
+ filter(
431
+ bool,
432
+ [
433
+ interface_str,
434
+ ((description and _format_description(description)) or ""),
435
+ type_def,
436
+ ],
437
+ ))
438
+
439
+
440
+ def encode_tools_to_typescript_style(tools: list[dict[str, Any]], ) -> str:
441
+ """
442
+ Convert tools (list of dict) to TypeScript style string.
443
+
444
+ Supports OpenAI format: {"type": "function", "function": {...}}
445
+
446
+ Args:
447
+ tools: List of tool definitions in dict format
448
+
449
+ Returns:
450
+ TypeScript style string representation of the tools
451
+ """
452
+ if not tools:
453
+ return ""
454
+
455
+ functions = []
456
+
457
+ for tool in tools:
458
+ tool_type = tool.get("type")
459
+ if tool_type == "function":
460
+ func_def = tool.get("function", {})
461
+ if func_def:
462
+ functions.append(
463
+ _openai_function_to_typescript_style(func_def))
464
+ else:
465
+ # Skip unsupported tool types (like "_plugin")
466
+ continue
467
+
468
+ if not functions:
469
+ return ""
470
+
471
+ functions_str = "\n".join(functions)
472
+ result = "# Tools\n\n"
473
+
474
+ if functions_str:
475
+ result += "## functions\nnamespace functions {\n"
476
+ result += functions_str + "\n"
477
+ result += "}\n"
478
+
479
+ return result