yujiepan commited on
Commit
14f47b2
·
verified ·
1 Parent(s): d97d0a9

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ base_model:
4
+ - moonshotai/Kimi-K2.5
5
+ ---
6
+
7
+ This tiny model is intended for debugging. It is randomly initialized using the configuration adapted from [moonshotai/Kimi-K2.5](https://huggingface.co/moonshotai/Kimi-K2.5).
8
+
9
+ | File path | Size |
10
+ |------|------|
11
+ | model.safetensors | 6.19MB |
12
+
13
+
14
+ ### Example usage:
15
+
16
+ - vLLM
17
+
18
+ ```bash
19
+ vllm serve tiny-random/kimi-k2.5 --trust-remote-code
20
+ ```
21
+
22
+ - Transformers
23
+
24
+ ```python
25
+ import torch
26
+ from transformers import AutoModel, AutoProcessor
27
+
28
+ model_id = "tiny-random/kimi-k2.5"
29
+ messages = [
30
+ {
31
+ "role": "user",
32
+ "content": [
33
+ {
34
+ "type": "image",
35
+ "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"
36
+ },
37
+ {
38
+ "type": "text",
39
+ "text": "describe this image"
40
+ }
41
+ ],
42
+ }
43
+ ]
44
+ processor = AutoProcessor.from_pretrained(
45
+ model_id,
46
+ trust_remote_code=True,
47
+ )
48
+ model = AutoModel.from_pretrained(
49
+ model_id,
50
+ torch_dtype=torch.bfloat16,
51
+ device_map="cuda",
52
+ trust_remote_code=True,
53
+ )
54
+ inputs = processor.apply_chat_template(
55
+ messages,
56
+ tokenize=True,
57
+ add_generation_prompt=True,
58
+ return_dict=True,
59
+ return_tensors="pt"
60
+ ).to(model.device)
61
+ inputs.pop("token_type_ids", None)
62
+ generated_ids = model.generate(**inputs, max_new_tokens=16)
63
+ output_text = processor.decode(
64
+ generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
65
+ print(output_text)
66
+ ```
67
+
68
+ ### Codes to create this repo:
69
+
70
+ ```python
71
+ import json
72
+ from pathlib import Path
73
+
74
+ import accelerate
75
+ import torch
76
+ from huggingface_hub import file_exists, hf_hub_download, list_repo_files
77
+ from transformers import (
78
+ AutoConfig,
79
+ AutoModel,
80
+ AutoModelForCausalLM,
81
+ AutoProcessor,
82
+ AutoTokenizer,
83
+ GenerationConfig,
84
+ set_seed,
85
+ )
86
+
87
+ source_model_id = "moonshotai/Kimi-K2.5"
88
+ save_folder = "/tmp/tiny-random/kimi-k25"
89
+
90
+ Path(save_folder).mkdir(parents=True, exist_ok=True)
91
+
92
+ for f in list_repo_files(source_model_id, repo_type="model"):
93
+ if (f.endswith('.json') or f.endswith('.py') or f.endswith('.model') or f.endswith('.jinja')) and (
94
+ not f.endswith('.index.json')
95
+ ):
96
+ hf_hub_download(
97
+ repo_id=source_model_id,
98
+ filename=f,
99
+ repo_type="model",
100
+ local_dir=save_folder
101
+ )
102
+
103
+ def replace_file(filepath, old_string, new_string):
104
+ with open(filepath, 'r', encoding='utf-8') as f:
105
+ code = f.read()
106
+ code = code.replace(old_string, new_string)
107
+ with open(filepath, 'w', encoding='utf-8') as f:
108
+ f.write(code)
109
+
110
+ replace_file(f'{save_folder}/configuration_kimi_k25.py',
111
+ "from configuration_deepseek import DeepseekV3Config",
112
+ "from transformers import DeepseekV3Config")
113
+ replace_file(f'{save_folder}/modeling_kimi_k25.py',
114
+ "use_deterministic_attn=self.use_deterministic_attn",
115
+ "")
116
+ with open(f'{save_folder}/config.json') as f:
117
+ config_json = json.load(f)
118
+
119
+ config_json['text_config'].update({
120
+ 'first_k_dense_replace': 1,
121
+ 'num_hidden_layers': 2,
122
+ 'hidden_size': 8,
123
+ 'intermediate_size': 64,
124
+ 'kv_lora_rank': 384,
125
+ 'moe_intermediate_size': 64,
126
+ 'n_routed_experts': 32,
127
+ 'n_shared_experts': 1,
128
+ 'num_attention_heads': 1,
129
+ 'num_experts_per_tok': 8,
130
+ 'num_key_value_heads': 1,
131
+ 'q_lora_rank': 32,
132
+ 'qk_nope_head_dim': 64,
133
+ 'qk_rope_head_dim': 192,
134
+ 'v_head_dim': 64,
135
+ 'tie_word_embeddings': False,
136
+ })
137
+ del config_json['text_config']['quantization_config']
138
+ config_json['vision_config'].update({
139
+ 'mm_hidden_size': 64,
140
+ 'text_hidden_size': 8,
141
+ 'vt_hidden_size': 64,
142
+ 'vt_intermediate_size': 128,
143
+ 'vt_num_attention_heads': 2,
144
+ 'vt_num_hidden_layers': 2,
145
+ })
146
+ del config_json['vision_config']['_attn_implementation']
147
+ with open(f"{save_folder}/config.json", "w", encoding='utf-8') as f:
148
+ json.dump(config_json, f, indent=2)
149
+
150
+ config = AutoConfig.from_pretrained(
151
+ save_folder,
152
+ trust_remote_code=True,
153
+ )
154
+ print(config)
155
+ torch.set_default_dtype(torch.bfloat16)
156
+ model = AutoModel.from_config(config, trust_remote_code=True)
157
+ torch.set_default_dtype(torch.float32)
158
+ if file_exists(filename="generation_config.json", repo_id=source_model_id, repo_type='model'):
159
+ model.generation_config = GenerationConfig.from_pretrained(
160
+ source_model_id, trust_remote_code=True,
161
+ )
162
+ set_seed(42)
163
+ model = model.cpu()
164
+ with torch.no_grad():
165
+ for name, p in sorted(model.named_parameters()):
166
+ torch.nn.init.normal_(p, 0, 0.1)
167
+ print(name, p.shape)
168
+ model.save_pretrained(save_folder)
169
+ replace_file(f'{save_folder}/configuration_kimi_k25.py',
170
+ "from configuration_deepseek import DeepseekV3Config",
171
+ "from transformers import DeepseekV3Config")
172
+ replace_file(f'{save_folder}/modeling_kimi_k25.py',
173
+ "use_deterministic_attn=self.use_deterministic_attn",
174
+ "")
175
+ ```
176
+
177
+ ### Printing the model:
178
+
179
+ ```text
180
+ KimiK25ForConditionalGeneration(
181
+ (vision_tower): MoonViT3dPretrainedModel(
182
+ (patch_embed): MoonVision3dPatchEmbed(
183
+ (proj): Conv2d(3, 64, kernel_size=(14, 14), stride=(14, 14))
184
+ (pos_emb): Learnable2DInterpPosEmbDivided_fixed()
185
+ )
186
+ (encoder): MoonViT3dEncoder(
187
+ (rope_2d): Rope2DPosEmbRepeated(dim=32, max_height=512, max_width=512, theta_base=10000)
188
+ (blocks): ModuleList(
189
+ (0-1): 2 x MoonViTEncoderLayer(
190
+ (norm0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
191
+ (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
192
+ (mlp): MLP2(
193
+ (fc0): Linear(in_features=64, out_features=128, bias=True)
194
+ (fc1): Linear(in_features=128, out_features=64, bias=True)
195
+ (activation): PytorchGELUTanh()
196
+ )
197
+ (wqkv): Linear(in_features=64, out_features=192, bias=True)
198
+ (wo): Linear(in_features=64, out_features=64, bias=True)
199
+ )
200
+ )
201
+ (final_layernorm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
202
+ )
203
+ )
204
+ (mm_projector): PatchMergerMLP(
205
+ (pre_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
206
+ (proj): Sequential(
207
+ (0): Linear(in_features=256, out_features=256, bias=True)
208
+ (1): GELU(approximate='none')
209
+ (2): Linear(in_features=256, out_features=8, bias=True)
210
+ )
211
+ )
212
+ (language_model): DeepseekV3ForCausalLM(
213
+ (model): DeepseekV3Model(
214
+ (embed_tokens): Embedding(163840, 8, padding_idx=163839)
215
+ (layers): ModuleList(
216
+ (0): DeepseekV3DecoderLayer(
217
+ (self_attn): DeepseekV3Attention(
218
+ (q_a_proj): Linear(in_features=8, out_features=32, bias=False)
219
+ (q_a_layernorm): DeepseekV3RMSNorm()
220
+ (q_b_proj): Linear(in_features=32, out_features=256, bias=False)
221
+ (kv_a_proj_with_mqa): Linear(in_features=8, out_features=576, bias=False)
222
+ (kv_a_layernorm): DeepseekV3RMSNorm()
223
+ (kv_b_proj): Linear(in_features=384, out_features=128, bias=False)
224
+ (o_proj): Linear(in_features=64, out_features=8, bias=False)
225
+ (rotary_emb): DeepseekV3YarnRotaryEmbedding()
226
+ )
227
+ (mlp): DeepseekV3MLP(
228
+ (gate_proj): Linear(in_features=8, out_features=64, bias=False)
229
+ (up_proj): Linear(in_features=8, out_features=64, bias=False)
230
+ (down_proj): Linear(in_features=64, out_features=8, bias=False)
231
+ (act_fn): SiLU()
232
+ )
233
+ (input_layernorm): DeepseekV3RMSNorm()
234
+ (post_attention_layernorm): DeepseekV3RMSNorm()
235
+ )
236
+ (1): DeepseekV3DecoderLayer(
237
+ (self_attn): DeepseekV3Attention(
238
+ (q_a_proj): Linear(in_features=8, out_features=32, bias=False)
239
+ (q_a_layernorm): DeepseekV3RMSNorm()
240
+ (q_b_proj): Linear(in_features=32, out_features=256, bias=False)
241
+ (kv_a_proj_with_mqa): Linear(in_features=8, out_features=576, bias=False)
242
+ (kv_a_layernorm): DeepseekV3RMSNorm()
243
+ (kv_b_proj): Linear(in_features=384, out_features=128, bias=False)
244
+ (o_proj): Linear(in_features=64, out_features=8, bias=False)
245
+ (rotary_emb): DeepseekV3YarnRotaryEmbedding()
246
+ )
247
+ (mlp): DeepseekV3MoE(
248
+ (experts): ModuleList(
249
+ (0-31): 32 x DeepseekV3MLP(
250
+ (gate_proj): Linear(in_features=8, out_features=64, bias=False)
251
+ (up_proj): Linear(in_features=8, out_features=64, bias=False)
252
+ (down_proj): Linear(in_features=64, out_features=8, bias=False)
253
+ (act_fn): SiLU()
254
+ )
255
+ )
256
+ (gate): MoEGate()
257
+ (shared_experts): DeepseekV3MLP(
258
+ (gate_proj): Linear(in_features=8, out_features=64, bias=False)
259
+ (up_proj): Linear(in_features=8, out_features=64, bias=False)
260
+ (down_proj): Linear(in_features=64, out_features=8, bias=False)
261
+ (act_fn): SiLU()
262
+ )
263
+ )
264
+ (input_layernorm): DeepseekV3RMSNorm()
265
+ (post_attention_layernorm): DeepseekV3RMSNorm()
266
+ )
267
+ )
268
+ (norm): DeepseekV3RMSNorm()
269
+ )
270
+ (lm_head): Linear(in_features=8, out_features=163840, bias=False)
271
+ )
272
+ )
273
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- macro render_content(msg) -%}
2
+ {%- set c = msg.get('content') -%}
3
+ {%- if c is string -%}
4
+ {{ c }}
5
+ {%- elif c is not none -%}
6
+ {% for content in c -%}
7
+ {% if content['type'] == 'image' or content['type'] == 'image_url' -%}
8
+ <|media_start|>image<|media_content|><|media_pad|><|media_end|>
9
+ {% elif content['type'] == 'video' or content['type']== 'video_url'-%}
10
+ <|kimi_k25_video_placeholder|>
11
+ {% else -%}
12
+ {{ content['text'] }}
13
+ {%- endif -%}
14
+ {%- endfor -%}
15
+ {%- endif -%}
16
+ {%- endmacro -%}
17
+
18
+ {% macro set_roles(message) -%}
19
+ {%- set role_name = message.get('name') or message['role'] -%}
20
+ {%- if message['role'] == 'user' -%}
21
+ <|im_user|>{{role_name}}<|im_middle|>
22
+ {%- elif message['role'] == 'assistant' -%}
23
+ <|im_assistant|>{{role_name}}<|im_middle|>
24
+ {%- else -%}
25
+ <|im_system|>{{role_name}}<|im_middle|>
26
+ {%- endif -%}
27
+ {%- endmacro -%}
28
+
29
+
30
+ {%- macro render_toolcalls(message) -%}
31
+ <|tool_calls_section_begin|>
32
+ {%- for tool_call in message['tool_calls'] -%}
33
+ {%- set formatted_id = tool_call['id'] -%}
34
+ <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
35
+ {%- endfor -%}
36
+ <|tool_calls_section_end|>
37
+ {%- endmacro -%}
38
+
39
+
40
+ {# Find last non-tool-call assisitant message #}
41
+ {%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%}
42
+ {%- for idx in range(messages|length-1, -1, -1) -%}
43
+ {%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%}
44
+ {%- set ns.last_non_tool_call_assistant_msg = idx -%}
45
+ {%- break -%}
46
+ {%- endif -%}
47
+ {%- endfor -%}
48
+
49
+ {# split all messages into history & suffix, reasoning_content in suffix should be reserved.#}
50
+ {%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%}
51
+ {%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%}
52
+
53
+ {%- if tools -%}
54
+ {%- if tools_ts_str -%}
55
+ <|im_system|>tool_declare<|im_middle|>{{ tools_ts_str }}<|im_end|>
56
+ {%- else -%}
57
+ <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>
58
+ {%- endif -%}
59
+ {%- endif -%}
60
+
61
+ {%- if messages|length == 0 or messages[0]['role'] != 'system' -%}
62
+ <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>
63
+ {%- endif -%}
64
+
65
+ {%- for message in hist_msgs -%}
66
+ {{set_roles(message)}}
67
+ {%- if message['role'] == 'assistant' -%}
68
+ <think></think>{{render_content(message)}}
69
+ {%- if message.get('tool_calls') -%}
70
+ {{render_toolcalls(message)}}
71
+ {%- endif -%}
72
+ {%- elif message['role'] == 'tool' -%}
73
+ {%- set tool_call_id = message.tool_call_id -%}
74
+ ## Return of {{ tool_call_id }}
75
+ {{render_content(message)}}
76
+ {%- elif message['content'] is not none -%}
77
+ {{render_content(message)}}
78
+ {%- endif -%}
79
+ <|im_end|>
80
+ {%- endfor -%}
81
+
82
+ {%- for message in suffix_msgs -%}
83
+ {{set_roles(message)}}
84
+ {%- if message['role'] == 'assistant' -%}
85
+ {%- if thinking is defined and thinking is false -%}
86
+ <think></think>{{render_content(message)}}
87
+ {%- else -%}
88
+ {%- set rc = message.get('reasoning_content', '') -%}
89
+ <think>{{rc}}</think>{{render_content(message)}}
90
+ {%- endif -%}
91
+ {%- if message.get('tool_calls') -%}
92
+ {{render_toolcalls(message)}}
93
+ {%- endif -%}
94
+ {%- elif message['role'] == 'tool' -%}
95
+ {%- set tool_call_id = message.tool_call_id -%}
96
+ ## Return of {{ tool_call_id }}
97
+ {{render_content(message)}}
98
+ {%- elif message['content'] is not none -%}
99
+ {{render_content(message)}}
100
+ {%- endif -%}
101
+ <|im_end|>
102
+ {%- endfor -%}
103
+
104
+
105
+ {%- if add_generation_prompt -%}
106
+ <|im_assistant|>assistant<|im_middle|>
107
+ {%- if thinking is defined and thinking is false -%}
108
+ <think></think>
109
+ {%- else -%}
110
+ <think>
111
+ {%- endif -%}
112
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiK25ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi_k25.KimiK25Config",
7
+ "AutoModel": "modeling_kimi_k25.KimiK25ForConditionalGeneration",
8
+ "AutoModelForCausalLM": "modeling_kimi_k25.KimiK25ForConditionalGeneration"
9
+ },
10
+ "bos_token_id": 163584,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 163585,
13
+ "ignore_index": -100,
14
+ "media_placeholder_token_id": 163605,
15
+ "model_type": "kimi_k25",
16
+ "pad_token_id": 163839,
17
+ "text_config": {
18
+ "_name_or_path": "",
19
+ "add_cross_attention": false,
20
+ "architectures": [
21
+ "DeepseekV3ForCausalLM"
22
+ ],
23
+ "attention_bias": false,
24
+ "attention_dropout": 0.0,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_deepseek.DeepseekV3Config",
27
+ "AutoModel": "modeling_deepseek.DeepseekV3Model",
28
+ "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
29
+ },
30
+ "aux_loss_alpha": 0.001,
31
+ "bad_words_ids": null,
32
+ "begin_suppress_tokens": null,
33
+ "bos_token_id": 163584,
34
+ "chunk_size_feed_forward": 0,
35
+ "cross_attention_hidden_size": null,
36
+ "decoder_start_token_id": null,
37
+ "diversity_penalty": 0.0,
38
+ "do_sample": false,
39
+ "dtype": "bfloat16",
40
+ "early_stopping": false,
41
+ "encoder_no_repeat_ngram_size": 0,
42
+ "eos_token_id": 163585,
43
+ "ep_size": 1,
44
+ "exponential_decay_length_penalty": null,
45
+ "finetuning_task": null,
46
+ "first_k_dense_replace": 1,
47
+ "forced_bos_token_id": null,
48
+ "forced_eos_token_id": null,
49
+ "head_dim": 192,
50
+ "hidden_act": "silu",
51
+ "hidden_size": 8,
52
+ "id2label": {
53
+ "0": "LABEL_0",
54
+ "1": "LABEL_1"
55
+ },
56
+ "initializer_range": 0.02,
57
+ "intermediate_size": 64,
58
+ "is_decoder": false,
59
+ "is_encoder_decoder": false,
60
+ "kv_lora_rank": 384,
61
+ "label2id": {
62
+ "LABEL_0": 0,
63
+ "LABEL_1": 1
64
+ },
65
+ "length_penalty": 1.0,
66
+ "max_length": 20,
67
+ "max_position_embeddings": 262144,
68
+ "min_length": 0,
69
+ "model_type": "deepseek_v3",
70
+ "moe_intermediate_size": 64,
71
+ "moe_layer_freq": 1,
72
+ "n_group": 1,
73
+ "n_routed_experts": 32,
74
+ "n_shared_experts": 1,
75
+ "no_repeat_ngram_size": 0,
76
+ "norm_topk_prob": true,
77
+ "num_attention_heads": 1,
78
+ "num_beam_groups": 1,
79
+ "num_beams": 1,
80
+ "num_experts_per_tok": 8,
81
+ "num_hidden_layers": 2,
82
+ "num_key_value_heads": 1,
83
+ "num_nextn_predict_layers": 0,
84
+ "num_return_sequences": 1,
85
+ "output_attentions": false,
86
+ "output_hidden_states": false,
87
+ "output_scores": false,
88
+ "pad_token_id": 163839,
89
+ "prefix": null,
90
+ "pretraining_tp": 1,
91
+ "problem_type": null,
92
+ "pruned_heads": {},
93
+ "q_lora_rank": 32,
94
+ "qk_head_dim": 256,
95
+ "qk_nope_head_dim": 64,
96
+ "qk_rope_head_dim": 192,
97
+ "remove_invalid_values": false,
98
+ "repetition_penalty": 1.0,
99
+ "return_dict": true,
100
+ "return_dict_in_generate": false,
101
+ "rms_norm_eps": 1e-05,
102
+ "rope_interleave": true,
103
+ "rope_scaling": {
104
+ "beta_fast": 32.0,
105
+ "beta_slow": 1.0,
106
+ "factor": 64.0,
107
+ "mscale": 1.0,
108
+ "mscale_all_dim": 1.0,
109
+ "original_max_position_embeddings": 4096,
110
+ "rope_type": "yarn",
111
+ "type": "yarn"
112
+ },
113
+ "rope_theta": 50000.0,
114
+ "routed_scaling_factor": 2.827,
115
+ "scoring_func": "sigmoid",
116
+ "sep_token_id": null,
117
+ "seq_aux": true,
118
+ "suppress_tokens": null,
119
+ "task_specific_params": null,
120
+ "temperature": 1.0,
121
+ "tf_legacy_loss": false,
122
+ "tie_encoder_decoder": false,
123
+ "tie_word_embeddings": false,
124
+ "tokenizer_class": null,
125
+ "top_k": 50,
126
+ "top_p": 1.0,
127
+ "topk_group": 1,
128
+ "topk_method": "noaux_tc",
129
+ "torchscript": false,
130
+ "typical_p": 1.0,
131
+ "use_bfloat16": false,
132
+ "use_cache": true,
133
+ "v_head_dim": 64,
134
+ "vocab_size": 163840
135
+ },
136
+ "tie_word_embeddings": false,
137
+ "transformers_version": "4.56.2",
138
+ "use_unified_vision_chunk": true,
139
+ "video_placeholder": "<|kimi_k25_video_placeholder|>",
140
+ "vision_config": {
141
+ "init_pos_emb_height": 64,
142
+ "init_pos_emb_time": 4,
143
+ "init_pos_emb_width": 64,
144
+ "merge_kernel_size": [
145
+ 2,
146
+ 2
147
+ ],
148
+ "merge_type": "sd2_tpool",
149
+ "mm_hidden_size": 64,
150
+ "mm_projector_type": "patchmerger",
151
+ "model_type": "",
152
+ "patch_size": 14,
153
+ "pos_emb_type": "divided_fixed",
154
+ "projector_hidden_act": "gelu",
155
+ "projector_ln_eps": 1e-05,
156
+ "text_hidden_size": 8,
157
+ "video_attn_type": "spatial_temporal",
158
+ "vt_hidden_size": 64,
159
+ "vt_intermediate_size": 128,
160
+ "vt_num_attention_heads": 2,
161
+ "vt_num_hidden_layers": 2
162
+ }
163
+ }
configuration_deepseek.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/configuration_deepseek.py
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+ DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
9
+
10
+
11
+ class DeepseekV3Config(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+
21
+ Args:
22
+ vocab_size (`int`, *optional*, defaults to 129280):
23
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
24
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
25
+ hidden_size (`int`, *optional*, defaults to 4096):
26
+ Dimension of the hidden representations.
27
+ intermediate_size (`int`, *optional*, defaults to 11008):
28
+ Dimension of the MLP representations.
29
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
30
+ Dimension of the MoE representations.
31
+ num_hidden_layers (`int`, *optional*, defaults to 32):
32
+ Number of hidden layers in the Transformer decoder.
33
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
34
+ Number of nextn predict layers in the DeepSeekV3 Model.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer decoder.
37
+ n_shared_experts (`int`, *optional*, defaults to None):
38
+ Number of shared experts, None means dense model.
39
+ n_routed_experts (`int`, *optional*, defaults to None):
40
+ Number of routed experts, None means dense model.
41
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
42
+ Scaling factor or routed experts.
43
+ topk_method (`str`, *optional*, defaults to `gready`):
44
+ Topk method used in routed gate.
45
+ n_group (`int`, *optional*, defaults to None):
46
+ Number of groups for routed experts.
47
+ topk_group (`int`, *optional*, defaults to None):
48
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
49
+ num_experts_per_tok (`int`, *optional*, defaults to None):
50
+ Number of selected experts, None means dense model.
51
+ moe_layer_freq (`int`, *optional*, defaults to 1):
52
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
53
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
54
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
55
+ \--k dense layers--/
56
+ norm_topk_prob (`bool`, *optional*, defaults to False):
57
+ Whether to normalize the weights of the routed experts.
58
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
59
+ Method of computing expert weights.
60
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
61
+ Auxiliary loss weight coefficient.
62
+ seq_aux = (`bool`, *optional*, defaults to True):
63
+ Whether to compute the auxiliary loss for each individual sample.
64
+ num_key_value_heads (`int`, *optional*):
65
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
66
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
67
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
68
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
69
+ by meanpooling all the original heads within that group. For more details checkout [this
70
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
71
+ `num_attention_heads`.
72
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
73
+ The non-linear activation function (function or string) in the decoder.
74
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
75
+ The maximum sequence length that this model might ever be used with.
76
+ initializer_range (`float`, *optional*, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
79
+ The epsilon used by the rms normalization layers.
80
+ use_cache (`bool`, *optional*, defaults to `True`):
81
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
82
+ relevant if `config.is_decoder=True`.
83
+ pad_token_id (`int`, *optional*):
84
+ Padding token id.
85
+ bos_token_id (`int`, *optional*, defaults to 1):
86
+ Beginning of stream token id.
87
+ eos_token_id (`int`, *optional*, defaults to 2):
88
+ End of stream token id.
89
+ pretraining_tp (`int`, *optional*, defaults to 1):
90
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
91
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
92
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
93
+ issue](https://github.com/pytorch/pytorch/issues/76232).
94
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
95
+ Whether to tie weight embeddings
96
+ rope_theta (`float`, *optional*, defaults to 10000.0):
97
+ The base period of the RoPE embeddings.
98
+ rope_scaling (`Dict`, *optional*):
99
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
100
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
101
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
102
+ `max_position_embeddings` to the expected new maximum.
103
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
104
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
105
+ attention_dropout (`float`, *optional*, defaults to 0.0):
106
+ The dropout ratio for the attention probabilities.
107
+
108
+ ```python
109
+ >>> from transformers import DeepseekV3Model, DeepseekV3Config
110
+
111
+ >>> # Initializing a Deepseek-V3 style configuration
112
+ >>> configuration = DeepseekV3Config()
113
+
114
+ >>> # Accessing the model configuration
115
+ >>> configuration = model.config
116
+ ```"""
117
+
118
+ model_type = "deepseek_v3"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=129280,
124
+ hidden_size=7168,
125
+ intermediate_size=18432,
126
+ moe_intermediate_size=2048,
127
+ num_hidden_layers=61,
128
+ num_nextn_predict_layers=1,
129
+ num_attention_heads=128,
130
+ num_key_value_heads=128,
131
+ n_shared_experts=1,
132
+ n_routed_experts=256,
133
+ ep_size=1,
134
+ routed_scaling_factor=2.5,
135
+ kv_lora_rank=512,
136
+ q_lora_rank=1536,
137
+ qk_rope_head_dim=64,
138
+ v_head_dim=128,
139
+ qk_nope_head_dim=128,
140
+ topk_method='noaux_tc',
141
+ n_group=8,
142
+ topk_group=4,
143
+ num_experts_per_tok=8,
144
+ moe_layer_freq=1,
145
+ first_k_dense_replace=3,
146
+ norm_topk_prob=True,
147
+ scoring_func='sigmoid',
148
+ aux_loss_alpha=0.001,
149
+ seq_aux=True,
150
+ hidden_act="silu",
151
+ max_position_embeddings=4096,
152
+ initializer_range=0.02,
153
+ rms_norm_eps=1e-6,
154
+ use_cache=True,
155
+ pad_token_id=None,
156
+ bos_token_id=0,
157
+ eos_token_id=1,
158
+ pretraining_tp=1,
159
+ tie_word_embeddings=False,
160
+ rope_theta=10000.0,
161
+ rope_scaling=None,
162
+ attention_bias=False,
163
+ attention_dropout=0.0,
164
+ **kwargs,
165
+ ):
166
+ self.vocab_size = vocab_size
167
+ self.max_position_embeddings = max_position_embeddings
168
+ self.hidden_size = hidden_size
169
+ self.intermediate_size = intermediate_size
170
+ self.moe_intermediate_size = moe_intermediate_size
171
+ self.num_hidden_layers = num_hidden_layers
172
+ self.num_nextn_predict_layers = num_nextn_predict_layers
173
+ self.num_attention_heads = num_attention_heads
174
+ self.n_shared_experts = n_shared_experts
175
+ self.n_routed_experts = n_routed_experts
176
+ self.ep_size = ep_size
177
+ self.routed_scaling_factor = routed_scaling_factor
178
+ self.kv_lora_rank = kv_lora_rank
179
+ self.q_lora_rank = q_lora_rank
180
+ self.qk_rope_head_dim = qk_rope_head_dim
181
+ self.v_head_dim = v_head_dim
182
+ self.qk_nope_head_dim = qk_nope_head_dim
183
+ self.topk_method = topk_method
184
+ self.n_group = n_group
185
+ self.topk_group = topk_group
186
+ self.num_experts_per_tok = num_experts_per_tok
187
+ self.moe_layer_freq = moe_layer_freq
188
+ self.first_k_dense_replace = first_k_dense_replace
189
+ self.norm_topk_prob = norm_topk_prob
190
+ self.scoring_func = scoring_func
191
+ self.aux_loss_alpha = aux_loss_alpha
192
+ self.seq_aux = seq_aux
193
+ # for backward compatibility
194
+ if num_key_value_heads is None:
195
+ num_key_value_heads = num_attention_heads
196
+
197
+ self.num_key_value_heads = num_key_value_heads
198
+ self.hidden_act = hidden_act
199
+ self.initializer_range = initializer_range
200
+ self.rms_norm_eps = rms_norm_eps
201
+ self.pretraining_tp = pretraining_tp
202
+ self.use_cache = use_cache
203
+ self.rope_theta = rope_theta
204
+ self.rope_scaling = rope_scaling
205
+ self.attention_bias = attention_bias
206
+ self.attention_dropout = attention_dropout
207
+
208
+ super().__init__(
209
+ pad_token_id=pad_token_id,
210
+ bos_token_id=bos_token_id,
211
+ eos_token_id=eos_token_id,
212
+ tie_word_embeddings=tie_word_embeddings,
213
+ **kwargs,
214
+ )
configuration_kimi_k25.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ try:
4
+ from transformers import DeepseekV3Config
5
+ except ImportError:
6
+ from .configuration_deepseek import DeepseekV3Config
7
+
8
+
9
+ class KimiK25VisionConfig(PretrainedConfig):
10
+
11
+ def __init__(
12
+ self,
13
+ patch_size: int = 14,
14
+ init_pos_emb_height: int = 64,
15
+ init_pos_emb_width: int = 64,
16
+ init_pos_emb_time: int = 4,
17
+ pos_emb_type: str = 'divided_fixed',
18
+ vt_num_attention_heads: int = 16,
19
+ vt_num_hidden_layers: int = 27,
20
+ vt_hidden_size: int = 1152,
21
+ vt_intermediate_size: int = 4304,
22
+ merge_kernel_size: tuple = (2, 2),
23
+ video_attn_type: str = 'spatial_temporal',
24
+ merge_type: str = 'sd2_tpool',
25
+ _attn_implementation: str = 'flash_attention_2',
26
+ # MM Projector parameters
27
+ mm_projector_type: str = 'patchmerger',
28
+ mm_hidden_size: int | None = None,
29
+ projector_hidden_act: str = "gelu",
30
+ projector_ln_eps: float = 1e-5,
31
+ # Other parameters
32
+ ignore_index: int = -100,
33
+ media_placeholder_token_id: int = 163605,
34
+ pad_token_id: int = 0,
35
+ use_unified_vision_chunk: bool = True,
36
+ video_placeholder="<|kimi_k25_video_placeholder|>",
37
+ text_hidden_size=7168,
38
+ **vision_config_kwargs):
39
+
40
+ self.patch_size = patch_size
41
+ self.init_pos_emb_height = init_pos_emb_height
42
+ self.init_pos_emb_width = init_pos_emb_width
43
+ self.init_pos_emb_time = init_pos_emb_time
44
+ self.pos_emb_type = pos_emb_type
45
+ self.vt_num_attention_heads = vt_num_attention_heads
46
+ self.vt_num_hidden_layers = vt_num_hidden_layers
47
+ self.vt_hidden_size = vt_hidden_size
48
+ self.vt_intermediate_size = vt_intermediate_size
49
+ self.merge_kernel_size = merge_kernel_size
50
+ self.video_attn_type = video_attn_type
51
+ self.merge_type = merge_type
52
+ self._attn_implementation = _attn_implementation
53
+
54
+ # MM Projector config
55
+ self.mm_projector_type = mm_projector_type
56
+ self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else vt_hidden_size
57
+ self.projector_hidden_act = projector_hidden_act
58
+ self.projector_ln_eps = projector_ln_eps
59
+ self.text_hidden_size = text_hidden_size
60
+
61
+
62
+ class KimiK25Config(PretrainedConfig):
63
+ """Kimi-K2.5 model configuration.
64
+
65
+ Args:
66
+ text_config (dict | DeepseekV3Config): Configuration for the text model.
67
+
68
+ Vision Tower Parameters (from MoonViT3dConfig):
69
+ patch_size (int): Patch size for vision tower.
70
+ init_pos_emb_height (int): Initial position embedding height.
71
+ init_pos_emb_width (int): Initial position embedding width.
72
+ init_pos_emb_time (int): Initial position embedding time dimension.
73
+ pos_emb_type (str): Type of position embedding.
74
+ vt_num_attention_heads (int): Number of attention heads in vision tower.
75
+ vt_num_hidden_layers (int): Number of hidden layers in vision tower.
76
+ vt_hidden_size (int): Hidden size of vision tower.
77
+ vt_intermediate_size (int): Intermediate size in vision tower FFN.
78
+ merge_kernel_size (tuple): Kernel size for patch merging.
79
+ video_attn_type (str): Type of video attention.
80
+ merge_type (str): Type of merge operation.
81
+ _attn_implementation (str): Attention implementation type.
82
+
83
+ MM Projector Parameters (from MultiModalProjectorConfig):
84
+ mm_projector_type (str): Type of multimodal projector.
85
+ mm_hidden_size (int): Hidden size from vision tower (should match vt_hidden_size).
86
+ projector_hidden_act (str): Activation function for projector.
87
+ projector_ln_eps (float): Layer norm epsilon for projector.
88
+
89
+ Other Parameters:
90
+ ignore_index (int): The ignore index for the loss function.
91
+ media_placeholder_token_id (int): The token ID to use for media placeholders.
92
+ pad_token_id (int): The token ID to use for padding.
93
+ """
94
+
95
+ model_type = "kimi_k25"
96
+
97
+ def __init__(
98
+ self,
99
+ text_config: dict | DeepseekV3Config = None,
100
+ vision_config: dict | KimiK25VisionConfig = None,
101
+ # Other parameters
102
+ ignore_index: int = -100,
103
+ media_placeholder_token_id: int = 163605,
104
+ pad_token_id: int = 0,
105
+ use_unified_vision_chunk: bool = True,
106
+ video_placeholder="<|kimi_k25_video_placeholder|>",
107
+ **kwargs,
108
+ ):
109
+ if isinstance(text_config, dict):
110
+ text_config = DeepseekV3Config(**text_config)
111
+ if isinstance(vision_config, dict):
112
+ vision_config = KimiK25VisionConfig(**vision_config)
113
+ self.text_config = text_config
114
+ self.vision_config = vision_config
115
+ # Other config
116
+ self.ignore_index = ignore_index
117
+ self.media_placeholder_token_id = media_placeholder_token_id
118
+ self.use_unified_vision_chunk = use_unified_vision_chunk
119
+ self.video_placeholder = video_placeholder
120
+ if getattr(self.text_config, "quantization_config", None) is not None:
121
+ self.quantization_config = self.text_config.quantization_config
122
+
123
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token_id": 163586,
3
+ "max_length": 262144,
4
+ "transformers_version": "4.56.2",
5
+ "trust_remote_code": true
6
+ }
kimi_k25_processor.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.feature_extraction_utils import BatchFeature
2
+ from transformers.processing_utils import ProcessorMixin
3
+ from transformers.utils import logging
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class KimiK25Processor(ProcessorMixin):
9
+ r"""
10
+ Constructs a KimiK25 processor which wraps a KimiK25 image processor and a tokenizer into a single processor.
11
+
12
+ [`KimiK25Processor`] offers all the functionalities of [`KimiK25ImageProcessor`] and [`TikTokenTokenizer`]. See the
13
+ [`~KimiK25Processor.__call__`] and [`~KimiK25Processor.decode`] for more information.
14
+
15
+ Args:
16
+ image_processor ([`KimiK25ImageProcessor`], *optional*):
17
+ The image processor is a required input.
18
+ tokenizer ([`TikTokenTokenizer`], *optional*):
19
+ The tokenizer is a required input.
20
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
21
+ in a chat into a tokenizable string.
22
+ """
23
+
24
+ attributes = ["image_processor", "tokenizer"]
25
+ valid_kwargs = ["chat_template"]
26
+ image_processor_class = "AutoImageProcessor"
27
+ tokenizer_class = "AutoTokenizer"
28
+
29
+ def __init__(
30
+ self,
31
+ image_processor=None,
32
+ tokenizer=None,
33
+ chat_template=None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(image_processor,
37
+ tokenizer,
38
+ chat_template=chat_template)
39
+ self.media_processor = image_processor
40
+ # A special temporal placeholder to be replaced by actual video placeholders
41
+ self.video_placeholder = "<|kimi_k25_video_placeholder|>"
42
+
43
+ def update_raw_text(self, text: str, video_prompts: list[str]) -> str:
44
+ # replace video prompt in text with video chunk prompts
45
+ video_count = text.count(self.video_placeholder)
46
+ if video_count == 0:
47
+ return text
48
+ assert video_count == len(video_prompts)
49
+ text_parts = text.split(self.video_placeholder)
50
+ assert len(text_parts) == len(video_prompts) + 1
51
+ text = "".join([
52
+ text_parts[i] + video_prompts[i] for i in range(len(video_prompts))
53
+ ])
54
+ text += text_parts[-1]
55
+ return text
56
+
57
+ def preprocess_medias(self, medias: list[dict]) -> list[dict]:
58
+ updated_medias = []
59
+ video_prompts = []
60
+ for media in medias:
61
+ if media['type'] == 'image':
62
+ updated_medias.append(media)
63
+ elif media['type'] == 'video':
64
+ video_chunks = self.media_processor.split_video_chunks(
65
+ media['video'])
66
+ updated_medias.extend(video_chunks)
67
+ video_prompts.append("".join(
68
+ [vc['prompt'] for vc in video_chunks]))
69
+ else:
70
+ raise ValueError(f"unsupported media type: {media['type']}")
71
+ return updated_medias, video_prompts
72
+
73
+ def __call__(self,
74
+ messages: list[dict] = None,
75
+ medias: list[dict] = None,
76
+ text: str = None,
77
+ return_tensors: str = "pt",
78
+ **kwargs) -> BatchFeature:
79
+ """
80
+ Process multimodal inputs for Kimi-K2.5 model.
81
+
82
+ This processor accepts ordered messages and extracts both media and text in a single pass.
83
+ text will be automatically updated if video input detected in messages
84
+
85
+ Args:
86
+ messages: List of message dicts with 'role' and 'content' fields.
87
+ If provided, medias and text will be extracted automatically.
88
+ medias: Pre-extracted list of media dicts. If None, extracted from messages.
89
+ text: Pre-formatted text string. If None, generated via apply_chat_template.
90
+ return_tensors: Format of returned tensors ('pt', 'np', 'tf'). Default: 'pt'.
91
+ **kwargs: Additional arguments passed to tokenizer.apply_chat_template.
92
+
93
+ Returns:
94
+ BatchFeature with fields: input_ids, attention_mask, pixel_values, grid_thws.
95
+ """
96
+ if messages is None and (medias is None or text is None):
97
+ raise ValueError(
98
+ "Provide either 'messages' or both 'medias' and 'text'")
99
+
100
+ if medias is not None and text is not None:
101
+ updated_medias, video_prompts = self.preprocess_medias(medias)
102
+ preprocessed = self.media_processor.preprocess(
103
+ updated_medias, return_tensors=return_tensors)
104
+ text = self.update_raw_text(text, video_prompts)
105
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
106
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
107
+
108
+ if medias is None:
109
+ medias = self._extract_medias_from_messages(messages)
110
+ updated_medias, video_prompts = self.preprocess_medias(medias)
111
+ preprocessed = self.media_processor.preprocess(
112
+ updated_medias, return_tensors=return_tensors)
113
+
114
+ # Generate text if not provided
115
+ if text is None:
116
+ text = self.tokenizer.apply_chat_template(messages, **kwargs)
117
+
118
+ text = self.update_raw_text(text, video_prompts)
119
+
120
+ text_inputs = self.tokenizer(text, return_tensors=return_tensors)
121
+ return BatchFeature(data={**text_inputs, **preprocessed.data})
122
+
123
+ @staticmethod
124
+ def _extract_medias_from_messages(messages: list[dict]) -> list[dict]:
125
+ """
126
+ Extract media items from messages in a single pass.
127
+
128
+ This is an optimized version that processes messages only once.
129
+ Kept as internal method since external callers should use __call__.
130
+ """
131
+ medias = []
132
+ for msg in messages:
133
+ if msg['role'] != 'user' or not msg.get('content'):
134
+ continue
135
+
136
+ for content_part in msg['content']:
137
+ if not isinstance(content_part, dict):
138
+ continue
139
+
140
+ content_type = content_part.get('type')
141
+ if content_type in ['video_url', 'video']:
142
+ medias.append({
143
+ 'type': 'video',
144
+ 'video': content_part['video_url']['url'],
145
+ 'first_frame_timestamp': 0.0
146
+ })
147
+ elif content_type in ['image_url', 'image']:
148
+ medias.append({
149
+ 'type': 'image',
150
+ 'image': content_part['image_url'],
151
+ })
152
+ return medias
153
+
154
+ def apply_chat_template(self, messages, **kwargs):
155
+ return self.tokenizer.apply_chat_template(messages, **kwargs)
156
+
157
+ def batch_decode(self, *args, **kwargs):
158
+ return self.tokenizer.batch_decode(*args, **kwargs)
159
+
160
+ def decode(self, *args, **kwargs):
161
+ return self.tokenizer.decode(*args, **kwargs)
162
+
163
+ @property
164
+ def model_input_names(self):
165
+ return ['input_ids', 'attention_mask', 'pixel_values', 'grid_thws']
kimi_k25_vision_processing.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for Kimi-K2.5.
2
+ """
3
+
4
+ import json
5
+ from typing import Any, Dict, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from transformers.image_processing_utils import (BaseImageProcessor,
11
+ BatchFeature)
12
+ from transformers.utils import TensorType
13
+
14
+ from .media_utils import (MediaInput, VideoChunkInput, _to_tensor,
15
+ ensure_media_type, get_video_meta, image_to_np,
16
+ navit_patchify, navit_resize_image,
17
+ navit_resize_video, normalize,
18
+ real_sample_fps_and_max_num_frames, timestamp_as_str)
19
+
20
+ try:
21
+ from mecord import VideoReader
22
+ except ImportError:
23
+ VideoReader = None
24
+
25
+
26
+ def resampling(video_bytes: bytes,
27
+ sample_indices: list[int],
28
+ key_indices=None,
29
+ frame_time_info=None,
30
+ num_threads=4) -> str:
31
+ video = VideoReader(video_bytes,
32
+ num_threads=num_threads,
33
+ frame_time_info=frame_time_info,
34
+ key_indices=key_indices)
35
+ # extract target frames
36
+ frames = video[sample_indices]
37
+ frames = [Image.fromarray(frame) for frame in frames]
38
+ return frames
39
+
40
+
41
+ class KimiK25VisionProcessor(BaseImageProcessor):
42
+ model_type = "kimi_k25"
43
+
44
+ def __init__(
45
+ self,
46
+ media_proc_cfg: dict,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.media_proc_cfg = media_proc_cfg
51
+ self.num_frames_per_chunk = media_proc_cfg[
52
+ 'temporal_merge_kernel_size']
53
+
54
+ def media_tokens_calculator(self, media: MediaInput):
55
+ media = ensure_media_type(media)
56
+ ret = self.get_resize_config(media)
57
+ return ret['num_tokens']
58
+
59
+ @classmethod
60
+ def make_chunk_prompt(cls, timestamp_text: str) -> str:
61
+ return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>"
62
+
63
+ def split_video_chunks(self,
64
+ video_url: str | bytes) -> list[list[Image.Image]]:
65
+ # video_url should be base64 str or bytes
66
+ video_spec = get_video_meta(video_url)
67
+ sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps)
68
+ sampled_nframes = max(
69
+ round(video_spec.num_frames * sample_fps / video_spec.fps), 1)
70
+ frame_inds = np.linspace(0, video_spec.num_frames - 1,
71
+ sampled_nframes).round().astype(int)
72
+ frame_inds = frame_inds.tolist()
73
+ sampled_frame_ids = []
74
+ temporal_merge_kernel_size = self.media_proc_cfg[
75
+ "temporal_merge_kernel_size"]
76
+ num_chunks = 0
77
+ chunk_timestamp = []
78
+ for i in range(0, len(frame_inds), temporal_merge_kernel_size):
79
+ sampled_frame_ids.extend(frame_inds[i:i +
80
+ temporal_merge_kernel_size])
81
+ start_time = frame_inds[i] / float(video_spec.fps)
82
+ timestamp_text = timestamp_as_str(
83
+ start_time, self.media_proc_cfg["timestamp_mode"])
84
+ chunk_timestamp.append(timestamp_text)
85
+ num_chunks += 1
86
+
87
+ sampled_frames = resampling(video_url, sampled_frame_ids)
88
+ chunks = []
89
+ for chunk_id in range(num_chunks):
90
+ chunk = sampled_frames[chunk_id *
91
+ temporal_merge_kernel_size:(chunk_id + 1) *
92
+ temporal_merge_kernel_size]
93
+ chunks.append(
94
+ VideoChunkInput(type="video_chunk",
95
+ video_chunk=chunk,
96
+ prompt=self.make_chunk_prompt(
97
+ chunk_timestamp[chunk_id])))
98
+ return chunks
99
+
100
+ def get_resize_config(self, media_input: MediaInput) -> dict:
101
+ if media_input['type'] == 'image':
102
+ w, h = media_input['image'].size
103
+ ret = navit_resize_image(
104
+ w, h, self.media_proc_cfg['patch_size'],
105
+ self.media_proc_cfg['merge_kernel_size'],
106
+ self.media_proc_cfg['in_patch_limit'],
107
+ self.media_proc_cfg['patch_limit_on_one_side'],
108
+ self.media_proc_cfg['fixed_output_tokens'])
109
+ return ret
110
+ elif media_input['type'] == 'video_chunk':
111
+ frame = media_input['video_chunk'][0]
112
+ width, height = frame.size
113
+ num_frames = len(media_input["video_chunk"])
114
+ fps = 1.0
115
+
116
+ sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames(
117
+ media_input["type"],
118
+ self.media_proc_cfg['sample_fps'],
119
+ self.media_proc_cfg['max_num_frames_each_video'],
120
+ )
121
+
122
+ in_patch_limit_each_frame = self.media_proc_cfg[
123
+ 'in_patch_limit_each_frame']
124
+ if in_patch_limit_each_frame is None:
125
+ in_patch_limit_each_frame = self.media_proc_cfg[
126
+ 'in_patch_limit']
127
+
128
+ ret = navit_resize_video(
129
+ width,
130
+ height,
131
+ num_frames,
132
+ fps,
133
+ sample_fps,
134
+ self.media_proc_cfg['patch_size'],
135
+ self.media_proc_cfg['merge_kernel_size'],
136
+ in_patch_limit_each_frame,
137
+ self.media_proc_cfg['patch_limit_on_one_side'],
138
+ self.media_proc_cfg['in_patch_limit_video'],
139
+ max_num_frames_each_video,
140
+ self.media_proc_cfg['fixed_output_tokens'],
141
+ )
142
+ return ret
143
+ else:
144
+ raise ValueError("Unsupported type: {}".format(
145
+ media_input['type']))
146
+
147
+ def resize_image(self, image: Image.Image, new_width: int, new_height: int,
148
+ pad_width: int, pad_height: int) -> np.ndarray:
149
+ image_np = image_to_np(image, (new_width, new_height), "resize")
150
+ image_np = np.pad(
151
+ image_np,
152
+ ((0, pad_height), (0, pad_width), (0, 0)),
153
+ mode="constant",
154
+ constant_values=0,
155
+ )
156
+ return image_np
157
+
158
+ def preprocess(
159
+ self,
160
+ medias: list[MediaInput],
161
+ return_tensors: Optional[Union[str, TensorType]] = None,
162
+ ) -> BatchFeature:
163
+ """
164
+ Preprocess a atom vision input (images/video_chunk) into model-ready tensors.
165
+
166
+ Args:
167
+ medias: List of MediaInput.
168
+ return_tensors: Desired output format ('pt', 'np', 'tf', or None).
169
+
170
+ Returns:
171
+ BatchFeature containing 'pixel_values' and 'grid_thws' tensors.
172
+ """
173
+ if not isinstance(medias, list):
174
+ medias = [medias]
175
+ if medias:
176
+ pixel_values = []
177
+ for item in medias:
178
+ item = ensure_media_type(item)
179
+ resize_config = self.get_resize_config(item)
180
+ new_width, new_height, pad_width, pad_height = resize_config[
181
+ 'new_width'], resize_config['new_height'], resize_config[
182
+ 'pad_width'], resize_config['pad_height']
183
+ if item['type'] == 'image':
184
+ image = item['image']
185
+ image_np = self.resize_image(image, new_width, new_height,
186
+ pad_width, pad_height)
187
+ pixel_values.append(np.expand_dims(image_np, axis=0))
188
+ elif item['type'] == 'video_chunk':
189
+ pixels = []
190
+ for frame in item['video_chunk']:
191
+ frame_np = self.resize_image(frame, new_width,
192
+ new_height, pad_width,
193
+ pad_height)
194
+ pixels.append(frame_np)
195
+ pixel_values.append(np.stack(pixels, axis=0))
196
+ else:
197
+ raise ValueError("Unsupported type: {}".format(
198
+ item['type']))
199
+ normalized_pixel_values = []
200
+ image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std'])
201
+ image_mean = np.array(self.media_proc_cfg['image_mean'])
202
+ for pixels in pixel_values:
203
+ pixels = normalize(pixels, image_mean, image_std_inv)
204
+ pixels_and_thw = navit_patchify(
205
+ pixels,
206
+ self.media_proc_cfg['patch_size'],
207
+ )
208
+ normalized_pixel_values.append(pixels_and_thw)
209
+
210
+ pixel_values = torch.cat([
211
+ _to_tensor(pixel_value['pixel_values'])
212
+ for pixel_value in normalized_pixel_values
213
+ ])
214
+ grid_thws = torch.cat([
215
+ _to_tensor(pixel_value['grid_thw'],
216
+ dtype=torch.int64).unsqueeze(0)
217
+ for pixel_value in normalized_pixel_values
218
+ ])
219
+
220
+ data = {
221
+ 'pixel_values': pixel_values,
222
+ 'grid_thws': grid_thws,
223
+ }
224
+
225
+ else:
226
+ data = {}
227
+
228
+ return BatchFeature(data=data, tensor_type=return_tensors)
229
+
230
+ def __repr__(self):
231
+ return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})"
232
+
233
+ def to_dict(self) -> Dict[str, Any]:
234
+ output = super().to_dict()
235
+ output["media_proc_cfg"] = self.media_proc_cfg
236
+ if "media_processor" in output:
237
+ del output["media_processor"]
238
+ return output
239
+
240
+ @classmethod
241
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
242
+ config = config_dict.copy()
243
+ media_proc_cfg = config.pop("media_proc_cfg", {})
244
+ return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs)
245
+
246
+ def to_json_string(self):
247
+ dictionary = self.to_dict()
248
+ for key, value in dictionary.items():
249
+ if hasattr(value, 'tolist'):
250
+ dictionary[key] = value.tolist()
251
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
media_utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import math
4
+ import os
5
+ from datetime import datetime, timezone
6
+ from typing import List, Literal, Optional, TypedDict
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pydantic import BaseModel, Field
11
+
12
+ try:
13
+ from mecord import VideoReader
14
+ except ImportError:
15
+ VideoReader = None
16
+
17
+
18
+ class VideoSpec(BaseModel):
19
+ media_type: str = Literal['video']
20
+ height: int = Field(..., gt=0, description="video frame height")
21
+ width: int = Field(..., gt=0, description="video frame width")
22
+ num_frames: int = Field(..., gt=0, description="num frames")
23
+ fps: float = Field(..., gt=0, description="average fps")
24
+
25
+ # optional, help to accelerate video reading
26
+ key_indices: list[int] = Field(None, description="key indices")
27
+ frame_time_info: dict = Field(None, description="frame time info")
28
+
29
+
30
+ class ImageInput(TypedDict):
31
+ type: Literal['image']
32
+ image: Image.Image
33
+
34
+
35
+ class VideoChunkInput(TypedDict):
36
+ type: Literal['video_chunk']
37
+ video_chunk: List[Image.Image]
38
+ prompt: Optional[str] = None
39
+
40
+
41
+ MediaInput = ImageInput | VideoChunkInput
42
+
43
+
44
+ def get_video_meta(video_src: bytes | str | os.PathLike,
45
+ accurate: bool = True) -> dict:
46
+ """Get the dimensions of a video."""
47
+ if isinstance(video_src, os.PathLike):
48
+ video_src = str(video_src)
49
+ # if b64 string, decode to bytes
50
+ if isinstance(video_src,
51
+ str) and video_src.startswith('data:video/mp4;base64,'):
52
+ video_src = base64.b64decode(video_src.split(',')[1])
53
+ video = VideoReader(video_src, auto_init=accurate, num_threads=1)
54
+ assert video.num_frames > 0, "Invalid video format."
55
+ assert video.original_width > 0 and video.original_height > 0, (
56
+ "Invalid video format.")
57
+ assert video.avg_fps > 0, "Invalid video format."
58
+ return VideoSpec(media_type='video',
59
+ height=video.original_height,
60
+ width=video.original_width,
61
+ num_frames=video.num_frames,
62
+ fps=video.avg_fps,
63
+ key_indices=video.key_indices,
64
+ frame_time_info=video.frame_time_info)
65
+
66
+
67
+ def timestamp_as_str(timestamp: float,
68
+ timestamp_mode: str = "hh:mm:ss.fff") -> str:
69
+ """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
70
+ if timestamp_mode == "hh:mm:ss.fff":
71
+ return (datetime.fromtimestamp(timestamp,
72
+ tz=timezone.utc).strftime("%H:%M:%S") +
73
+ f".{int((timestamp % 1) * 1000):03d}")
74
+ elif timestamp_mode == "mm:ss.fff":
75
+ return (datetime.fromtimestamp(timestamp,
76
+ tz=timezone.utc).strftime("%M:%S") +
77
+ f".{int((timestamp % 1) * 1000):03d}")
78
+ elif timestamp_mode == "mm:ss":
79
+ return datetime.fromtimestamp(timestamp,
80
+ tz=timezone.utc).strftime("%M:%S")
81
+ else:
82
+ raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
83
+
84
+
85
+ def navit_resize_image(
86
+ width: int,
87
+ height: int,
88
+ patch_size: int,
89
+ merge_kernel_size: int,
90
+ in_patch_limit: int,
91
+ patch_limit_on_one_side: int,
92
+ fixed_output_tokens: int | None,
93
+ ):
94
+ # Apply the patch limits.
95
+ s1 = math.sqrt(
96
+ in_patch_limit /
97
+ (max(1.0, width // patch_size) * max(1.0, height // patch_size)))
98
+ s2 = patch_limit_on_one_side * patch_size / width
99
+ s3 = patch_limit_on_one_side * patch_size / height
100
+ scale = min(1.0, s1, s2, s3)
101
+ new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
102
+ new_w = min(new_w, patch_limit_on_one_side * patch_size)
103
+ new_h = min(new_h, patch_limit_on_one_side * patch_size)
104
+
105
+ # Calculate the padding to make the height and width divisible by the merge kernel size and patch size.
106
+ factor = merge_kernel_size * patch_size
107
+
108
+ pad_height = (factor - new_h % factor) % factor
109
+ pad_width = (factor - new_w % factor) % factor
110
+
111
+ if fixed_output_tokens is not None:
112
+ num_tokens = fixed_output_tokens
113
+ else:
114
+ # Calculate new dimensions after padding and patching
115
+ token_height = (new_h + pad_height) // factor
116
+ token_width = (new_w + pad_width) // factor
117
+
118
+ assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
119
+ f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
120
+ )
121
+ assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
122
+ f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
123
+ )
124
+
125
+ num_tokens = token_height * token_width
126
+ return {
127
+ "num_tokens": num_tokens,
128
+ "new_width": new_w,
129
+ "new_height": new_h,
130
+ "pad_width": pad_width,
131
+ "pad_height": pad_height,
132
+ "sampled_nframes": 1,
133
+ }
134
+
135
+
136
+ def navit_resize_video(
137
+ width: int,
138
+ height: int,
139
+ nframes: int,
140
+ avg_fps: float,
141
+ sample_fps: float,
142
+ patch_size: int,
143
+ merge_kernel_size: int,
144
+ in_patch_limit_each_frame: int,
145
+ patch_limit_on_one_side: int,
146
+ in_patch_limit_total: int | None,
147
+ max_num_frames_each_video: int | None,
148
+ fixed_output_tokens_each_frame: int | None,
149
+ ):
150
+ sample_fps = min(sample_fps, avg_fps)
151
+ # Calculate the number of frames to sample based on target FPS
152
+ sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
153
+ if max_num_frames_each_video is not None:
154
+ sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
155
+
156
+ if in_patch_limit_total is not None:
157
+ in_patch_limit_each_frame = min(
158
+ round(in_patch_limit_total / sampled_nframes),
159
+ in_patch_limit_each_frame)
160
+
161
+ ret = navit_resize_image(
162
+ width,
163
+ height,
164
+ patch_size,
165
+ merge_kernel_size,
166
+ in_patch_limit_each_frame,
167
+ patch_limit_on_one_side,
168
+ fixed_output_tokens_each_frame,
169
+ )
170
+ ret["sampled_nframes"] = sampled_nframes
171
+ return ret
172
+
173
+
174
+ def real_sample_fps_and_max_num_frames(
175
+ type_name: Literal["video", "video_chunk"],
176
+ sample_fps: float,
177
+ max_num_frames_each_video: int | None,
178
+ ) -> tuple[int, int | None]:
179
+ if type_name == "video":
180
+ return sample_fps, max_num_frames_each_video
181
+ elif type_name == "video_chunk":
182
+ max_num_frames_each_video = None
183
+ sample_fps = math.inf
184
+ return sample_fps, max_num_frames_each_video
185
+ else:
186
+ return math.inf, None
187
+
188
+
189
+ def _to_pil(data: str | bytes):
190
+ if isinstance(data, Image.Image):
191
+
192
+ return data.convert("RGB")
193
+ elif isinstance(data, str):
194
+ if data.startswith("data:"):
195
+ raw_base64 = data.split(",")[1]
196
+ return Image.open(io.BytesIO(
197
+ base64.b64decode(raw_base64))).convert("RGB")
198
+ else:
199
+ return Image.open(data).convert("RGB")
200
+ elif isinstance(data, bytes):
201
+ return Image.open(io.BytesIO(data)).convert("RGB")
202
+ else:
203
+ raise ValueError(f"Unsupported data type: {type(data)}")
204
+
205
+
206
+ def ensure_media_type(media: MediaInput) -> MediaInput:
207
+ if media['type'] == 'image':
208
+ media['image'] = _to_pil(media['image'])
209
+ return media
210
+ elif media['type'] == 'video_chunk':
211
+ media['video_chunk'] = [
212
+ _to_pil(frame) for frame in media['video_chunk']
213
+ ]
214
+ return media
215
+ else:
216
+ raise ValueError(f"Unsupported media type: {media['type']}")
217
+
218
+
219
+ def image_to_np(
220
+ image: Image.Image,
221
+ resize_to: tuple[int, int] | None = None,
222
+ mode: str = "resize",
223
+ raise_error_for_ill_resize: bool = True,
224
+ ) -> np.ndarray:
225
+ """Convert an image to a numpy array.
226
+
227
+ Args:
228
+ content: The image to convert.
229
+ resize_to: The size to resize the image to.
230
+ mode: The mode to resize the image to.
231
+ raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
232
+
233
+ Returns:
234
+ A numpy array.
235
+ """
236
+ assert isinstance(image, Image.Image), "image must be a PIL Image"
237
+ if resize_to is not None:
238
+ if mode == "resize":
239
+ image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
240
+
241
+ elif mode == "rescale_and_pad_to_center":
242
+ scale = min(resize_to[0] / image.width,
243
+ resize_to[1] / image.height, 1.0)
244
+ new_width = round(image.width * scale)
245
+ new_height = round(image.height * scale)
246
+ if new_width == 0 or new_height == 0:
247
+ if raise_error_for_ill_resize:
248
+ raise ValueError(
249
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
250
+ )
251
+ else:
252
+ return np.zeros((resize_to[1], resize_to[0], 3),
253
+ dtype=np.uint8)
254
+
255
+ image = image.resize((new_width, new_height),
256
+ resample=Image.Resampling.BICUBIC)
257
+ padding_left = (resize_to[0] - new_width) // 2
258
+ padding_right = resize_to[0] - new_width - padding_left
259
+ padding_top = (resize_to[1] - new_height) // 2
260
+ padding_bottom = resize_to[1] - new_height - padding_top
261
+ image = np.asarray(image)
262
+ image = np.pad(
263
+ image,
264
+ ((padding_top, padding_bottom), (padding_left, padding_right),
265
+ (0, 0)),
266
+ mode="constant",
267
+ constant_values=0,
268
+ )
269
+ assert image.shape == (resize_to[1], resize_to[0], 3)
270
+
271
+ elif mode == "rescale_and_pad_to_rightbottom":
272
+ scale = min(resize_to[0] / image.width,
273
+ resize_to[1] / image.height, 1.0)
274
+ new_width = round(image.width * scale)
275
+ new_height = round(image.height * scale)
276
+ if new_width == 0 or new_height == 0:
277
+ if raise_error_for_ill_resize:
278
+ raise ValueError(
279
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
280
+ )
281
+ else:
282
+ return np.zeros((resize_to[1], resize_to[0], 3),
283
+ dtype=np.uint8)
284
+
285
+ image = image.resize((new_width, new_height),
286
+ resample=Image.Resampling.BICUBIC)
287
+ padding_right = resize_to[0] - new_width
288
+ padding_bottom = resize_to[1] - new_height
289
+ image = np.asarray(image)
290
+ image = np.pad(
291
+ image,
292
+ ((0, padding_bottom), (0, padding_right), (0, 0)),
293
+ mode="constant",
294
+ constant_values=0,
295
+ )
296
+ assert image.shape == (resize_to[1], resize_to[0], 3)
297
+
298
+ else:
299
+ raise ValueError(f"Invalid mode: {mode}")
300
+
301
+ if isinstance(image, Image.Image):
302
+ return np.asarray(image)
303
+ else:
304
+ return image
305
+
306
+
307
+ def navit_patchify(pixel_values: np.ndarray,
308
+ patch_size: int) -> dict[str, np.ndarray]:
309
+ """Reshape the pixel values to a navit shape.
310
+
311
+ Args:
312
+ pixel_values: np.ndarray, shape (t, h, w, c)
313
+ patch_size: int
314
+
315
+ Returns:
316
+ dict[str, np.ndarray]
317
+ - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
318
+ - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
319
+ """
320
+ T, H, W, C = pixel_values.shape
321
+ assert C == 3, "pixel_values must have 3 channels"
322
+
323
+ patches = pixel_values.reshape(T, H // patch_size, patch_size,
324
+ W // patch_size, patch_size, C)
325
+ # (T, H//patch_size, W//patch_size, C, patch_size, patch_size)
326
+ patches = patches.transpose(0, 1, 3, 5, 2, 4)
327
+ patches = patches.reshape(-1, C, patch_size, patch_size)
328
+ grid_thw = np.array([T, H // patch_size, W // patch_size])
329
+ return {"pixel_values": patches, "grid_thw": grid_thw}
330
+
331
+
332
+ def normalize(x: np.ndarray,
333
+ mean,
334
+ std_inv,
335
+ pixels_dtype: np.dtype = np.float32) -> np.ndarray:
336
+ """Normalize the image.
337
+
338
+ Args:
339
+ x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
340
+ mean: The mean of the image.
341
+ std_inv: The inverse of the std of the image.
342
+ pixels_dtype: The dtype of the image.
343
+ Returns:
344
+ The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
345
+ """
346
+ x = (x / 255.0).astype(pixels_dtype)
347
+ x -= mean
348
+ x *= std_inv
349
+ return x
350
+
351
+
352
+ def _to_tensor(data, **kwargs):
353
+ import torch
354
+
355
+ if isinstance(data, np.ndarray):
356
+ return torch.from_numpy(data).to(**kwargs)
357
+ elif isinstance(data, torch.Tensor):
358
+ return data.to(**kwargs)
359
+ elif isinstance(data, list):
360
+ return [_to_tensor(item, **kwargs) for item in data]
361
+ elif isinstance(data, tuple):
362
+ return tuple(_to_tensor(item, **kwargs) for item in data)
363
+ elif isinstance(data, dict):
364
+ return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
365
+ elif data is None:
366
+ return None
367
+ else:
368
+ raise ValueError(f"Unsupported data type: {type(data)}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80a09d95543a7a7eee8f6ca664534a52d2ee476627264a81bf48da66e9d5df45
3
+ size 6489880
modeling_deepseek.py ADDED
@@ -0,0 +1,1808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch DeepSeek model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.distributed as dist
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import \
35
+ _prepare_4d_causal_attention_mask
36
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast,
38
+ SequenceClassifierOutputWithPast)
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS,
41
+ is_torch_greater_or_equal_than_1_13)
42
+ from transformers.utils import (add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10, logging,
46
+ replace_return_docstrings)
47
+ from transformers.utils.import_utils import is_torch_fx_available
48
+
49
+ from .configuration_deepseek import DeepseekV3Config
50
+
51
+ if is_flash_attn_2_available():
52
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
53
+ from flash_attn.bert_padding import pad_input # noqa
54
+ from flash_attn.bert_padding import index_first_axis, unpad_input
55
+
56
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57
+ # It means that the function will not be traced through and simply appear as a node in the graph.
58
+ if is_torch_fx_available():
59
+ if not is_torch_greater_or_equal_than_1_13:
60
+ import torch.fx
61
+
62
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(
63
+ _prepare_4d_causal_attention_mask)
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
68
+
69
+
70
+ def _get_unpad_data(attention_mask):
71
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
72
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
73
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
74
+ cu_seqlens = F.pad(
75
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
76
+ return (
77
+ indices,
78
+ cu_seqlens,
79
+ max_seqlen_in_batch,
80
+ )
81
+
82
+
83
+ # code modified from transformers 4.48.3 to amend breaks in newer transformers versions
84
+ def get_usable_length(past_key_value,
85
+ new_seq_length: int,
86
+ layer_idx: Optional[int] = 0) -> int:
87
+ max_length = past_key_value.get_max_cache_shape()
88
+ previous_seq_length = past_key_value.get_seq_length(layer_idx)
89
+ if max_length is not None and max_length > 0 and previous_seq_length + new_seq_length > max_length:
90
+ return max_length - new_seq_length
91
+ return previous_seq_length
92
+
93
+
94
+ class DeepseekV3RMSNorm(nn.Module):
95
+
96
+ def __init__(self, hidden_size, eps=1e-6):
97
+ """
98
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
99
+ """
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size))
102
+ self.variance_epsilon = eps
103
+
104
+ def forward(self, hidden_states):
105
+ input_dtype = hidden_states.dtype
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
108
+ hidden_states = hidden_states * torch.rsqrt(variance +
109
+ self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+
113
+ ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
114
+
115
+
116
+ class DeepseekV3RotaryEmbedding(nn.Module):
117
+
118
+ def __init__(self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None):
123
+ super().__init__()
124
+
125
+ self.dim = dim
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.base = base
128
+ inv_freq = 1.0 / (self.base**(
129
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
130
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
131
+
132
+ # Build here to make `torch.jit.trace` work.
133
+ self._set_cos_sin_cache(
134
+ seq_len=max_position_embeddings,
135
+ device=self.inv_freq.device,
136
+ dtype=torch.get_default_dtype(),
137
+ )
138
+ self.max_seq_len_cached = None
139
+
140
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
141
+ self.max_seq_len_cached = seq_len
142
+ t = torch.arange(self.max_seq_len_cached,
143
+ device=device,
144
+ dtype=self.inv_freq.dtype)
145
+
146
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
147
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
+ emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer("cos_cached",
150
+ emb.cos().to(dtype),
151
+ persistent=False)
152
+ self.register_buffer("sin_cached",
153
+ emb.sin().to(dtype),
154
+ persistent=False)
155
+
156
+ def forward(self, x, seq_len=None):
157
+ # x: [bs, num_attention_heads, seq_len, head_size]
158
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
159
+ self._set_cos_sin_cache(seq_len=seq_len,
160
+ device=x.device,
161
+ dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3
170
+ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
171
+ """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ max_position_embeddings=2048,
177
+ base=10000,
178
+ device=None,
179
+ scaling_factor=1.0,
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(self.max_seq_len_cached,
187
+ device=device,
188
+ dtype=self.inv_freq.dtype)
189
+ t = t / self.scaling_factor
190
+
191
+ freqs = torch.outer(t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached",
195
+ emb.cos().to(dtype),
196
+ persistent=False)
197
+ self.register_buffer("sin_cached",
198
+ emb.sin().to(dtype),
199
+ persistent=False)
200
+
201
+
202
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3
203
+ class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
204
+ """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
205
+
206
+ def __init__(
207
+ self,
208
+ dim,
209
+ max_position_embeddings=2048,
210
+ base=10000,
211
+ device=None,
212
+ scaling_factor=1.0,
213
+ ):
214
+ self.scaling_factor = scaling_factor
215
+ super().__init__(dim, max_position_embeddings, base, device)
216
+
217
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
218
+ self.max_seq_len_cached = seq_len
219
+
220
+ if seq_len > self.max_position_embeddings:
221
+ base = self.base * ((self.scaling_factor * seq_len /
222
+ self.max_position_embeddings) -
223
+ (self.scaling_factor - 1))**(self.dim /
224
+ (self.dim - 2))
225
+ inv_freq = 1.0 / (base**(
226
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
227
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
228
+
229
+ t = torch.arange(self.max_seq_len_cached,
230
+ device=device,
231
+ dtype=self.inv_freq.dtype)
232
+
233
+ freqs = torch.outer(t, self.inv_freq)
234
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
235
+ emb = torch.cat((freqs, freqs), dim=-1)
236
+ self.register_buffer("cos_cached",
237
+ emb.cos().to(dtype),
238
+ persistent=False)
239
+ self.register_buffer("sin_cached",
240
+ emb.sin().to(dtype),
241
+ persistent=False)
242
+
243
+
244
+ # Inverse dim formula to find dim based on number of rotations
245
+ def yarn_find_correction_dim(num_rotations,
246
+ dim,
247
+ base=10000,
248
+ max_position_embeddings=2048):
249
+ return (dim * math.log(max_position_embeddings /
250
+ (num_rotations * 2 * math.pi))) / (2 *
251
+ math.log(base))
252
+
253
+
254
+ # Find dim range bounds based on rotations
255
+ def yarn_find_correction_range(low_rot,
256
+ high_rot,
257
+ dim,
258
+ base=10000,
259
+ max_position_embeddings=2048):
260
+ low = math.floor(
261
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
262
+ high = math.ceil(
263
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
264
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
265
+
266
+
267
+ def yarn_get_mscale(scale=1, mscale=1):
268
+ if scale <= 1:
269
+ return 1.0
270
+ return 0.1 * mscale * math.log(scale) + 1.0
271
+
272
+
273
+ def yarn_linear_ramp_mask(min, max, dim):
274
+ if min == max:
275
+ max += 0.001 # Prevent singularity
276
+
277
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
278
+ ramp_func = torch.clamp(linear_func, 0, 1)
279
+ return ramp_func
280
+
281
+
282
+ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
283
+
284
+ def __init__(
285
+ self,
286
+ dim,
287
+ max_position_embeddings=2048,
288
+ base=10000,
289
+ device=None,
290
+ scaling_factor=1.0,
291
+ original_max_position_embeddings=4096,
292
+ beta_fast=32,
293
+ beta_slow=1,
294
+ mscale=1,
295
+ mscale_all_dim=0,
296
+ ):
297
+ self.scaling_factor = scaling_factor
298
+ self.original_max_position_embeddings = original_max_position_embeddings
299
+ self.beta_fast = beta_fast
300
+ self.beta_slow = beta_slow
301
+ self.mscale = mscale
302
+ self.mscale_all_dim = mscale_all_dim
303
+ super().__init__(dim, max_position_embeddings, base, device)
304
+
305
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
306
+ self.max_seq_len_cached = seq_len
307
+ dim = self.dim
308
+
309
+ freq_extra = 1.0 / (self.base**(
310
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
311
+ freq_inter = 1.0 / (self.scaling_factor * self.base**(
312
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
313
+
314
+ low, high = yarn_find_correction_range(
315
+ self.beta_fast,
316
+ self.beta_slow,
317
+ dim,
318
+ self.base,
319
+ self.original_max_position_embeddings,
320
+ )
321
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
322
+ device=device, dtype=torch.float32)
323
+ inv_freq = freq_inter * (1 -
324
+ inv_freq_mask) + freq_extra * inv_freq_mask
325
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
326
+
327
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
328
+
329
+ freqs = torch.outer(t, inv_freq)
330
+
331
+ _mscale = float(
332
+ yarn_get_mscale(self.scaling_factor, self.mscale) /
333
+ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
334
+
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype),
337
+ persistent=False)
338
+ self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype),
339
+ persistent=False)
340
+
341
+
342
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
343
+ def rotate_half(x):
344
+ """Rotates half the hidden dims of the input."""
345
+ x1 = x[..., :x.shape[-1] // 2]
346
+ x2 = x[..., x.shape[-1] // 2:]
347
+ return torch.cat((-x2, x1), dim=-1)
348
+
349
+
350
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
351
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
352
+ """Applies Rotary Position Embedding to the query and key tensors.
353
+
354
+ Args:
355
+ q (`torch.Tensor`): The query tensor.
356
+ k (`torch.Tensor`): The key tensor.
357
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
358
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
359
+ position_ids (`torch.Tensor`):
360
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
361
+ used to pass offsetted position ids when working with a KV-cache.
362
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
363
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
364
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
365
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
366
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
367
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
368
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
369
+ Returns:
370
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
371
+ """
372
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
373
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
374
+
375
+ b, h, s, d = q.shape
376
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
377
+
378
+ b, h, s, d = k.shape
379
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
380
+
381
+ q_embed = (q * cos) + (rotate_half(q) * sin)
382
+ k_embed = (k * cos) + (rotate_half(k) * sin)
383
+ return q_embed, k_embed
384
+
385
+
386
+ class DeepseekV3MLP(nn.Module):
387
+
388
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
389
+ super().__init__()
390
+ self.config = config
391
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
392
+ self.intermediate_size = (config.intermediate_size if intermediate_size
393
+ is None else intermediate_size)
394
+
395
+ self.gate_proj = nn.Linear(self.hidden_size,
396
+ self.intermediate_size,
397
+ bias=False)
398
+ self.up_proj = nn.Linear(self.hidden_size,
399
+ self.intermediate_size,
400
+ bias=False)
401
+ self.down_proj = nn.Linear(self.intermediate_size,
402
+ self.hidden_size,
403
+ bias=False)
404
+ self.act_fn = ACT2FN[config.hidden_act]
405
+
406
+ def forward(self, x):
407
+ down_proj = self.down_proj(
408
+ self.act_fn(self.gate_proj(x)) * self.up_proj(x))
409
+ return down_proj
410
+
411
+
412
+ class MoEGate(nn.Module):
413
+
414
+ def __init__(self, config):
415
+ super().__init__()
416
+ self.config = config
417
+ self.top_k = config.num_experts_per_tok
418
+ self.n_routed_experts = config.n_routed_experts
419
+ self.routed_scaling_factor = config.routed_scaling_factor
420
+ self.scoring_func = config.scoring_func
421
+ self.seq_aux = config.seq_aux
422
+ self.topk_method = config.topk_method
423
+ self.n_group = config.n_group
424
+ self.topk_group = config.topk_group
425
+
426
+ # topk selection algorithm
427
+ self.norm_topk_prob = config.norm_topk_prob
428
+ self.gating_dim = config.hidden_size
429
+ self.weight = nn.Parameter(
430
+ torch.empty((self.n_routed_experts, self.gating_dim)))
431
+ if self.topk_method == "noaux_tc":
432
+ self.e_score_correction_bias = nn.Parameter(
433
+ torch.empty((self.n_routed_experts)))
434
+ self.reset_parameters()
435
+
436
+ def reset_parameters(self) -> None:
437
+ import torch.nn.init as init
438
+
439
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
440
+
441
+ def forward(self, hidden_states):
442
+ bsz, seq_len, h = hidden_states.shape
443
+ ### compute gating score
444
+ hidden_states = hidden_states.view(-1, h)
445
+ logits = F.linear(hidden_states.type(torch.float32),
446
+ self.weight.type(torch.float32), None)
447
+ if self.scoring_func == "sigmoid":
448
+ scores = logits.sigmoid()
449
+ else:
450
+ raise NotImplementedError(
451
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
452
+ )
453
+
454
+ ### select top-k experts
455
+ if self.topk_method == "noaux_tc":
456
+ assert not self.training
457
+ scores_for_choice = scores.view(
458
+ bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
459
+ group_scores = (scores_for_choice.view(
460
+ bsz * seq_len, self.n_group,
461
+ -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
462
+ group_idx = torch.topk(group_scores,
463
+ k=self.topk_group,
464
+ dim=-1,
465
+ sorted=False)[1] # [n, top_k_group]
466
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
467
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
468
+ score_mask = (group_mask.unsqueeze(-1).expand(
469
+ bsz * seq_len, self.n_group,
470
+ self.n_routed_experts // self.n_group).reshape(
471
+ bsz * seq_len, -1)) # [n, e]
472
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
473
+ 0.0) # [n, e]
474
+ _, topk_idx = torch.topk(tmp_scores,
475
+ k=self.top_k,
476
+ dim=-1,
477
+ sorted=False)
478
+ topk_weight = scores.gather(1, topk_idx)
479
+ else:
480
+ raise NotImplementedError(
481
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
482
+ )
483
+
484
+ ### norm gate to sum 1
485
+ if self.top_k > 1 and self.norm_topk_prob:
486
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
487
+ topk_weight = topk_weight / denominator
488
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
489
+
490
+ return topk_idx, topk_weight
491
+
492
+
493
+ class DeepseekV3MoE(nn.Module):
494
+ """
495
+ A mixed expert module containing shared experts.
496
+ """
497
+
498
+ def __init__(self, config):
499
+ super().__init__()
500
+ self.config = config
501
+ self.num_experts_per_tok = config.num_experts_per_tok
502
+
503
+ if hasattr(config, "ep_size") and config.ep_size > 1:
504
+ assert config.ep_size == dist.get_world_size()
505
+ self.ep_size = config.ep_size
506
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
507
+ self.ep_rank = dist.get_rank()
508
+ self.experts = nn.ModuleList([
509
+ (DeepseekV3MLP(config,
510
+ intermediate_size=config.moe_intermediate_size)
511
+ if i >= self.ep_rank * self.experts_per_rank
512
+ and i < (self.ep_rank + 1) * self.experts_per_rank else None)
513
+ for i in range(config.n_routed_experts)
514
+ ])
515
+ else:
516
+ self.ep_size = 1
517
+ self.experts_per_rank = config.n_routed_experts
518
+ self.ep_rank = 0
519
+ self.experts = nn.ModuleList([
520
+ DeepseekV3MLP(config,
521
+ intermediate_size=config.moe_intermediate_size)
522
+ for i in range(config.n_routed_experts)
523
+ ])
524
+ self.gate = MoEGate(config)
525
+ if config.n_shared_experts is not None:
526
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
527
+ self.shared_experts = DeepseekV3MLP(
528
+ config=config, intermediate_size=intermediate_size)
529
+
530
+ def forward(self, hidden_states):
531
+ identity = hidden_states
532
+ orig_shape = hidden_states.shape
533
+ topk_idx, topk_weight = self.gate(hidden_states)
534
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
535
+ flat_topk_idx = topk_idx.view(-1)
536
+ if not self.training:
537
+ y = self.moe_infer(hidden_states, topk_idx,
538
+ topk_weight).view(*orig_shape)
539
+ if self.config.n_shared_experts is not None:
540
+ y = y + self.shared_experts(identity)
541
+ return y
542
+
543
+ @torch.no_grad()
544
+ def moe_infer(self, x, topk_ids, topk_weight):
545
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
546
+ cnts.scatter_(1, topk_ids, 1)
547
+ tokens_per_expert = cnts.sum(dim=0)
548
+ idxs = topk_ids.view(-1).argsort()
549
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
550
+ sorted_tokens_shape = sorted_tokens.shape
551
+ if self.ep_size > 1:
552
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
553
+ -1).sum(dim=1)
554
+ tokens_per_expert_group = tokens_per_expert.new_empty(
555
+ tokens_per_expert.shape[0])
556
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
557
+ output_splits = (tokens_per_expert_group.view(
558
+ self.ep_size, -1).sum(1).cpu().numpy().tolist())
559
+ gathered_tokens = sorted_tokens.new_empty(
560
+ tokens_per_expert_group.sum(dim=0).cpu().item(),
561
+ sorted_tokens.shape[1])
562
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
563
+ dist.all_to_all(
564
+ list(gathered_tokens.split(output_splits)),
565
+ list(sorted_tokens.split(input_split_sizes)),
566
+ )
567
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
568
+ self.ep_size, self.experts_per_rank).sum(dim=0)
569
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0], ),
570
+ dtype=np.int32)
571
+ s = 0
572
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
573
+ gatherd_idxs[s:s + k] = i % self.experts_per_rank
574
+ s += k
575
+ gatherd_idxs = gatherd_idxs.argsort()
576
+ sorted_tokens = gathered_tokens[gatherd_idxs]
577
+ tokens_per_expert = tokens_per_expert_post_gather
578
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
579
+
580
+ outputs = []
581
+ start_idx = 0
582
+ for i, num_tokens in enumerate(tokens_per_expert):
583
+ end_idx = start_idx + num_tokens
584
+ if num_tokens == 0:
585
+ continue
586
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
587
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
588
+ expert_out = expert(tokens_for_this_expert)
589
+ outputs.append(expert_out)
590
+ start_idx = end_idx
591
+
592
+ outs = torch.cat(outputs,
593
+ dim=0) if len(outputs) else sorted_tokens.new_empty(0)
594
+ if self.ep_size > 1:
595
+ new_x = torch.empty_like(outs)
596
+ new_x[gatherd_idxs] = outs
597
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
598
+ dist.all_to_all(
599
+ list(gathered_tokens.split(input_split_sizes)),
600
+ list(new_x.split(output_splits)),
601
+ )
602
+ outs = gathered_tokens
603
+
604
+ new_x = torch.empty_like(outs)
605
+ new_x[idxs] = outs
606
+ final_out = (new_x.view(
607
+ *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
608
+ topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
609
+ return final_out
610
+
611
+
612
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
613
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
614
+ """
615
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
616
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
617
+ """
618
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
619
+ if n_rep == 1:
620
+ return hidden_states
621
+ hidden_states = hidden_states[:, :,
622
+ None, :, :].expand(batch,
623
+ num_key_value_heads,
624
+ n_rep, slen, head_dim)
625
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
626
+ head_dim)
627
+
628
+
629
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3
630
+ class DeepseekV3Attention(nn.Module):
631
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
632
+
633
+ def __init__(self,
634
+ config: DeepseekV3Config,
635
+ layer_idx: Optional[int] = None):
636
+ super().__init__()
637
+ self.config = config
638
+ self.layer_idx = layer_idx
639
+ if layer_idx is None:
640
+ logger.warning_once(
641
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
642
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
643
+ "when creating this class.")
644
+
645
+ self.attention_dropout = config.attention_dropout
646
+ self.hidden_size = config.hidden_size
647
+ self.num_heads = config.num_attention_heads
648
+
649
+ self.max_position_embeddings = config.max_position_embeddings
650
+ self.rope_theta = config.rope_theta
651
+ self.q_lora_rank = config.q_lora_rank
652
+ self.qk_rope_head_dim = config.qk_rope_head_dim
653
+ self.kv_lora_rank = config.kv_lora_rank
654
+ self.v_head_dim = config.v_head_dim
655
+ self.qk_nope_head_dim = config.qk_nope_head_dim
656
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
657
+
658
+ self.is_causal = True
659
+
660
+ if self.q_lora_rank is None:
661
+ self.q_proj = nn.Linear(self.hidden_size,
662
+ self.num_heads * self.q_head_dim,
663
+ bias=False)
664
+ else:
665
+ self.q_a_proj = nn.Linear(self.hidden_size,
666
+ config.q_lora_rank,
667
+ bias=config.attention_bias)
668
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
669
+ self.q_b_proj = nn.Linear(config.q_lora_rank,
670
+ self.num_heads * self.q_head_dim,
671
+ bias=False)
672
+
673
+ self.kv_a_proj_with_mqa = nn.Linear(
674
+ self.hidden_size,
675
+ config.kv_lora_rank + config.qk_rope_head_dim,
676
+ bias=config.attention_bias,
677
+ )
678
+ self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
679
+ self.kv_b_proj = nn.Linear(
680
+ config.kv_lora_rank,
681
+ self.num_heads *
682
+ (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
683
+ bias=False,
684
+ )
685
+
686
+ self.o_proj = nn.Linear(
687
+ self.num_heads * self.v_head_dim,
688
+ self.hidden_size,
689
+ bias=config.attention_bias,
690
+ )
691
+ self._init_rope()
692
+
693
+ self.softmax_scale = self.q_head_dim**(-0.5)
694
+ if self.config.rope_scaling is not None:
695
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
696
+ scaling_factor = self.config.rope_scaling["factor"]
697
+ if mscale_all_dim:
698
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
699
+ self.softmax_scale = self.softmax_scale * mscale * mscale
700
+
701
+ def _init_rope(self):
702
+ if self.config.rope_scaling is None:
703
+ self.rotary_emb = DeepseekV3RotaryEmbedding(
704
+ self.qk_rope_head_dim,
705
+ max_position_embeddings=self.max_position_embeddings,
706
+ base=self.rope_theta,
707
+ )
708
+ else:
709
+ scaling_type = self.config.rope_scaling["type"]
710
+ scaling_factor = self.config.rope_scaling["factor"]
711
+ if scaling_type == "linear":
712
+ self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
713
+ self.qk_rope_head_dim,
714
+ max_position_embeddings=self.max_position_embeddings,
715
+ scaling_factor=scaling_factor,
716
+ base=self.rope_theta,
717
+ )
718
+ elif scaling_type == "dynamic":
719
+ self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
720
+ self.qk_rope_head_dim,
721
+ max_position_embeddings=self.max_position_embeddings,
722
+ scaling_factor=scaling_factor,
723
+ base=self.rope_theta,
724
+ )
725
+ elif scaling_type == "yarn":
726
+ kwargs = {
727
+ key: self.config.rope_scaling[key]
728
+ for key in [
729
+ "original_max_position_embeddings",
730
+ "beta_fast",
731
+ "beta_slow",
732
+ "mscale",
733
+ "mscale_all_dim",
734
+ ] if key in self.config.rope_scaling
735
+ }
736
+ self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
737
+ self.qk_rope_head_dim,
738
+ max_position_embeddings=self.max_position_embeddings,
739
+ scaling_factor=scaling_factor,
740
+ base=self.rope_theta,
741
+ **kwargs,
742
+ )
743
+ else:
744
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
745
+
746
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
747
+ return (tensor.view(bsz, seq_len, self.num_heads,
748
+ self.v_head_dim).transpose(1, 2).contiguous())
749
+
750
+ def forward(
751
+ self,
752
+ hidden_states: torch.Tensor,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ position_ids: Optional[torch.LongTensor] = None,
755
+ past_key_value: Optional[Cache] = None,
756
+ output_attentions: bool = False,
757
+ use_cache: bool = False,
758
+ **kwargs,
759
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
760
+ Optional[Tuple[torch.Tensor]]]:
761
+ if "padding_mask" in kwargs:
762
+ warnings.warn(
763
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
764
+ )
765
+ bsz, q_len, _ = hidden_states.size()
766
+
767
+ if self.q_lora_rank is None:
768
+ q = self.q_proj(hidden_states)
769
+ else:
770
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
771
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
772
+ q_nope, q_pe = torch.split(
773
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
774
+
775
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
776
+ compressed_kv, k_pe = torch.split(
777
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
778
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
779
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
780
+ bsz, q_len, self.num_heads,
781
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
782
+
783
+ k_nope, value_states = torch.split(
784
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
785
+ kv_seq_len = value_states.shape[-2]
786
+ if past_key_value is not None:
787
+ if self.layer_idx is None:
788
+ raise ValueError(
789
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
790
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
791
+ "with a layer index.")
792
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
793
+ self.layer_idx)
794
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
795
+
796
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
797
+
798
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
799
+ self.q_head_dim)
800
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
801
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
802
+
803
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
804
+ self.q_head_dim)
805
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
806
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
807
+ if past_key_value is not None:
808
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
809
+ key_states, value_states = past_key_value.update(
810
+ key_states, value_states, self.layer_idx, cache_kwargs)
811
+
812
+ attn_weights = (
813
+ torch.matmul(query_states, key_states.transpose(2, 3)) *
814
+ self.softmax_scale)
815
+
816
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
817
+ raise ValueError(
818
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
819
+ f" {attn_weights.size()}")
820
+ assert attention_mask is not None
821
+ if attention_mask is not None:
822
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
823
+ raise ValueError(
824
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
825
+ )
826
+ attn_weights = attn_weights + attention_mask
827
+
828
+ # upcast attention to fp32
829
+ attn_weights = nn.functional.softmax(attn_weights,
830
+ dim=-1,
831
+ dtype=torch.float32).to(
832
+ query_states.dtype)
833
+ attn_weights = nn.functional.dropout(attn_weights,
834
+ p=self.attention_dropout,
835
+ training=self.training)
836
+ attn_output = torch.matmul(attn_weights, value_states)
837
+
838
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
839
+ raise ValueError(
840
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
841
+ f" {attn_output.size()}")
842
+
843
+ attn_output = attn_output.transpose(1, 2).contiguous()
844
+
845
+ attn_output = attn_output.reshape(bsz, q_len,
846
+ self.num_heads * self.v_head_dim)
847
+
848
+ attn_output = self.o_proj(attn_output)
849
+
850
+ if not output_attentions:
851
+ attn_weights = None
852
+
853
+ return attn_output, attn_weights, past_key_value
854
+
855
+
856
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
857
+ class DeepseekV3FlashAttention2(DeepseekV3Attention):
858
+ """
859
+ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
860
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
861
+ flash attention and deal with padding tokens in case the input contains any of them.
862
+ """
863
+
864
+ def __init__(self, *args, **kwargs):
865
+ super().__init__(*args, **kwargs)
866
+
867
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
868
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
869
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
870
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
871
+ )
872
+
873
+ def forward(
874
+ self,
875
+ hidden_states: torch.Tensor,
876
+ attention_mask: Optional[torch.LongTensor] = None,
877
+ position_ids: Optional[torch.LongTensor] = None,
878
+ past_key_value: Optional[Cache] = None,
879
+ output_attentions: bool = False,
880
+ use_cache: bool = False,
881
+ **kwargs,
882
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
883
+ Optional[Tuple[torch.Tensor]]]:
884
+ # DeepseekV3FlashAttention2 attention does not support output_attentions
885
+ if "padding_mask" in kwargs:
886
+ warnings.warn(
887
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
888
+ )
889
+
890
+ # overwrite attention_mask with padding_mask
891
+ attention_mask = kwargs.pop("padding_mask")
892
+
893
+ output_attentions = False
894
+
895
+ bsz, q_len, _ = hidden_states.size()
896
+
897
+ if self.q_lora_rank is None:
898
+ q = self.q_proj(hidden_states)
899
+ else:
900
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
901
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
902
+ q_nope, q_pe = torch.split(
903
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
904
+
905
+ # Flash attention requires the input to have the shape
906
+ # batch_size x seq_length x head_dim x hidden_dim
907
+ # therefore we just need to keep the original shape
908
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
909
+ compressed_kv, k_pe = torch.split(
910
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
911
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
912
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
913
+ bsz, q_len, self.num_heads,
914
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
915
+
916
+ k_nope, value_states = torch.split(
917
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
918
+ kv_seq_len = value_states.shape[-2]
919
+
920
+ kv_seq_len = value_states.shape[-2]
921
+ if past_key_value is not None:
922
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
923
+ self.layer_idx)
924
+
925
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
926
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
927
+
928
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
929
+ self.q_head_dim)
930
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
931
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
932
+
933
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
934
+ self.q_head_dim)
935
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
936
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
937
+
938
+ if self.q_head_dim != self.v_head_dim:
939
+ value_states = F.pad(value_states,
940
+ [0, self.q_head_dim - self.v_head_dim])
941
+
942
+ if past_key_value is not None:
943
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
944
+ key_states, value_states = past_key_value.update(
945
+ key_states, value_states, self.layer_idx, cache_kwargs)
946
+
947
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
948
+ # to be able to avoid many of these transpose/reshape/view.
949
+ query_states = query_states.transpose(1, 2)
950
+ key_states = key_states.transpose(1, 2)
951
+ value_states = value_states.transpose(1, 2)
952
+
953
+ dropout_rate = self.attention_dropout if self.training else 0.0
954
+
955
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
956
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
957
+ # cast them back in the correct dtype just to be sure everything works as expected.
958
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
959
+ # in fp32. (DeepseekV3RMSNorm handles it correctly)
960
+
961
+ input_dtype = query_states.dtype
962
+ if input_dtype == torch.float32:
963
+ # Handle the case where the model is quantized
964
+ if hasattr(self.config, "_pre_quantization_dtype"):
965
+ target_dtype = self.config._pre_quantization_dtype
966
+ elif torch.is_autocast_enabled():
967
+ target_dtype = torch.get_autocast_gpu_dtype()
968
+ else:
969
+ target_dtype = (self.q_proj.weight.dtype if self.q_lora_rank
970
+ is None else self.q_a_proj.weight.dtype)
971
+
972
+ logger.warning_once(
973
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
974
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
975
+ f" {target_dtype}.")
976
+
977
+ query_states = query_states.to(target_dtype)
978
+ key_states = key_states.to(target_dtype)
979
+ value_states = value_states.to(target_dtype)
980
+
981
+ attn_output = self._flash_attention_forward(
982
+ query_states,
983
+ key_states,
984
+ value_states,
985
+ attention_mask,
986
+ q_len,
987
+ dropout=dropout_rate,
988
+ softmax_scale=self.softmax_scale,
989
+ )
990
+ if self.q_head_dim != self.v_head_dim:
991
+ attn_output = attn_output[:, :, :, :self.v_head_dim]
992
+
993
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads *
994
+ self.v_head_dim).contiguous()
995
+ attn_output = self.o_proj(attn_output)
996
+
997
+ if not output_attentions:
998
+ attn_weights = None
999
+
1000
+ return attn_output, attn_weights, past_key_value
1001
+
1002
+ def _flash_attention_forward(
1003
+ self,
1004
+ query_states,
1005
+ key_states,
1006
+ value_states,
1007
+ attention_mask,
1008
+ query_length,
1009
+ dropout=0.0,
1010
+ softmax_scale=None,
1011
+ ):
1012
+ """
1013
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1014
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1015
+
1016
+ Args:
1017
+ query_states (`torch.Tensor`):
1018
+ Input query states to be passed to Flash Attention API
1019
+ key_states (`torch.Tensor`):
1020
+ Input key states to be passed to Flash Attention API
1021
+ value_states (`torch.Tensor`):
1022
+ Input value states to be passed to Flash Attention API
1023
+ attention_mask (`torch.Tensor`):
1024
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1025
+ position of padding tokens and 1 for the position of non-padding tokens.
1026
+ dropout (`int`, *optional*):
1027
+ Attention dropout
1028
+ softmax_scale (`float`, *optional*):
1029
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1030
+ """
1031
+ if not self._flash_attn_uses_top_left_mask:
1032
+ causal = self.is_causal
1033
+ else:
1034
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
1035
+ causal = self.is_causal and query_length != 1
1036
+
1037
+ # Contains at least one padding token in the sequence
1038
+ if attention_mask is not None:
1039
+ batch_size = query_states.shape[0]
1040
+ (
1041
+ query_states,
1042
+ key_states,
1043
+ value_states,
1044
+ indices_q,
1045
+ cu_seq_lens,
1046
+ max_seq_lens,
1047
+ ) = self._upad_input(query_states, key_states, value_states,
1048
+ attention_mask, query_length)
1049
+
1050
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1051
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1052
+
1053
+ attn_output_unpad = flash_attn_varlen_func(
1054
+ query_states,
1055
+ key_states,
1056
+ value_states,
1057
+ cu_seqlens_q=cu_seqlens_q,
1058
+ cu_seqlens_k=cu_seqlens_k,
1059
+ max_seqlen_q=max_seqlen_in_batch_q,
1060
+ max_seqlen_k=max_seqlen_in_batch_k,
1061
+ dropout_p=dropout,
1062
+ softmax_scale=softmax_scale,
1063
+ causal=causal,
1064
+ )
1065
+
1066
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
1067
+ query_length)
1068
+ else:
1069
+ attn_output = flash_attn_func(
1070
+ query_states,
1071
+ key_states,
1072
+ value_states,
1073
+ dropout,
1074
+ softmax_scale=softmax_scale,
1075
+ causal=causal,
1076
+ )
1077
+
1078
+ return attn_output
1079
+
1080
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
1081
+ query_length):
1082
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
1083
+ attention_mask)
1084
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1085
+
1086
+ key_layer = index_first_axis(
1087
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1088
+ head_dim),
1089
+ indices_k,
1090
+ )
1091
+ value_layer = index_first_axis(
1092
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1093
+ head_dim),
1094
+ indices_k,
1095
+ )
1096
+ if query_length == kv_seq_len:
1097
+ query_layer = index_first_axis(
1098
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
1099
+ head_dim),
1100
+ indices_k,
1101
+ )
1102
+ cu_seqlens_q = cu_seqlens_k
1103
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1104
+ indices_q = indices_k
1105
+ elif query_length == 1:
1106
+ max_seqlen_in_batch_q = 1
1107
+ cu_seqlens_q = torch.arange(
1108
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1109
+ ) # There is a memcpy here, that is very bad.
1110
+ indices_q = cu_seqlens_q[:-1]
1111
+ query_layer = query_layer.squeeze(1)
1112
+ else:
1113
+ # The -q_len: slice assumes left padding.
1114
+ attention_mask = attention_mask[:, -query_length:]
1115
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1116
+ query_layer, attention_mask)
1117
+
1118
+ return (
1119
+ query_layer,
1120
+ key_layer,
1121
+ value_layer,
1122
+ indices_q,
1123
+ (cu_seqlens_q, cu_seqlens_k),
1124
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1125
+ )
1126
+
1127
+
1128
+ ATTENTION_CLASSES = {
1129
+ "eager": DeepseekV3Attention,
1130
+ "flash_attention_2": DeepseekV3FlashAttention2,
1131
+ }
1132
+
1133
+
1134
+ class DeepseekV3DecoderLayer(nn.Module):
1135
+
1136
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
1137
+ super().__init__()
1138
+ self.hidden_size = config.hidden_size
1139
+
1140
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1141
+ config=config, layer_idx=layer_idx)
1142
+
1143
+ self.mlp = (DeepseekV3MoE(config) if
1144
+ (config.n_routed_experts is not None
1145
+ and layer_idx >= config.first_k_dense_replace
1146
+ and layer_idx % config.moe_layer_freq == 0) else
1147
+ DeepseekV3MLP(config))
1148
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size,
1149
+ eps=config.rms_norm_eps)
1150
+ self.post_attention_layernorm = DeepseekV3RMSNorm(
1151
+ config.hidden_size, eps=config.rms_norm_eps)
1152
+
1153
+ def forward(
1154
+ self,
1155
+ hidden_states: torch.Tensor,
1156
+ attention_mask: Optional[torch.Tensor] = None,
1157
+ position_ids: Optional[torch.LongTensor] = None,
1158
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1159
+ output_attentions: Optional[bool] = False,
1160
+ use_cache: Optional[bool] = False,
1161
+ **kwargs,
1162
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
1163
+ torch.FloatTensor]]]:
1164
+ """
1165
+ Args:
1166
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1167
+ attention_mask (`torch.FloatTensor`, *optional*):
1168
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1169
+ query_sequence_length, key_sequence_length)` if default attention is used.
1170
+ output_attentions (`bool`, *optional*):
1171
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1172
+ returned tensors for more detail.
1173
+ use_cache (`bool`, *optional*):
1174
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1175
+ (see `past_key_values`).
1176
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1177
+ """
1178
+ if "padding_mask" in kwargs:
1179
+ warnings.warn(
1180
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1181
+ )
1182
+ residual = hidden_states
1183
+
1184
+ hidden_states = self.input_layernorm(hidden_states)
1185
+
1186
+ # Self Attention
1187
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1188
+ hidden_states=hidden_states,
1189
+ attention_mask=attention_mask,
1190
+ position_ids=position_ids,
1191
+ past_key_value=past_key_value,
1192
+ output_attentions=output_attentions,
1193
+ use_cache=use_cache,
1194
+ **kwargs,
1195
+ )
1196
+ hidden_states = residual + hidden_states
1197
+
1198
+ # Fully Connected
1199
+ residual = hidden_states
1200
+ hidden_states = self.post_attention_layernorm(hidden_states)
1201
+ hidden_states = self.mlp(hidden_states)
1202
+ hidden_states = residual + hidden_states
1203
+
1204
+ outputs = (hidden_states, )
1205
+
1206
+ if output_attentions:
1207
+ outputs += (self_attn_weights, )
1208
+
1209
+ if use_cache:
1210
+ outputs += (present_key_value, )
1211
+
1212
+ return outputs
1213
+
1214
+
1215
+ DeepseekV3_START_DOCSTRING = r"""
1216
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1217
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1218
+ etc.)
1219
+
1220
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1221
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1222
+ and behavior.
1223
+
1224
+ Parameters:
1225
+ config ([`DeepseekV3Config`]):
1226
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1227
+ load the weights associated with the model, only the configuration. Check out the
1228
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1229
+ """
1230
+
1231
+
1232
+ @add_start_docstrings(
1233
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1234
+ DeepseekV3_START_DOCSTRING,
1235
+ )
1236
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
1237
+ config_class = DeepseekV3Config
1238
+ base_model_prefix = "model"
1239
+ supports_gradient_checkpointing = True
1240
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
1241
+ _skip_keys_device_placement = "past_key_values"
1242
+ _supports_flash_attn_2 = True
1243
+ _supports_cache_class = True
1244
+
1245
+ def _init_weights(self, module):
1246
+ std = self.config.initializer_range
1247
+ if isinstance(module, nn.Linear):
1248
+ module.weight.data.normal_(mean=0.0, std=std)
1249
+ if module.bias is not None:
1250
+ module.bias.data.zero_()
1251
+ elif isinstance(module, nn.Embedding):
1252
+ module.weight.data.normal_(mean=0.0, std=std)
1253
+ if module.padding_idx is not None:
1254
+ module.weight.data[module.padding_idx].zero_()
1255
+
1256
+
1257
+ DeepseekV3_INPUTS_DOCSTRING = r"""
1258
+ Args:
1259
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1260
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1261
+ it.
1262
+
1263
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1264
+ [`PreTrainedTokenizer.__call__`] for details.
1265
+
1266
+ [What are input IDs?](../glossary#input-ids)
1267
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1268
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1269
+
1270
+ - 1 for tokens that are **not masked**,
1271
+ - 0 for tokens that are **masked**.
1272
+
1273
+ [What are attention masks?](../glossary#attention-mask)
1274
+
1275
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1276
+ [`PreTrainedTokenizer.__call__`] for details.
1277
+
1278
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1279
+ `past_key_values`).
1280
+
1281
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1282
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1283
+ information on the default strategy.
1284
+
1285
+ - 1 indicates the head is **not masked**,
1286
+ - 0 indicates the head is **masked**.
1287
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1288
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1289
+ config.n_positions - 1]`.
1290
+
1291
+ [What are position IDs?](../glossary#position-ids)
1292
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1293
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1294
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1295
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1296
+
1297
+ Two formats are allowed:
1298
+ - a [`~cache_utils.Cache`] instance;
1299
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1300
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1301
+ cache format.
1302
+
1303
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1304
+ legacy cache format will be returned.
1305
+
1306
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1307
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1308
+ of shape `(batch_size, sequence_length)`.
1309
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1310
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1311
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1312
+ model's internal embedding lookup matrix.
1313
+ use_cache (`bool`, *optional*):
1314
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1315
+ `past_key_values`).
1316
+ output_attentions (`bool`, *optional*):
1317
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1318
+ tensors for more detail.
1319
+ output_hidden_states (`bool`, *optional*):
1320
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1321
+ more detail.
1322
+ return_dict (`bool`, *optional*):
1323
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1324
+ """
1325
+
1326
+
1327
+ @add_start_docstrings(
1328
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1329
+ DeepseekV3_START_DOCSTRING,
1330
+ )
1331
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1332
+ """
1333
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
1334
+
1335
+ Args:
1336
+ config: DeepseekV3Config
1337
+ """
1338
+
1339
+ def __init__(self, config: DeepseekV3Config):
1340
+ super().__init__(config)
1341
+ self.padding_idx = config.pad_token_id
1342
+ self.vocab_size = config.vocab_size
1343
+
1344
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
1345
+ self.padding_idx)
1346
+ self.layers = nn.ModuleList([
1347
+ DeepseekV3DecoderLayer(config, layer_idx)
1348
+ for layer_idx in range(config.num_hidden_layers)
1349
+ ])
1350
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1351
+ self.norm = DeepseekV3RMSNorm(config.hidden_size,
1352
+ eps=config.rms_norm_eps)
1353
+
1354
+ self.gradient_checkpointing = False
1355
+ # Initialize weights and apply final processing
1356
+ self.post_init()
1357
+
1358
+ def get_input_embeddings(self):
1359
+ return self.embed_tokens
1360
+
1361
+ def set_input_embeddings(self, value):
1362
+ self.embed_tokens = value
1363
+
1364
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1365
+ def forward(
1366
+ self,
1367
+ input_ids: torch.LongTensor = None,
1368
+ attention_mask: Optional[torch.Tensor] = None,
1369
+ position_ids: Optional[torch.LongTensor] = None,
1370
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1371
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1372
+ use_cache: Optional[bool] = None,
1373
+ output_attentions: Optional[bool] = None,
1374
+ output_hidden_states: Optional[bool] = None,
1375
+ return_dict: Optional[bool] = None,
1376
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1377
+ output_attentions = (output_attentions if output_attentions is not None
1378
+ else self.config.output_attentions)
1379
+ output_hidden_states = (output_hidden_states
1380
+ if output_hidden_states is not None else
1381
+ self.config.output_hidden_states)
1382
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1383
+
1384
+ return_dict = (return_dict if return_dict is not None else
1385
+ self.config.use_return_dict)
1386
+
1387
+ # retrieve input_ids and inputs_embeds
1388
+ if input_ids is not None and inputs_embeds is not None:
1389
+ raise ValueError(
1390
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1391
+ )
1392
+ elif input_ids is not None:
1393
+ batch_size, seq_length = input_ids.shape[:2]
1394
+ elif inputs_embeds is not None:
1395
+ batch_size, seq_length = inputs_embeds.shape[:2]
1396
+ else:
1397
+ raise ValueError(
1398
+ "You have to specify either input_ids or inputs_embeds")
1399
+
1400
+ past_key_values_length = 0
1401
+ if use_cache:
1402
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1403
+ if use_legacy_cache:
1404
+ past_key_values = DynamicCache.from_legacy_cache(
1405
+ past_key_values)
1406
+ past_key_values_length = get_usable_length(past_key_values,
1407
+ seq_length)
1408
+
1409
+ if position_ids is None:
1410
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1411
+ position_ids = torch.arange(
1412
+ past_key_values_length,
1413
+ seq_length + past_key_values_length,
1414
+ dtype=torch.long,
1415
+ device=device,
1416
+ )
1417
+ position_ids = position_ids.unsqueeze(0)
1418
+
1419
+ if inputs_embeds is None:
1420
+ inputs_embeds = self.embed_tokens(input_ids)
1421
+
1422
+ if self._use_flash_attention_2:
1423
+ # 2d mask is passed through the layers
1424
+ attention_mask = (attention_mask if
1425
+ (attention_mask is not None
1426
+ and 0 in attention_mask) else None)
1427
+ else:
1428
+ # 4d mask is passed through the layers
1429
+ attention_mask = _prepare_4d_causal_attention_mask(
1430
+ attention_mask,
1431
+ (batch_size, seq_length),
1432
+ inputs_embeds,
1433
+ past_key_values_length,
1434
+ )
1435
+
1436
+ # embed positions
1437
+ hidden_states = inputs_embeds
1438
+
1439
+ # decoder layers
1440
+ all_hidden_states = () if output_hidden_states else None
1441
+ all_self_attns = () if output_attentions else None
1442
+ next_decoder_cache = None
1443
+
1444
+ for decoder_layer in self.layers:
1445
+ if output_hidden_states:
1446
+ all_hidden_states += (hidden_states, )
1447
+
1448
+ layer_outputs = decoder_layer(
1449
+ hidden_states,
1450
+ attention_mask=attention_mask,
1451
+ position_ids=position_ids,
1452
+ past_key_value=past_key_values,
1453
+ output_attentions=output_attentions,
1454
+ use_cache=use_cache,
1455
+ )
1456
+
1457
+ hidden_states = layer_outputs[0]
1458
+
1459
+ if use_cache:
1460
+ next_decoder_cache = layer_outputs[
1461
+ 2 if output_attentions else 1]
1462
+
1463
+ if output_attentions:
1464
+ all_self_attns += (layer_outputs[1], )
1465
+
1466
+ hidden_states = self.norm(hidden_states)
1467
+
1468
+ # add hidden states from the last decoder layer
1469
+ if output_hidden_states:
1470
+ all_hidden_states += (hidden_states, )
1471
+
1472
+ next_cache = None
1473
+ if use_cache:
1474
+ next_cache = (next_decoder_cache.to_legacy_cache()
1475
+ if use_legacy_cache else next_decoder_cache)
1476
+ if not return_dict:
1477
+ return tuple(
1478
+ v for v in
1479
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
1480
+ if v is not None)
1481
+ return BaseModelOutputWithPast(
1482
+ last_hidden_state=hidden_states,
1483
+ past_key_values=next_cache,
1484
+ hidden_states=all_hidden_states,
1485
+ attentions=all_self_attns,
1486
+ )
1487
+
1488
+
1489
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1490
+ _tied_weights_keys = ["lm_head.weight"]
1491
+
1492
+ def __init__(self, config):
1493
+ super().__init__(config)
1494
+ self.model = DeepseekV3Model(config)
1495
+ self.vocab_size = config.vocab_size
1496
+ self.lm_head = nn.Linear(config.hidden_size,
1497
+ config.vocab_size,
1498
+ bias=False)
1499
+
1500
+ # Initialize weights and apply final processing
1501
+ self.post_init()
1502
+
1503
+ def get_input_embeddings(self):
1504
+ return self.model.embed_tokens
1505
+
1506
+ def set_input_embeddings(self, value):
1507
+ self.model.embed_tokens = value
1508
+
1509
+ def get_output_embeddings(self):
1510
+ return self.lm_head
1511
+
1512
+ def set_output_embeddings(self, new_embeddings):
1513
+ self.lm_head = new_embeddings
1514
+
1515
+ def set_decoder(self, decoder):
1516
+ self.model = decoder
1517
+
1518
+ def get_decoder(self):
1519
+ return self.model
1520
+
1521
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1522
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast,
1523
+ config_class=_CONFIG_FOR_DOC)
1524
+ def forward(
1525
+ self,
1526
+ input_ids: torch.LongTensor = None,
1527
+ attention_mask: Optional[torch.Tensor] = None,
1528
+ position_ids: Optional[torch.LongTensor] = None,
1529
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1530
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1531
+ labels: Optional[torch.LongTensor] = None,
1532
+ use_cache: Optional[bool] = None,
1533
+ output_attentions: Optional[bool] = None,
1534
+ output_hidden_states: Optional[bool] = None,
1535
+ return_dict: Optional[bool] = None,
1536
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1537
+ r"""
1538
+ Args:
1539
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1540
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1541
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1542
+ (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1543
+
1544
+ Returns:
1545
+
1546
+ Example:
1547
+
1548
+ ```python
1549
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
1550
+
1551
+ >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1552
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1553
+
1554
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1555
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1556
+
1557
+ >>> # Generate
1558
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1559
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1560
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1561
+ ```"""
1562
+ output_attentions = (output_attentions if output_attentions is not None
1563
+ else self.config.output_attentions)
1564
+ output_hidden_states = (output_hidden_states
1565
+ if output_hidden_states is not None else
1566
+ self.config.output_hidden_states)
1567
+ return_dict = (return_dict if return_dict is not None else
1568
+ self.config.use_return_dict)
1569
+
1570
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1571
+ outputs = self.model(
1572
+ input_ids=input_ids,
1573
+ attention_mask=attention_mask,
1574
+ position_ids=position_ids,
1575
+ past_key_values=past_key_values,
1576
+ inputs_embeds=inputs_embeds,
1577
+ use_cache=use_cache,
1578
+ output_attentions=output_attentions,
1579
+ output_hidden_states=output_hidden_states,
1580
+ return_dict=return_dict,
1581
+ )
1582
+
1583
+ hidden_states = outputs[0]
1584
+ logits = self.lm_head(hidden_states)
1585
+ logits = logits.float()
1586
+
1587
+ loss = None
1588
+ if labels is not None:
1589
+ # Shift so that tokens < n predict n
1590
+ shift_logits = logits[..., :-1, :].contiguous()
1591
+ shift_labels = labels[..., 1:].contiguous()
1592
+ # Flatten the tokens
1593
+ loss_fct = CrossEntropyLoss()
1594
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1595
+ shift_labels = shift_labels.view(-1)
1596
+ # Enable model parallelism
1597
+ shift_labels = shift_labels.to(shift_logits.device)
1598
+ loss = loss_fct(shift_logits, shift_labels)
1599
+
1600
+ if not return_dict:
1601
+ output = (logits, ) + outputs[1:]
1602
+ return (loss, ) + output if loss is not None else output
1603
+
1604
+ return CausalLMOutputWithPast(
1605
+ loss=loss,
1606
+ logits=logits,
1607
+ past_key_values=outputs.past_key_values,
1608
+ hidden_states=outputs.hidden_states,
1609
+ attentions=outputs.attentions,
1610
+ )
1611
+
1612
+ def prepare_inputs_for_generation(
1613
+ self,
1614
+ input_ids,
1615
+ past_key_values=None,
1616
+ attention_mask=None,
1617
+ inputs_embeds=None,
1618
+ **kwargs,
1619
+ ):
1620
+ if past_key_values is not None:
1621
+ if isinstance(past_key_values, Cache):
1622
+ cache_length = past_key_values.get_seq_length()
1623
+ # seen_tokens 可能在某些 transformers 版本中不存在,使用 getattr 安全访问
1624
+ past_length = getattr(past_key_values, 'seen_tokens',
1625
+ cache_length)
1626
+ max_cache_length = past_key_values.get_max_length()
1627
+ else:
1628
+ cache_length = past_length = past_key_values[0][0].shape[2]
1629
+ max_cache_length = None
1630
+
1631
+ # Keep only the unprocessed tokens:
1632
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1633
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1634
+ # input)
1635
+ if (attention_mask is not None
1636
+ and attention_mask.shape[1] > input_ids.shape[1]):
1637
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
1638
+ past_length):]
1639
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1640
+ # input_ids based on the past_length.
1641
+ elif past_length < input_ids.shape[1]:
1642
+ input_ids = input_ids[:, past_length:]
1643
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1644
+
1645
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1646
+ if (max_cache_length is not None and attention_mask is not None
1647
+ and cache_length + input_ids.shape[1] > max_cache_length):
1648
+ attention_mask = attention_mask[:, -max_cache_length:]
1649
+
1650
+ position_ids = kwargs.get("position_ids", None)
1651
+ if attention_mask is not None and position_ids is None:
1652
+ # create position_ids on the fly for batch generation
1653
+ position_ids = attention_mask.long().cumsum(-1) - 1
1654
+ position_ids.masked_fill_(attention_mask == 0, 1)
1655
+ if past_key_values:
1656
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1657
+
1658
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1659
+ if inputs_embeds is not None and past_key_values is None:
1660
+ model_inputs = {"inputs_embeds": inputs_embeds}
1661
+ else:
1662
+ model_inputs = {"input_ids": input_ids}
1663
+
1664
+ model_inputs.update({
1665
+ "position_ids": position_ids,
1666
+ "past_key_values": past_key_values,
1667
+ "use_cache": kwargs.get("use_cache"),
1668
+ "attention_mask": attention_mask,
1669
+ })
1670
+ return model_inputs
1671
+
1672
+ @staticmethod
1673
+ def _reorder_cache(past_key_values, beam_idx):
1674
+ reordered_past = ()
1675
+ for layer_past in past_key_values:
1676
+ reordered_past += (tuple(
1677
+ past_state.index_select(0, beam_idx.to(past_state.device))
1678
+ for past_state in layer_past), )
1679
+ return reordered_past
1680
+
1681
+
1682
+ @add_start_docstrings(
1683
+ """
1684
+ The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).
1685
+
1686
+ [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1687
+ (e.g. GPT-2) do.
1688
+
1689
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1690
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1691
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1692
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1693
+ each row of the batch).
1694
+ """,
1695
+ DeepseekV3_START_DOCSTRING,
1696
+ )
1697
+ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1698
+
1699
+ def __init__(self, config):
1700
+ super().__init__(config)
1701
+ self.num_labels = config.num_labels
1702
+ self.model = DeepseekV3Model(config)
1703
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1704
+
1705
+ # Initialize weights and apply final processing
1706
+ self.post_init()
1707
+
1708
+ def get_input_embeddings(self):
1709
+ return self.model.embed_tokens
1710
+
1711
+ def set_input_embeddings(self, value):
1712
+ self.model.embed_tokens = value
1713
+
1714
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1715
+ def forward(
1716
+ self,
1717
+ input_ids: torch.LongTensor = None,
1718
+ attention_mask: Optional[torch.Tensor] = None,
1719
+ position_ids: Optional[torch.LongTensor] = None,
1720
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1721
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1722
+ labels: Optional[torch.LongTensor] = None,
1723
+ use_cache: Optional[bool] = None,
1724
+ output_attentions: Optional[bool] = None,
1725
+ output_hidden_states: Optional[bool] = None,
1726
+ return_dict: Optional[bool] = None,
1727
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1728
+ r"""
1729
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1730
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1731
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1732
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1733
+ """
1734
+ return_dict = (return_dict if return_dict is not None else
1735
+ self.config.use_return_dict)
1736
+
1737
+ transformer_outputs = self.model(
1738
+ input_ids,
1739
+ attention_mask=attention_mask,
1740
+ position_ids=position_ids,
1741
+ past_key_values=past_key_values,
1742
+ inputs_embeds=inputs_embeds,
1743
+ use_cache=use_cache,
1744
+ output_attentions=output_attentions,
1745
+ output_hidden_states=output_hidden_states,
1746
+ return_dict=return_dict,
1747
+ )
1748
+ hidden_states = transformer_outputs[0]
1749
+ logits = self.score(hidden_states)
1750
+
1751
+ if input_ids is not None:
1752
+ batch_size = input_ids.shape[0]
1753
+ else:
1754
+ batch_size = inputs_embeds.shape[0]
1755
+
1756
+ if self.config.pad_token_id is None and batch_size != 1:
1757
+ raise ValueError(
1758
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1759
+ )
1760
+ if self.config.pad_token_id is None:
1761
+ sequence_lengths = -1
1762
+ else:
1763
+ if input_ids is not None:
1764
+ sequence_lengths = (torch.eq(
1765
+ input_ids, self.config.pad_token_id).int().argmax(-1) -
1766
+ 1).to(logits.device)
1767
+ else:
1768
+ sequence_lengths = -1
1769
+
1770
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device),
1771
+ sequence_lengths]
1772
+
1773
+ loss = None
1774
+ if labels is not None:
1775
+ labels = labels.to(logits.device)
1776
+ if self.config.problem_type is None:
1777
+ if self.num_labels == 1:
1778
+ self.config.problem_type = "regression"
1779
+ elif self.num_labels > 1 and (labels.dtype == torch.long
1780
+ or labels.dtype == torch.int):
1781
+ self.config.problem_type = "single_label_classification"
1782
+ else:
1783
+ self.config.problem_type = "multi_label_classification"
1784
+
1785
+ if self.config.problem_type == "regression":
1786
+ loss_fct = MSELoss()
1787
+ if self.num_labels == 1:
1788
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1789
+ else:
1790
+ loss = loss_fct(pooled_logits, labels)
1791
+ elif self.config.problem_type == "single_label_classification":
1792
+ loss_fct = CrossEntropyLoss()
1793
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels),
1794
+ labels.view(-1))
1795
+ elif self.config.problem_type == "multi_label_classification":
1796
+ loss_fct = BCEWithLogitsLoss()
1797
+ loss = loss_fct(pooled_logits, labels)
1798
+ if not return_dict:
1799
+ output = (pooled_logits, ) + transformer_outputs[1:]
1800
+ return ((loss, ) + output) if loss is not None else output
1801
+
1802
+ return SequenceClassifierOutputWithPast(
1803
+ loss=loss,
1804
+ logits=pooled_logits,
1805
+ past_key_values=transformer_outputs.past_key_values,
1806
+ hidden_states=transformer_outputs.hidden_states,
1807
+ attentions=transformer_outputs.attentions,
1808
+ )
modeling_kimi_k25.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025-2026 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for Kimi-K2.5.
5
+ #
6
+ # Licensing Information:
7
+ # - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
8
+ # - Other parts of the code are licensed under the MIT License.
9
+ #
10
+ # Apache License, Version 2.0:
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ #
23
+ # MIT License:
24
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
25
+ # of this software and associated documentation files (the "Software"), to deal
26
+ # in the Software without restriction, including without limitation the rights
27
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28
+ # copies of the Software, and to permit persons to whom the Software is
29
+ # furnished to do so, subject to the following conditions:
30
+ #
31
+ # The above copyright notice and this permission notice shall be included in all
32
+ # copies or substantial portions of the Software.
33
+ #
34
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40
+ # SOFTWARE.
41
+ import math
42
+ from collections.abc import Sequence
43
+ from copy import deepcopy
44
+ from typing import Optional
45
+
46
+ import numpy as np
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ from transformers import activations
51
+
52
+ try:
53
+ from transformers.activations import PytorchGELUTanh
54
+ except ImportError:
55
+ from transformers.activations import GELUTanh
56
+ activations.PytorchGELUTanh = GELUTanh
57
+ PytorchGELUTanh = GELUTanh
58
+ from transformers.activations import PytorchGELUTanh
59
+ from transformers.cache_utils import Cache
60
+ from transformers.configuration_utils import PretrainedConfig
61
+ from transformers.modeling_utils import PreTrainedModel
62
+ from transformers.models.llava.modeling_llava import \
63
+ LlavaCausalLMOutputWithPast
64
+ from transformers.utils import is_flash_attn_2_available
65
+
66
+ from .configuration_kimi_k25 import KimiK25Config
67
+ from .modeling_deepseek import DeepseekV3ForCausalLM
68
+
69
+ # Flash attention imports
70
+ if is_flash_attn_2_available():
71
+ from flash_attn import flash_attn_varlen_func
72
+ else:
73
+ flash_attn_varlen_func = None
74
+
75
+
76
+ def multihead_attention(
77
+ q: torch.Tensor,
78
+ k: torch.Tensor,
79
+ v: torch.Tensor,
80
+ q_cu_seqlens: torch.Tensor | None = None,
81
+ k_cu_seqlens: torch.Tensor | None = None,
82
+ max_seqlen_q: int | None = None,
83
+ max_seqlen_k: int | None = None,
84
+ deterministic: bool = False,
85
+ ):
86
+ """Multi-head attention using flash attention 2.
87
+
88
+ Args:
89
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
90
+ or (tot_seqlens, num_heads, head_dim) if packing.
91
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
92
+ The first element should be 0 and the last element should be q.shape[0].
93
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
94
+ The first element should be 0 and the last element should be k.shape[0].
95
+
96
+ Returns:
97
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
98
+ where dim = num_heads * head_dim
99
+ """
100
+ attn_out = flash_attn_varlen_func(
101
+ q,
102
+ k,
103
+ v,
104
+ q_cu_seqlens,
105
+ k_cu_seqlens,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ causal=False,
109
+ deterministic=deterministic,
110
+ )
111
+ if isinstance(attn_out, tuple):
112
+ attn_out = attn_out[0]
113
+
114
+ attn_out = attn_out.flatten(start_dim=-2)
115
+
116
+ return attn_out
117
+
118
+
119
+ def eager_attention(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ q_cu_seqlens: Optional[torch.Tensor] = None,
124
+ k_cu_seqlens: Optional[torch.Tensor] = None,
125
+ **kwargs,
126
+ ) -> torch.Tensor:
127
+ seq_length = q.shape[0]
128
+ attention_mask = torch.zeros([1, seq_length, seq_length],
129
+ device=q.device,
130
+ dtype=torch.bool)
131
+ for i in range(1, len(q_cu_seqlens)):
132
+ attention_mask[
133
+ ...,
134
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
135
+ q_cu_seqlens[i - 1]:q_cu_seqlens[i],
136
+ ] = True
137
+ q = q.transpose(0, 1)
138
+ k = k.transpose(0, 1)
139
+ v = v.transpose(0, 1)
140
+
141
+ attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
142
+ attn_weight += attention_mask
143
+ attn_weight = torch.softmax(attn_weight, dim=-1,
144
+ dtype=torch.float32).to(q.dtype)
145
+
146
+ attn_output = attn_weight @ v
147
+ attn_output = attn_output.transpose(0, 1)
148
+ attn_output = attn_output.reshape(seq_length, -1)
149
+ return attn_output
150
+
151
+
152
+ VL_VISION_ATTENTION_FUNCTIONS = {
153
+ "flash_attention_2": multihead_attention,
154
+ "eager": eager_attention,
155
+ }
156
+
157
+
158
+ def _apply_rope_input_validation(x, freqs_cis):
159
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
160
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
161
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
162
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
163
+
164
+
165
+ def get_rope_shape_decorate(func):
166
+ _get_rope_shape_first_call_flag = set()
167
+
168
+ def wrapper(org, interpolation_mode, shape):
169
+ key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode)
170
+ if key not in _get_rope_shape_first_call_flag:
171
+ _get_rope_shape_first_call_flag.add(key)
172
+ _ = func(org, interpolation_mode, shape=(64, 64))
173
+ return func(org, interpolation_mode, shape)
174
+
175
+ return wrapper
176
+
177
+
178
+ @get_rope_shape_decorate
179
+ @torch.compile(dynamic=True)
180
+ def get_rope_shape(org, interpolation_mode, shape):
181
+ return (F.interpolate(
182
+ org.permute((2, 0, 1)).unsqueeze(0),
183
+ size=shape,
184
+ mode=interpolation_mode,
185
+ ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
186
+
187
+
188
+ def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
189
+ freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Args: (The leading dimensions of all inputs should be the same)
192
+ xq: query, tensor of shape (..., num_heads, head_dim)
193
+ xk: key, tensor of shape (..., num_heads, head_dim)
194
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
195
+ Returns:
196
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
197
+ """
198
+ _apply_rope_input_validation(xq, freqs_cis)
199
+ _apply_rope_input_validation(xk, freqs_cis)
200
+
201
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
202
+ # ..., num_heads, head_dim/2
203
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
204
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
205
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
206
+ -2) # ..., num_heads, head_dim
207
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
208
+ -2) # ..., num_heads, head_dim
209
+ return xq_out.type_as(xq), xk_out.type_as(xk)
210
+
211
+
212
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
213
+ """
214
+ From:
215
+ https://github.com/OpenGVLab/InternVideo/blob/421f6d2361fc8f61a3394244571f2601a4e99e29/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py#L86
216
+ embed_dim: output dimension for each position
217
+ pos: a list of positions to be encoded: size (M,)
218
+ out: (M, D)
219
+ """
220
+ assert embed_dim % 2 == 0
221
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
222
+ omega /= embed_dim / 2.0
223
+ omega = 1.0 / 10000**omega # (D/2,)
224
+
225
+ pos = pos.reshape(-1) # (M,)
226
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
227
+
228
+ emb_sin = np.sin(out) # (M, D/2)
229
+ emb_cos = np.cos(out) # (M, D/2)
230
+
231
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
232
+ return emb
233
+
234
+
235
+ def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
236
+ """
237
+ t_size: int of the temporal size
238
+ return:
239
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
240
+ """
241
+ grid_t = np.arange(t_size, dtype=np.float32)
242
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
243
+ if cls_token:
244
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
245
+ axis=0)
246
+ return pos_embed
247
+
248
+
249
+ class Learnable2DInterpPosEmbDivided_fixed(nn.Module):
250
+
251
+ def __init__(self,
252
+ height: int,
253
+ width: int,
254
+ num_frames: int,
255
+ dim: int,
256
+ interpolation_mode: str = 'bicubic') -> None:
257
+ super().__init__()
258
+ self.height = height
259
+ self.width = width
260
+ self.num_frames = num_frames
261
+ self.dim = dim
262
+ self.interpolation_mode = interpolation_mode
263
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
264
+ self.register_buffer('time_weight',
265
+ torch.from_numpy(
266
+ get_1d_sincos_pos_embed(
267
+ self.dim,
268
+ self.num_frames)).float().unsqueeze(1),
269
+ persistent=False)
270
+
271
+ self.reset_parameters()
272
+
273
+ def reset_parameters(self):
274
+ nn.init.normal_(self.weight)
275
+
276
+ def forward(self, x: torch.Tensor,
277
+ grid_thws: torch.Tensor) -> torch.Tensor:
278
+ pos_embs = []
279
+ for t, h, w in grid_thws.tolist():
280
+ assert t <= self.num_frames, f't:{t} > self.num_frames:{self.num_frames}'
281
+ if (h, w) == self.weight.shape[:-1]:
282
+ pos_emb_2d = self.weight.flatten(end_dim=1)
283
+ else:
284
+ pos_emb_2d = get_rope_shape(
285
+ self.weight,
286
+ interpolation_mode=self.interpolation_mode,
287
+ shape=(h, w),
288
+ )
289
+
290
+ if t == 1:
291
+ pos_emb_3d = pos_emb_2d
292
+ else:
293
+ pos_emb_3d = pos_emb_2d.unsqueeze(0).repeat(
294
+ t, 1, 1) + self.time_weight[0:t]
295
+
296
+ pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1]))
297
+
298
+ out = x + torch.cat(pos_embs)
299
+ return out
300
+
301
+
302
+ class MoonVision3dPatchEmbed(nn.Module):
303
+
304
+ def __init__(self,
305
+ out_dim: int,
306
+ in_dim: int = 3,
307
+ patch_size: int | tuple[int, int] = (14, 14),
308
+ pos_emb_height: int = 14,
309
+ pos_emb_width: int = 14,
310
+ pos_emb_time: int = 4,
311
+ pos_emb_type: str = 'divided_fixed'):
312
+ super().__init__()
313
+ assert isinstance(
314
+ patch_size,
315
+ int | Sequence), f'Invalid patch_size type: {type(patch_size)}'
316
+ if isinstance(patch_size, int):
317
+ patch_size = (patch_size, patch_size)
318
+ assert (len(patch_size) == 2
319
+ ), f'Expected patch_size to be a tuple of 2, got {patch_size}'
320
+ self.patch_size = patch_size
321
+
322
+ self.proj = nn.Conv2d(in_dim,
323
+ out_dim,
324
+ kernel_size=patch_size,
325
+ stride=patch_size)
326
+
327
+ if pos_emb_type == 'divided_fixed':
328
+ self.pos_emb = Learnable2DInterpPosEmbDivided_fixed(
329
+ height=pos_emb_height,
330
+ width=pos_emb_width,
331
+ num_frames=pos_emb_time,
332
+ dim=out_dim)
333
+ else:
334
+ raise NotImplementedError(
335
+ f'Not support pos_emb_type: {pos_emb_type}')
336
+
337
+ def forward(self, x: torch.Tensor,
338
+ grid_thws: torch.Tensor) -> torch.Tensor:
339
+ """
340
+ Args:
341
+ x (L, Channels): input tensor
342
+ grid_hws (N, 3): temporal, height and width
343
+
344
+ Returns:
345
+ (L, Cout) tensor
346
+ """
347
+ x = self.proj(x).view(x.size(0), -1)
348
+ # apply positional embedding
349
+ x = self.pos_emb(x, grid_thws)
350
+ return x
351
+
352
+
353
+ class Rope2DPosEmbRepeated(nn.Module):
354
+ """2D rotary position embedding with multi-resolution support.
355
+
356
+ This class is intended to be used in the following way:
357
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
358
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
359
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
360
+ The rope is shared across all attention layers and all heads.
361
+
362
+ Refs:
363
+ - RoFormer: https://arxiv.org/abs/2104.09864
364
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
365
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
366
+
367
+ Args:
368
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
369
+ max_height (int): the maximum height of the 2D grid
370
+ max_width (int): the maximum width of the 2D grid
371
+ theta_base (float): the base of the theta
372
+ device (str): the device to store the precomputed cis
373
+ """
374
+
375
+ def __init__(self,
376
+ dim: int,
377
+ max_height: int,
378
+ max_width: int,
379
+ theta_base=10000):
380
+ super().__init__()
381
+ self.dim = dim
382
+ assert self.dim % 4 == 0, 'dim must be divisible by 4'
383
+ self.max_height = max_height
384
+ self.max_width = max_width
385
+ self.theta_base = theta_base
386
+
387
+ def extra_repr(self):
388
+ return f'dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}'
389
+
390
+ def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
391
+ """Calculate the cis(freqs) for each position in the 2D grid.
392
+
393
+ Return: complex tensor of shape (max_height, max_width, dim//2) and value:
394
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
395
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
396
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
397
+ """
398
+ N = self.max_height * self.max_width
399
+ flat_pos = torch.arange(0, N).float().to(device)
400
+ x_pos = flat_pos % self.max_width
401
+ y_pos = flat_pos // self.max_width
402
+ dim_range = (torch.arange(0, self.dim,
403
+ 4)[:(self.dim // 4)].float().to(device)
404
+ ) # C/4
405
+ freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
406
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
407
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
408
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
409
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
410
+ # N, C/4, 2
411
+ freqs_cis = torch.cat(
412
+ [x_cis.unsqueeze(dim=-1),
413
+ y_cis.unsqueeze(dim=-1)], dim=-1)
414
+ # max_height, max_width, C/2
415
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
416
+ return freqs_cis
417
+
418
+ def get_freqs_cis(self, grid_thws: torch.Tensor,
419
+ device: torch.device) -> torch.Tensor:
420
+ """
421
+ Args:
422
+ grid_thws (torch.Tensor): grid time, height and width
423
+
424
+ Returns:
425
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
426
+ """
427
+ if not hasattr(self, 'freqs_cis'):
428
+ self.register_buffer('freqs_cis',
429
+ self._precompute_freqs_cis(device),
430
+ persistent=False)
431
+
432
+ shapes = grid_thws.tolist()
433
+ assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
434
+ for t, h, w in shapes), (
435
+ shapes,
436
+ self.max_height,
437
+ self.max_width,
438
+ )
439
+ freqs_cis = torch.cat(
440
+ [
441
+ self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1)
442
+ for t, h, w in shapes
443
+ ],
444
+ dim=0,
445
+ )
446
+ return freqs_cis
447
+
448
+
449
+ class MLP2(nn.Module):
450
+ """
451
+ Args:
452
+ dims: [in_dim, hidden_dim, out_dim]
453
+ bias: whether to use bias in linear layer.
454
+ """
455
+
456
+ def __init__(self, dims: list[int], activation, bias=True):
457
+ super().__init__()
458
+ assert len(dims) == 3
459
+ self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
460
+ self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
461
+ self.activation = activation
462
+ for m in [self.fc0, self.fc1]:
463
+ nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
464
+ if m.bias is not None:
465
+ nn.init.zeros_(m.bias)
466
+
467
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
468
+ x = self.fc0(x)
469
+ x = self.activation(x)
470
+ return self.fc1(x)
471
+
472
+
473
+ class MoonViTEncoderLayer(nn.Module):
474
+
475
+ def __init__(
476
+ self,
477
+ num_heads: int,
478
+ hidden_dim: int,
479
+ mlp_dim: int,
480
+ *,
481
+ attn_implementation: str = 'flash_attention_2',
482
+ activation=F.gelu,
483
+ attn_bias: bool = False,
484
+ use_deterministic_attn: bool = False,
485
+ ):
486
+ super().__init__()
487
+ self.num_heads = num_heads
488
+ self.hidden_dim = hidden_dim
489
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
490
+ self.attn_implementation = attn_implementation
491
+ self.use_deterministic_attn = use_deterministic_attn
492
+
493
+ self.norm0 = nn.LayerNorm(hidden_dim)
494
+ self.norm1 = nn.LayerNorm(hidden_dim)
495
+ self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
496
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
497
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
498
+
499
+ def attention_qkvpacked(
500
+ self,
501
+ x: torch.Tensor,
502
+ cu_seqlens: torch.Tensor,
503
+ max_seqlen: torch.Tensor,
504
+ rope_freqs_cis: torch.Tensor | None = None,
505
+ ):
506
+ """
507
+ Args:
508
+ x (torch.Tensor): (batch_size, seqlen, hidden_dim)
509
+ cu_seqlens (torch.Tensor):
510
+ """
511
+ xqkv = self.wqkv(x)
512
+
513
+ qkv_shape = xqkv.size()[:-1] + (
514
+ 3,
515
+ self.num_heads,
516
+ self.hidden_size_per_attention_head,
517
+ )
518
+ # xqkv: (batch_size, seqlen, 3, nheads, headdim)
519
+ xqkv = xqkv.view(*qkv_shape)
520
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
521
+
522
+ xq, xk = apply_rope(xq, xk, rope_freqs_cis)
523
+
524
+ attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
525
+ attn_out = attn_func(xq,
526
+ xk,
527
+ xv,
528
+ q_cu_seqlens=cu_seqlens,
529
+ k_cu_seqlens=cu_seqlens,
530
+ max_seqlen_k=max_seqlen,
531
+ max_seqlen_q=max_seqlen,
532
+ deterministic=self.use_deterministic_attn)
533
+
534
+ attn_out = self.wo(attn_out)
535
+ return attn_out
536
+
537
+ def forward(
538
+ self,
539
+ hidden_states: torch.Tensor,
540
+ cu_seqlens: torch.Tensor,
541
+ max_seqlen: int,
542
+ rope_freqs_cis: torch.Tensor | None = None,
543
+ ):
544
+ residual = hidden_states
545
+ hidden_states = self.norm0(hidden_states)
546
+
547
+ hidden_states = self.attention_qkvpacked(hidden_states, cu_seqlens,
548
+ max_seqlen, rope_freqs_cis)
549
+ hidden_states = residual + hidden_states
550
+
551
+ residual = hidden_states
552
+ hidden_states = self.norm1(hidden_states)
553
+ hidden_states = self.mlp(hidden_states)
554
+ hidden_states = residual + hidden_states
555
+
556
+ return hidden_states
557
+
558
+
559
+ class MoonViT3dEncoder(nn.Module):
560
+
561
+ def __init__(self,
562
+ hidden_dim: int,
563
+ num_layers: int,
564
+ block_cfg: dict,
565
+ video_attn_type: str = 'spatial_temporal') -> None:
566
+ super().__init__()
567
+
568
+ assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
569
+ self.video_attn_type = video_attn_type
570
+ self.rope_2d = Rope2DPosEmbRepeated(
571
+ block_cfg['hidden_dim'] // block_cfg['num_heads'], 512, 512)
572
+ self.blocks = nn.ModuleList([
573
+ MoonViTEncoderLayer(
574
+ **block_cfg,
575
+ )
576
+ for _ in range(num_layers)
577
+ ])
578
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
579
+
580
+ def forward(
581
+ self,
582
+ hidden_states: torch.Tensor,
583
+ grid_thws: torch.Tensor,
584
+ ) -> torch.Tensor:
585
+ rope_freqs_cis = self.rope_2d.get_freqs_cis(
586
+ grid_thws=grid_thws, device=hidden_states.device)
587
+
588
+ lengths = torch.cat((
589
+ torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device),
590
+ grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2],
591
+ ))
592
+
593
+ max_seqlen = lengths.max()
594
+ cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0,
595
+ dtype=torch.int32)
596
+ for block in self.blocks:
597
+ hidden_states = block(hidden_states,
598
+ cu_seqlens,
599
+ max_seqlen,
600
+ rope_freqs_cis=rope_freqs_cis)
601
+
602
+ hidden_states = self.final_layernorm(hidden_states)
603
+ return hidden_states
604
+
605
+
606
+ def tpool_patch_merger(
607
+ x: torch.Tensor,
608
+ grid_thws: torch.Tensor,
609
+ merge_kernel_size: tuple[int, int] = (2, 2),
610
+ ) -> list[torch.Tensor]:
611
+ d_model = x.size(-1)
612
+
613
+ outputs = []
614
+ pre_sum = 0
615
+ for t, h, w in grid_thws.tolist():
616
+ # Get the current sequence
617
+ seq = x[pre_sum:pre_sum + t * h * w]
618
+ # Reshape along self.merge_kernel_size and concat to the last dimension
619
+ kernel_height, kernel_width = merge_kernel_size
620
+ new_height, new_width = h // kernel_height, w // kernel_width
621
+ reshaped_seq = seq.view(t, new_height, kernel_height, new_width,
622
+ kernel_width, d_model)
623
+ reshaped_seq = reshaped_seq.permute(0, 1,
624
+ 3, 2, 4, 5).contiguous().mean(
625
+ dim=0) # temporal pooling
626
+ padded_seq = reshaped_seq.view(new_height * new_width,
627
+ kernel_height * kernel_width, -1)
628
+ outputs.append(padded_seq)
629
+ pre_sum += t * h * w
630
+
631
+ return outputs
632
+
633
+
634
+ class MoonViT3dPretrainedModel(PreTrainedModel):
635
+ config_class = None
636
+ model_type = 'moonvit3d'
637
+ _no_split_modules = ['PackingTransformer']
638
+ _supports_flash_attn_2 = True
639
+ _supports_sdpa = True
640
+
641
+ def __init__(self, config, *inputs, **kwargs):
642
+ super().__init__(config, *inputs, **kwargs)
643
+ config = deepcopy(config)
644
+ self.merge_kernel_size = config.merge_kernel_size
645
+ self.patch_size = config.patch_size
646
+ self.merge_type = config.merge_type
647
+
648
+ self.patch_embed = MoonVision3dPatchEmbed(
649
+ out_dim=config.hidden_size,
650
+ patch_size=config.patch_size,
651
+ pos_emb_height=config.init_pos_emb_height,
652
+ pos_emb_width=config.init_pos_emb_width,
653
+ pos_emb_time=config.init_pos_emb_time,
654
+ pos_emb_type=config.pos_emb_type,
655
+ )
656
+
657
+ self.encoder = MoonViT3dEncoder(hidden_dim=config.hidden_size,
658
+ num_layers=config.num_hidden_layers,
659
+ block_cfg={
660
+ 'num_heads':
661
+ config.num_attention_heads,
662
+ 'hidden_dim':
663
+ config.hidden_size,
664
+ 'mlp_dim':
665
+ config.intermediate_size,
666
+ 'activation':
667
+ PytorchGELUTanh(),
668
+ 'attn_bias':
669
+ True,
670
+ 'attn_implementation':
671
+ config._attn_implementation,
672
+ },
673
+ video_attn_type=config.video_attn_type)
674
+
675
+ def forward(self, pixel_values: torch.Tensor,
676
+ grid_thws: torch.Tensor) -> torch.Tensor:
677
+ """
678
+ Args:
679
+ pixel_values (torch.Tensor): The input pixel values.
680
+ grid_thws (torch.Tensor): Temporal, height and width.
681
+
682
+ Returns:
683
+ torch.Tensor: The output tokens.
684
+ """
685
+ # grid_thws = grid_thws.to('cpu')
686
+ assert grid_thws.ndim == 2, f'grid_thws should be 2D, got {grid_thws.ndim}'
687
+ assert grid_thws.size(1) == 3, f'No support for thw: {grid_thws}'
688
+ hidden_states = self.patch_embed(pixel_values, grid_thws)
689
+ hidden_states = self.encoder(hidden_states, grid_thws)
690
+ if self.merge_type == 'sd2_tpool': # spatial downsampling 2x with temporal pooling all
691
+ hidden_states = tpool_patch_merger(
692
+ hidden_states,
693
+ grid_thws,
694
+ merge_kernel_size=self.merge_kernel_size)
695
+ else:
696
+ raise NotImplementedError(f'Not support {self.merge_type}')
697
+
698
+ return hidden_states
699
+
700
+
701
+ # ============================================================================
702
+ # MM Projector Helper Classes (from mm_projector/modeling_mm_projectors.py)
703
+ # ============================================================================
704
+
705
+
706
+ class IdentityMap(nn.Module):
707
+
708
+ def __init__(self):
709
+ super().__init__()
710
+
711
+ def forward(self, x, *args, **kwargs):
712
+ return x
713
+
714
+
715
+ class MLP(nn.Module):
716
+
717
+ def __init__(self, config):
718
+ super().__init__()
719
+ # TODO, use faster LayerNorm
720
+ self.pre_norm = nn.LayerNorm(config.mm_hidden_size)
721
+ self.proj = nn.Sequential(
722
+ nn.Linear(config.mm_hidden_size, config.hidden_size), nn.GELU(),
723
+ nn.Linear(config.hidden_size, config.hidden_size))
724
+
725
+ def forward(self, x, *args, **kwargs):
726
+ assert isinstance(x,
727
+ list | tuple), f'x is not a list or tuple: {type(x)}'
728
+ lengths = [item.shape[0] for item in x]
729
+ x = torch.cat(x, dim=0)
730
+ x = self.pre_norm(x)
731
+ x = self.proj(x)
732
+ x = torch.split(x, lengths, dim=0)
733
+
734
+ return x
735
+
736
+
737
+ class PatchMergerMLP(nn.Module):
738
+
739
+ def __init__(self, config):
740
+ super().__init__()
741
+ eps = config.projector_ln_eps
742
+ self.hidden_size = config.mm_hidden_size * (
743
+ config.merge_kernel_size[0] * config.merge_kernel_size[1])
744
+ self.pre_norm = nn.LayerNorm(config.mm_hidden_size, eps=eps)
745
+ self.proj = nn.Sequential(
746
+ nn.Linear(self.hidden_size, self.hidden_size),
747
+ nn.GELU(),
748
+ nn.Linear(self.hidden_size, config.hidden_size),
749
+ )
750
+
751
+ def forward(self, x, *args, **kwargs):
752
+ if isinstance(x, list) or isinstance(x, tuple):
753
+ x = [
754
+ self.proj(self.pre_norm(item).view(item.shape[0], -1))
755
+ for item in x
756
+ ]
757
+ else:
758
+ # B, N, N_k, C = x.shape
759
+ B = x.shape[0]
760
+ x = self.proj(self.pre_norm(x).view(B, -1, self.hidden_size))
761
+ return x
762
+
763
+
764
+ class KimiK25PreTrainedModel(PreTrainedModel):
765
+ config_class = KimiK25Config
766
+ base_model_prefix = "model"
767
+ _no_split_modules = [
768
+ "MoonViT3dPretrainedModel",
769
+ "MoonViTEncoderLayer",
770
+ "DeepseekDecoderLayer",
771
+ "PatchMergerMLP",
772
+ ]
773
+ _skip_keys_device_placement = "past_key_values"
774
+ _supports_flash_attn_2 = True
775
+ _supports_sdpa = False
776
+
777
+ def _init_weights(self, module):
778
+ # important: this ported version of Llava isn't meant for training from scratch - only
779
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
780
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
781
+ std = (self.config.initializer_range if hasattr(
782
+ self.config, "initializer_range") else
783
+ self.config.text_config.initializer_range)
784
+
785
+ if hasattr(module, "class_embedding"):
786
+ module.class_embedding.data.normal_(mean=0.0, std=std)
787
+
788
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
789
+ module.weight.data.normal_(mean=0.0, std=std)
790
+ if module.bias is not None:
791
+ module.bias.data.zero_()
792
+ elif isinstance(module, nn.Embedding):
793
+ module.weight.data.normal_(mean=0.0, std=std)
794
+ if module.padding_idx is not None:
795
+ module.weight.data[module.padding_idx].zero_()
796
+
797
+
798
+ class VisionTowerConfig(PretrainedConfig):
799
+ model_type = 'moonvit3d'
800
+
801
+ def __init__(self, config: KimiK25Config, **kwargs):
802
+ super().__init__(**kwargs)
803
+ self.patch_size = config.patch_size
804
+ self.init_pos_emb_height = config.init_pos_emb_height
805
+ self.init_pos_emb_width = config.init_pos_emb_width
806
+ self.init_pos_emb_time = config.init_pos_emb_time
807
+ self.pos_emb_type = config.pos_emb_type
808
+ self.num_attention_heads = config.vt_num_attention_heads
809
+ self.num_hidden_layers = config.vt_num_hidden_layers
810
+ self.hidden_size = config.vt_hidden_size
811
+ self.intermediate_size = config.vt_intermediate_size
812
+ self.merge_kernel_size = config.merge_kernel_size
813
+ self.video_attn_type = config.video_attn_type
814
+ self.merge_type = config.merge_type
815
+ self._attn_implementation = config._attn_implementation
816
+
817
+
818
+ class ProjectorConfig:
819
+
820
+ def __init__(self, config: KimiK25Config):
821
+ self.mm_projector_type = config.mm_projector_type
822
+ self.mm_hidden_size = config.mm_hidden_size
823
+ self.hidden_size = config.text_hidden_size
824
+ self.merge_kernel_size = config.merge_kernel_size
825
+ self.projector_hidden_act = config.projector_hidden_act
826
+ self.projector_ln_eps = config.projector_ln_eps
827
+
828
+
829
+ # ref https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/llava/modeling_llava.py#L240
830
+ class KimiK25ForConditionalGeneration(KimiK25PreTrainedModel):
831
+
832
+ def __init__(self, config: KimiK25Config):
833
+ super().__init__(config)
834
+
835
+ vt_config = VisionTowerConfig(config.vision_config)
836
+ self.vision_tower = MoonViT3dPretrainedModel(vt_config)
837
+
838
+ proj_config = ProjectorConfig(config.vision_config)
839
+ if proj_config.mm_projector_type == 'identity':
840
+ self.mm_projector = IdentityMap()
841
+ elif proj_config.mm_projector_type == 'mlp':
842
+ self.mm_projector = MLP(proj_config)
843
+ elif proj_config.mm_projector_type == 'patchmerger':
844
+ self.mm_projector = PatchMergerMLP(proj_config)
845
+ else:
846
+ raise ValueError(
847
+ f"Unsupported mm_projector_type: {proj_config.mm_projector_type}"
848
+ )
849
+
850
+ self.language_model = DeepseekV3ForCausalLM(config.text_config)
851
+ self.post_init()
852
+
853
+ if hasattr(self.language_model, 'dtype'):
854
+ target_dtype = self.language_model.dtype
855
+ self.vision_tower = self.vision_tower.to(dtype=target_dtype)
856
+ self.mm_projector = self.mm_projector.to(dtype=target_dtype)
857
+
858
+ def get_input_embeddings(self):
859
+ return self.language_model.get_input_embeddings()
860
+
861
+ def set_input_embeddings(self, value):
862
+ self.language_model.set_input_embeddings(value)
863
+
864
+ def get_output_embeddings(self):
865
+ return self.language_model.get_output_embeddings()
866
+
867
+ def set_output_embeddings(self, new_embeddings):
868
+ self.language_model.set_output_embeddings(new_embeddings)
869
+
870
+ def set_decoder(self, decoder):
871
+ self.language_model.set_decoder(decoder)
872
+
873
+ def get_decoder(self):
874
+ return self.language_model.get_decoder()
875
+
876
+ def tie_weights(self):
877
+ return self.language_model.tie_weights()
878
+
879
+ def resize_token_embeddings(self,
880
+ new_num_tokens: int | None = None,
881
+ pad_to_multiple_of=None) -> nn.Embedding:
882
+ model_embeds = self.language_model.resize_token_embeddings(
883
+ new_num_tokens, pad_to_multiple_of)
884
+ # update vocab size
885
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
886
+ self.vocab_size = model_embeds.num_embeddings
887
+ return model_embeds
888
+
889
+ def _merge_input_ids_with_image_features(
890
+ self,
891
+ image_features: list[torch.Tensor],
892
+ inputs_embeds: torch.Tensor,
893
+ input_ids: torch.Tensor,
894
+ attention_mask: torch.Tensor,
895
+ labels: torch.Tensor | None = None,
896
+ ):
897
+ """
898
+ Args:
899
+ image_features (:obj:`torch.Tensor` of shape :obj:`(num_image_tokens, embed_dim)`):
900
+ The image features to merge with the input embeddings.
901
+ inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, embed_dim)`):
902
+ The input embeddings.
903
+ input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
904
+ The input ids.
905
+ attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
906
+ The attention mask.
907
+ labels (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, *optional*):
908
+ The labels.
909
+ """
910
+ _, embed_dim = image_features[0].shape
911
+ feature_lengths = [x.shape[0] for x in image_features]
912
+ image_features = torch.cat(image_features, dim=0)
913
+
914
+ image_token_index: int = self.config.media_placeholder_token_id
915
+ pad_token_id: int = self.config.pad_token_id
916
+ ignore_index: int = self.config.ignore_index
917
+
918
+ batch_size, sequence_length = input_ids.shape
919
+ left_padding = not torch.sum(
920
+ input_ids[:, -1] == torch.tensor(pad_token_id))
921
+
922
+ # 1. Create a mask to know where special image tokens are
923
+ _token_occupation_table = torch.ones_like(input_ids.flatten())
924
+ _token_occupation_table[input_ids.flatten() ==
925
+ image_token_index] = torch.tensor(
926
+ feature_lengths,
927
+ dtype=torch.long,
928
+ device=input_ids.device)
929
+ _token_occupation_table = _token_occupation_table.reshape(
930
+ input_ids.shape)
931
+
932
+ max_embed_dim = _token_occupation_table.sum(-1).max().item()
933
+ assert (
934
+ max_embed_dim >= sequence_length
935
+ ), f"The maximum embedding dimension ({max_embed_dim}) is less than the sequence length ({sequence_length})"
936
+ batch_indices, non_image_indices = torch.where(
937
+ input_ids != image_token_index)
938
+
939
+ # 2. Compute the positions where text should be written
940
+ # Calculate new positions for text tokens in merged image-text sequence.
941
+ new_token_positions = torch.cumsum(_token_occupation_table, -1) - 1
942
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
943
+ if left_padding:
944
+ new_token_positions += nb_image_pad[:,
945
+ None] # offset for left padding
946
+ text_to_overwrite = new_token_positions[batch_indices,
947
+ non_image_indices]
948
+
949
+ # 3. Create the full embedding, already padded to the maximum position
950
+ final_embedding = torch.zeros(
951
+ batch_size,
952
+ max_embed_dim,
953
+ embed_dim,
954
+ dtype=inputs_embeds.dtype,
955
+ device=inputs_embeds.device,
956
+ )
957
+ final_attention_mask = torch.zeros(batch_size,
958
+ max_embed_dim,
959
+ dtype=attention_mask.dtype,
960
+ device=inputs_embeds.device)
961
+ if labels is not None:
962
+ final_labels = torch.full(
963
+ (batch_size, max_embed_dim),
964
+ ignore_index,
965
+ dtype=input_ids.dtype,
966
+ device=input_ids.device,
967
+ )
968
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
969
+ # set the corresponding tensors into their correct target device.
970
+ target_device = inputs_embeds.device
971
+ batch_indices, non_image_indices, text_to_overwrite = (
972
+ batch_indices.to(target_device),
973
+ non_image_indices.to(target_device),
974
+ text_to_overwrite.to(target_device),
975
+ )
976
+ attention_mask = attention_mask.to(target_device)
977
+
978
+ # 4. Fill the embeddings based on the mask.
979
+ final_embedding[batch_indices,
980
+ text_to_overwrite] = inputs_embeds[batch_indices,
981
+ non_image_indices]
982
+ final_attention_mask[batch_indices,
983
+ text_to_overwrite] = attention_mask[
984
+ batch_indices, non_image_indices]
985
+ if labels is not None:
986
+ final_labels[batch_indices,
987
+ text_to_overwrite] = labels[batch_indices,
988
+ non_image_indices]
989
+
990
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
991
+ image_to_overwrite = torch.full((batch_size, max_embed_dim),
992
+ True,
993
+ dtype=torch.bool,
994
+ device=inputs_embeds.device)
995
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
996
+ image_to_overwrite &= image_to_overwrite.cumsum(
997
+ -1) - 1 >= nb_image_pad[:, None].to(target_device)
998
+
999
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
1000
+ raise ValueError(
1001
+ f"The input provided to the model are wrong. The number of image tokens is {image_to_overwrite.sum()} while"
1002
+ f" the number of image features given to the model is {image_features.shape[:-1].numel()}. "
1003
+ "This prevents correct indexing and breaks batch generation.")
1004
+
1005
+ final_embedding[image_to_overwrite] = (
1006
+ image_features.contiguous().reshape(-1,
1007
+ embed_dim).to(target_device))
1008
+ final_attention_mask |= image_to_overwrite
1009
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
1010
+ (final_attention_mask == 0), 1)
1011
+
1012
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
1013
+ batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
1014
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
1015
+
1016
+ final_embedding[batch_indices, indices_to_mask] = 0
1017
+
1018
+ if labels is None:
1019
+ final_labels = None
1020
+
1021
+ return final_embedding, final_attention_mask, final_labels, position_ids
1022
+
1023
+ def _extract_image_features(self, pixel_values: torch.Tensor,
1024
+ grid_thws: torch.Tensor) -> list[torch.Tensor]:
1025
+ """
1026
+ Args:
1027
+ pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
1028
+ The pixel values of the images processed by image processor.
1029
+ grid_thws (:obj:`torch.Tensor` of shape :obj:`(batch_size, 3)`):
1030
+ The grid, height, width of the images.
1031
+
1032
+ Returns:
1033
+ selected_image_feature (:obj:`torch.FloatTensor` of shape :obj:`(num_image_tokens, embed_dim)`):
1034
+ The selected image features to use as input to the projector head.
1035
+
1036
+ """
1037
+
1038
+ target_dtype = self.vision_tower.patch_embed.proj.weight.dtype
1039
+ pixel_values = pixel_values.to(target_dtype)
1040
+
1041
+ image_features = self.vision_tower(pixel_values, grid_thws)
1042
+ return image_features
1043
+
1044
+ def forward(
1045
+ self,
1046
+ input_ids: torch.LongTensor | None = None,
1047
+ pixel_values: torch.FloatTensor | list[torch.FloatTensor]
1048
+ | None = None,
1049
+ grid_thws: torch.Tensor | None = None,
1050
+ attention_mask: torch.Tensor | None = None,
1051
+ position_ids: torch.LongTensor | None = None,
1052
+ past_key_values: list[torch.FloatTensor] | None = None,
1053
+ inputs_embeds: torch.FloatTensor | None = None,
1054
+ labels: torch.LongTensor | None = None,
1055
+ use_cache: bool | None = None,
1056
+ output_attentions: bool | None = None,
1057
+ output_hidden_states: bool | None = None,
1058
+ return_dict: bool | None = None,
1059
+ ) -> tuple | LlavaCausalLMOutputWithPast:
1060
+ r"""
1061
+ Args:
1062
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1063
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1064
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1065
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1066
+
1067
+ ```"""
1068
+ assert self.vision_tower is not None, "vision_tower is not loaded"
1069
+ output_attentions = (output_attentions if output_attentions is not None
1070
+ else self.config.output_attentions)
1071
+ output_hidden_states = (output_hidden_states
1072
+ if output_hidden_states is not None else
1073
+ self.config.output_hidden_states)
1074
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1075
+
1076
+ if inputs_embeds is None:
1077
+ # 1. Extra the input embeddings
1078
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1079
+
1080
+ # 2. Merge text and images
1081
+ if pixel_values is not None and len(
1082
+ pixel_values) > 0 and input_ids.shape[1] != 1:
1083
+ image_features = self._extract_image_features(
1084
+ pixel_values, grid_thws)
1085
+ if self.mm_projector:
1086
+ image_features = self.mm_projector(image_features)
1087
+
1088
+ inputs_embeds = inputs_embeds.to(
1089
+ image_features[0].dtype) # num_tokens, embed_dim
1090
+ inputs_embeds, attention_mask, labels, position_ids = (
1091
+ self._merge_input_ids_with_image_features(
1092
+ image_features,
1093
+ inputs_embeds,
1094
+ input_ids,
1095
+ attention_mask,
1096
+ labels,
1097
+ ))
1098
+
1099
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1100
+ # generation with cache
1101
+ elif (past_key_values is not None and pixel_values is not None
1102
+ and input_ids.shape[1] == 1):
1103
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
1104
+ # that are set to 0
1105
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
1106
+
1107
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
1108
+ batch_index, non_attended_tokens = torch.where(
1109
+ first_layer_past_key_value.float().sum(-2) == 0)
1110
+
1111
+ # Get the target length
1112
+ target_length = input_ids.shape[1]
1113
+ past_length = first_layer_past_key_value.shape[-1]
1114
+
1115
+ extended_attention_mask = torch.ones(
1116
+ (attention_mask.shape[0], past_length),
1117
+ dtype=attention_mask.dtype,
1118
+ device=attention_mask.device,
1119
+ )
1120
+
1121
+ # Filter out only the tokens that can be un-attended, this can happen
1122
+ # if one uses Llava + Fused modules where the cache on the
1123
+ # first iteration is already big enough, or if one passes custom cache
1124
+ valid_indices = non_attended_tokens < extended_attention_mask.size(
1125
+ -1)
1126
+ new_batch_index = batch_index[valid_indices]
1127
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
1128
+
1129
+ # Zero-out the places where we don't need to attend
1130
+ extended_attention_mask[new_batch_index,
1131
+ new_non_attended_tokens] = 0
1132
+
1133
+ attention_mask = torch.cat(
1134
+ (extended_attention_mask, attention_mask[:,
1135
+ -target_length:]),
1136
+ dim=1)
1137
+ position_ids = torch.sum(attention_mask,
1138
+ dim=1).unsqueeze(-1) - 1
1139
+
1140
+ outputs = self.language_model(
1141
+ attention_mask=attention_mask,
1142
+ position_ids=position_ids,
1143
+ past_key_values=past_key_values,
1144
+ inputs_embeds=inputs_embeds,
1145
+ use_cache=use_cache,
1146
+ output_attentions=output_attentions,
1147
+ output_hidden_states=output_hidden_states,
1148
+ return_dict=return_dict,
1149
+ )
1150
+
1151
+ logits = outputs[0]
1152
+
1153
+ loss = None
1154
+ if labels is not None:
1155
+ # Shift so that tokens < n predict n
1156
+ if attention_mask is not None:
1157
+ shift_attention_mask = attention_mask[..., 1:]
1158
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(
1159
+ logits.device) != 0].contiguous()
1160
+ shift_labels = labels[..., 1:][shift_attention_mask.to(
1161
+ labels.device) != 0].contiguous()
1162
+ else:
1163
+ shift_logits = logits[..., :-1, :].contiguous()
1164
+ shift_labels = labels[..., 1:].contiguous()
1165
+ # Flatten the tokens
1166
+ loss_fct = nn.CrossEntropyLoss()
1167
+ loss = loss_fct(
1168
+ shift_logits.view(-1, shift_logits.size(-1)),
1169
+ shift_labels.view(-1).to(shift_logits.device),
1170
+ )
1171
+
1172
+ if not return_dict:
1173
+ output = (logits, ) + outputs[1:]
1174
+ return (loss, ) + output if loss is not None else output
1175
+
1176
+ return LlavaCausalLMOutputWithPast(
1177
+ loss=loss,
1178
+ logits=logits,
1179
+ past_key_values=outputs.past_key_values,
1180
+ hidden_states=outputs.hidden_states,
1181
+ attentions=outputs.attentions,
1182
+ )
1183
+
1184
+ def prepare_inputs_for_generation(
1185
+ self,
1186
+ input_ids,
1187
+ past_key_values=None,
1188
+ inputs_embeds=None,
1189
+ pixel_values=None,
1190
+ grid_thws=None,
1191
+ attention_mask=None,
1192
+ **kwargs,
1193
+ ):
1194
+ if past_key_values is not None:
1195
+ if isinstance(past_key_values, Cache):
1196
+ cache_length = past_key_values.get_seq_length()
1197
+ past_length = getattr(past_key_values, 'seen_tokens',
1198
+ cache_length)
1199
+ else:
1200
+ cache_length = past_length = past_key_values[0][0].shape[2]
1201
+
1202
+ # Keep only the unprocessed tokens:
1203
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1204
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1205
+ # input)
1206
+ if attention_mask is not None and attention_mask.shape[
1207
+ 1] > input_ids.shape[1]:
1208
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
1209
+ past_length):]
1210
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1211
+ # input_ids based on the past_length.
1212
+ elif past_length < input_ids.shape[1]:
1213
+ input_ids = input_ids[:, past_length:]
1214
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1215
+ elif self.config.media_placeholder_token_id in input_ids:
1216
+ input_ids = input_ids[:, input_ids.shape[1] - 1:]
1217
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1218
+ # older attention values, as their corresponding values are not part of the input.
1219
+ if cache_length < past_length and attention_mask is not None:
1220
+ attention_mask = attention_mask[:, -(cache_length +
1221
+ input_ids.shape[1]):]
1222
+
1223
+ position_ids = kwargs.get("position_ids", None)
1224
+ if attention_mask is not None and position_ids is None:
1225
+ # create position_ids on the fly for batch generation
1226
+ position_ids = attention_mask.long().cumsum(-1) - 1
1227
+ position_ids.masked_fill_(attention_mask == 0, 1)
1228
+ if past_key_values:
1229
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1230
+
1231
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1232
+ if inputs_embeds is not None and past_key_values is None:
1233
+ model_inputs = {"inputs_embeds": inputs_embeds}
1234
+ else:
1235
+ model_inputs = {"input_ids": input_ids}
1236
+
1237
+ model_inputs.update({
1238
+ "position_ids": position_ids,
1239
+ "past_key_values": past_key_values,
1240
+ "use_cache": kwargs.get("use_cache"),
1241
+ "attention_mask": attention_mask,
1242
+ "pixel_values": pixel_values,
1243
+ "grid_thws": grid_thws,
1244
+ })
1245
+ return model_inputs
1246
+
1247
+ def _reorder_cache(self, *args, **kwargs):
1248
+ return self.language_model._reorder_cache(*args, **kwargs)
preprocessor_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "kimi_k25_processor.KimiK25Processor",
4
+ "AutoImageProcessor": "kimi_k25_vision_processing.KimiK25VisionProcessor"
5
+ },
6
+ "media_proc_cfg": {
7
+ "in_patch_limit": 16384,
8
+ "patch_size": 14,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "merge_kernel_size": 2,
20
+ "fixed_output_tokens": null,
21
+ "patch_limit_on_one_side": 512,
22
+ "in_patch_limit_each_frame": 4096,
23
+ "in_patch_limit_video": null,
24
+ "sample_fps": 2.0,
25
+ "max_num_frames_each_video": null,
26
+ "temporal_merge_kernel_size": 4,
27
+ "timestamp_mode": "hh:mm:ss.fff",
28
+ "config_type": "media_proc.processors.moonvit.MoonViTMediaProcessorConfig"
29
+ }
30
+ }
tiktoken.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103
3
+ size 2795286
tokenization_kimi.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from shutil import copyfile
6
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
+
8
+ import tiktoken
9
+ from tiktoken.load import load_tiktoken_bpe
10
+ from tokenizers import AddedToken
11
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
12
+ from transformers.tokenization_utils import PreTrainedTokenizer
13
+
14
+ from .tool_declaration_ts import encode_tools_to_typescript_style
15
+
16
+ logger = getLogger(__name__)
17
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
18
+
19
+
20
+ class TikTokenTokenizer(PreTrainedTokenizer):
21
+ """
22
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
23
+
24
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
25
+ this superclass for more information regarding those methods.
26
+
27
+ Args:
28
+ vocab_file (`str`):
29
+ The path to the Tiktoken model file.
30
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
31
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
32
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
33
+ The end of sequence token.
34
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
35
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
36
+ token instead. The second to last item in special_tokens.
37
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
38
+ The token used for padding, for example when batching sequences of different lengths.
39
+ additional_special_tokens (list of `str`, *optional*):
40
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
41
+ skipped when decoding if `skip_special_tokens` is set to `True`.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+
48
+ special_tokens: Dict[str, int]
49
+
50
+ num_reserved_special_tokens = 256
51
+
52
+ pat_str = "|".join([
53
+ r"""[\p{Han}]+""",
54
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
55
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
56
+ r"""\p{N}{1,3}""",
57
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
58
+ r"""\s*[\r\n]+""",
59
+ r"""\s+(?!\S)""",
60
+ r"""\s+""",
61
+ ])
62
+
63
+ def __init__(
64
+ self,
65
+ vocab_file,
66
+ bos_token: Union[str, AddedToken] = "[BOS]",
67
+ eos_token: Union[str, AddedToken] = "[EOS]",
68
+ unk_token: Union[str, AddedToken, None] = None,
69
+ pad_token: Union[str, AddedToken, None] = None,
70
+ additional_special_tokens: List[str] = None,
71
+ added_tokens_decoder: Optional[dict] = None,
72
+ **kwargs,
73
+ ):
74
+ assert os.path.isfile(vocab_file), vocab_file
75
+
76
+ if additional_special_tokens is None:
77
+ additional_special_tokens = [
78
+ "<|im_end|>",
79
+ "<|im_user|>",
80
+ "<|im_assistant|>",
81
+ "<|start_header_id|>",
82
+ "<|end_header_id|>",
83
+ "[EOT]",
84
+ "<|im_system|>",
85
+ "<|im_middle|>",
86
+ ]
87
+
88
+ if added_tokens_decoder:
89
+ special_tokens_mapping = {
90
+ i: added_tokens_decoder[i].content
91
+ for i in added_tokens_decoder
92
+ }
93
+ else:
94
+ special_tokens_mapping = {}
95
+
96
+ self.vocab_file = vocab_file
97
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
98
+ num_base_tokens = len(mergeable_ranks)
99
+ self.special_tokens = {
100
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
101
+ for i in range(num_base_tokens, num_base_tokens +
102
+ self.num_reserved_special_tokens)
103
+ }
104
+
105
+ self.model = tiktoken.Encoding(
106
+ name=Path(vocab_file).name,
107
+ pat_str=self.pat_str,
108
+ mergeable_ranks=mergeable_ranks,
109
+ special_tokens=self.special_tokens,
110
+ )
111
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
112
+
113
+ self.n_words: int = self.model.n_vocab
114
+ # BOS / EOS token IDs
115
+ self.bos_id: int = self.special_tokens[str(bos_token)]
116
+ self.eos_id: int = self.special_tokens[str(eos_token)]
117
+ logger.info(
118
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
119
+ )
120
+
121
+ self.pad_id: int = self.special_tokens[str(pad_token)]
122
+ self.unk_id: int = self.special_tokens[str(unk_token)]
123
+
124
+ self.byte_encoder = bytes_to_unicode()
125
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
126
+
127
+ self.decoder = {}
128
+ for i in range(self.n_words):
129
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
130
+ decoding = ''.join([
131
+ self.byte_encoder[ord(char)] for char in
132
+ self.model.decode_single_token_bytes(i).decode('latin-1')
133
+ ])
134
+ self.decoder[i] = decoding
135
+
136
+ self.encoder = {}
137
+ for i in range(self.n_words):
138
+ if i in self.decoder:
139
+ self.encoder[self.decoder[i]] = i
140
+
141
+ self._token_config_cache = OrderedDict()
142
+ self._cache_max_size = 128
143
+
144
+ super().__init__(
145
+ bos_token=bos_token,
146
+ eos_token=eos_token,
147
+ unk_token=unk_token,
148
+ pad_token=pad_token,
149
+ additional_special_tokens=additional_special_tokens,
150
+ added_tokens_decoder=added_tokens_decoder,
151
+ **kwargs,
152
+ )
153
+ self.all_special_ids_set = set(self.all_special_ids)
154
+
155
+ def encode(self,
156
+ text: str,
157
+ allow_special_tokens: bool = True,
158
+ **kwargs) -> List[int]:
159
+ """
160
+ Encodes a string into a list of token IDs.
161
+
162
+ Args:
163
+ text (str): The input string to be encoded.
164
+
165
+ Returns:
166
+ list[int]: A list of token IDs.
167
+ """
168
+ # If there are other args, we should call super().encode because there are a lot of code
169
+ # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
170
+ # NOTE: our encode method is not compatible with the super().encode method,
171
+ # e.g. split_special_tokens' default is True in our encode method.
172
+ if len(kwargs) > 0:
173
+ logger.warning(f"Calling super().encode with {kwargs}")
174
+ return super().encode(text, **kwargs)
175
+
176
+ assert type(text) is str
177
+
178
+ # The tiktoken tokenizer can handle <=400k chars without
179
+ # pyo3_runtime.PanicException.
180
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
181
+
182
+ # https://github.com/openai/tiktoken/issues/195
183
+ # Here we iterate over subsequences and split if we exceed the limit
184
+ # of max consecutive non-whitespace or whitespace characters.
185
+ MAX_NO_WHITESPACES_CHARS = 25_000
186
+
187
+ texts = self.pre_tokenizer_process(text)
188
+
189
+ all_substrs = []
190
+ for text in texts:
191
+ substrs = (
192
+ substr for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
193
+ for substr in self._split_whitespaces_or_nonwhitespaces(
194
+ text[i:i +
195
+ TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS))
196
+ all_substrs.extend(substrs)
197
+
198
+ t: List[int] = []
199
+ for substr in all_substrs:
200
+ if allow_special_tokens:
201
+ t.extend(
202
+ # we should consider special token as a common token
203
+ self.model.encode(
204
+ substr,
205
+ allowed_special="all",
206
+ ))
207
+ else:
208
+ t.extend(
209
+ # we should consider special token as a common token
210
+ self.model.encode(
211
+ substr,
212
+ disallowed_special=(),
213
+ ))
214
+
215
+ return t
216
+
217
+ def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
218
+ """
219
+ Decodes a list of token IDs into a string.
220
+
221
+ Args:
222
+ token_ids (List[int]): The list of token IDs to be decoded.
223
+
224
+ Returns:
225
+ str: The decoded string.
226
+ """
227
+ # If there are other args, we should call super().decode because there are a lot of code
228
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
229
+ if len(kwargs) > 0:
230
+ return super().decode(token_ids, **kwargs)
231
+
232
+ if type(token_ids) is int:
233
+ token_ids = [token_ids]
234
+
235
+ return self.model.decode(cast(List[int], token_ids))
236
+
237
+ @staticmethod
238
+ def _split_whitespaces_or_nonwhitespaces(
239
+ s: str, max_consecutive_slice_len: int) -> Iterator[str]:
240
+ """
241
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
242
+ consecutive whitespaces or consecutive non-whitespaces.
243
+ """
244
+ current_slice_len = 0
245
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
246
+ slice_start = 0
247
+
248
+ for i in range(len(s)):
249
+ is_now_space = s[i].isspace()
250
+
251
+ if current_slice_is_space ^ is_now_space:
252
+ current_slice_len = 1
253
+ current_slice_is_space = is_now_space
254
+ else:
255
+ current_slice_len += 1
256
+ if current_slice_len > max_consecutive_slice_len:
257
+ yield s[slice_start:i]
258
+ slice_start = i
259
+ current_slice_len = 1
260
+ yield s[slice_start:]
261
+
262
+ def pre_tokenizer_process(self, text: str) -> List[str]:
263
+ """
264
+ pre-tokenizes the input text into a list of tokens.
265
+ This method is used to split the input text into smaller chunks for internal processing.
266
+ """
267
+ return [text]
268
+
269
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
270
+
271
+ @property
272
+ def vocab_size(self) -> int:
273
+ return self.n_words
274
+
275
+ def get_vocab(self) -> Dict[str, int]:
276
+ return self.encoder
277
+
278
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
279
+ return [self.decoder[t] for t in self.encode(text)]
280
+
281
+ def _convert_token_to_id(self, token: str) -> int:
282
+ return self.encoder.get(token, self.unk_id)
283
+
284
+ def _convert_id_to_token(self, index: int) -> str:
285
+ return self.decoder.get(index)
286
+
287
+ @staticmethod
288
+ def clean_up_tokenization(out_string: str) -> str:
289
+ return out_string
290
+
291
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
292
+ text = ''.join(tokens)
293
+ text = bytearray([self.byte_decoder[c]
294
+ for c in text]).decode('utf-8', 'replace')
295
+ return text
296
+
297
+ def save_vocabulary(self,
298
+ save_directory: str,
299
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
300
+ if not os.path.isdir(save_directory):
301
+ raise ValueError(
302
+ f"vocabulary path ({save_directory}) should be a directory")
303
+ out_vocab_file = os.path.join(
304
+ save_directory,
305
+ (filename_prefix + "-" if filename_prefix else "") +
306
+ VOCAB_FILES_NAMES["vocab_file"])
307
+
308
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
309
+ out_vocab_file) and os.path.isfile(self.vocab_file):
310
+ copyfile(self.vocab_file, out_vocab_file)
311
+
312
+ return (out_vocab_file, )
313
+
314
+ def apply_chat_template(self,
315
+ conversation,
316
+ tools: Optional[list[dict]] = None,
317
+ tokenize: bool = False,
318
+ add_generation_prompt: bool = True,
319
+ thinking: bool = True,
320
+ **kwargs):
321
+
322
+ tools = deep_sort_dict(tools)
323
+
324
+ # Convert tools to TypeScript style string if tools are provided
325
+ tools_ts_str = None
326
+ if tools:
327
+ try:
328
+ tools_ts_str = encode_tools_to_typescript_style(tools)
329
+
330
+ except Exception as e:
331
+ print(f"Failed to convert tools to TypeScript style: {e}")
332
+ tools_ts_str = None
333
+
334
+ # Store the TypeScript string in kwargs so it can be accessed by the template
335
+ if tools_ts_str is not None:
336
+ kwargs['tools_ts_str'] = tools_ts_str
337
+ return super().apply_chat_template(
338
+ conversation,
339
+ tools=tools,
340
+ tokenize=tokenize,
341
+ add_generation_prompt=add_generation_prompt,
342
+ thinking=thinking,
343
+ **kwargs)
344
+
345
+
346
+ def deep_sort_dict(obj: Any) -> Any:
347
+ if isinstance(obj, dict):
348
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
349
+ if isinstance(obj, list):
350
+ return [deep_sort_dict(item) for item in obj]
351
+ return obj
tokenizer_config.json ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "163584": {
4
+ "content": "[BOS]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "163585": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "163586": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "163587": {
28
+ "content": "<|im_user|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "163588": {
36
+ "content": "<|im_assistant|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "163590": {
44
+ "content": "<|start_header_id|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "163591": {
52
+ "content": "<|end_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "163593": {
60
+ "content": "[EOT]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "163594": {
68
+ "content": "<|im_system|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "163595": {
76
+ "content": "<|tool_calls_section_begin|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "163596": {
84
+ "content": "<|tool_calls_section_end|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "163597": {
92
+ "content": "<|tool_call_begin|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "163598": {
100
+ "content": "<|tool_call_argument_begin|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "163599": {
108
+ "content": "<|tool_call_end|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "163601": {
116
+ "content": "<|im_middle|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "163602": {
124
+ "content": "<|media_begin|>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "163603": {
132
+ "content": "<|media_content|>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "163604": {
140
+ "content": "<|media_end|>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "163605": {
148
+ "content": "<|media_pad|>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "163606": {
156
+ "content": "<think>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "163607": {
164
+ "content": "</think>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "163838": {
172
+ "content": "[UNK]",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ },
179
+ "163839": {
180
+ "content": "[PAD]",
181
+ "lstrip": false,
182
+ "normalized": false,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": true
186
+ }
187
+ },
188
+ "additional_special_tokens": [
189
+ "<|im_end|>",
190
+ "<|im_user|>",
191
+ "<|im_assistant|>",
192
+ "<|start_header_id|>",
193
+ "<|end_header_id|>",
194
+ "[EOT]",
195
+ "<|im_system|>",
196
+ "<|im_middle|>",
197
+ "<|media_begin|>",
198
+ "<|media_content|>",
199
+ "<|media_end|>",
200
+ "<|media_pad|>"
201
+ ],
202
+ "bos_token": "[BOS]",
203
+ "clean_up_tokenization_spaces": false,
204
+ "eos_token": "[EOS]",
205
+ "extra_special_tokens": {},
206
+ "model_max_length": 1000000000000000019884624838656,
207
+ "pad_token": "[PAD]",
208
+ "tokenizer_class": "TikTokenTokenizer",
209
+ "unk_token": "[UNK]",
210
+ "auto_map": {
211
+ "AutoTokenizer": [
212
+ "tokenization_kimi.TikTokenTokenizer",
213
+ null
214
+ ]
215
+ }
216
+ }
tool_declaration_ts.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encode structured tool declaration to typescript style string.
3
+ """
4
+ import dataclasses
5
+ import json
6
+ import logging
7
+ from collections.abc import Sequence
8
+ from typing import Any
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _TS_INDENT = " "
13
+ _TS_FIELD_DELIMITER = ",\n"
14
+
15
+
16
+ class _SchemaRegistry:
17
+ """Registry for schema definitions to handle $ref resolution"""
18
+
19
+ def __init__(self):
20
+ self.definitions = {}
21
+ self.has_self_ref = False
22
+
23
+ def register_definitions(self, defs: dict[str, Any]):
24
+ """Register schema definitions from $defs section"""
25
+ if not defs:
26
+ return
27
+ for def_name, def_schema in defs.items():
28
+ self.definitions[def_name] = def_schema
29
+
30
+ def resolve_ref(self, ref: str) -> dict[str, Any]:
31
+ """Resolve a reference to its schema definition"""
32
+ if ref == "#":
33
+ self.has_self_ref = True
34
+ return {"$self_ref": True}
35
+ elif ref.startswith("#/$defs/"):
36
+ def_name = ref.split("/")[-1]
37
+ if def_name not in self.definitions:
38
+ raise ValueError(f"Reference not found: {ref}")
39
+ return self.definitions[def_name]
40
+ else:
41
+ raise ValueError(f"Unsupported reference format: {ref}")
42
+
43
+
44
+ def _format_description(description: str, indent: str = "") -> str:
45
+ return "\n".join([
46
+ f"{indent}// {line}" if line else ""
47
+ for line in description.split("\n")
48
+ ])
49
+
50
+
51
+ class _BaseType:
52
+ description: str
53
+ constraints: dict[str, Any]
54
+
55
+ def __init__(
56
+ self,
57
+ extra_props: dict[str, Any],
58
+ *,
59
+ allowed_constraint_keys: Sequence[str] = (),
60
+ ):
61
+ self.description = extra_props.get("description", "")
62
+ self.constraints = {
63
+ k: v
64
+ for k, v in extra_props.items() if k in allowed_constraint_keys
65
+ }
66
+
67
+ def to_typescript_style(self, indent: str = "") -> str:
68
+ raise NotImplementedError
69
+
70
+ def format_docstring(self, indent: str) -> str:
71
+ lines = []
72
+ if self.description:
73
+ lines.append(_format_description(self.description, indent))
74
+ if self.constraints:
75
+ constraints_str = ", ".join(f"{k}: {v}" for k, v in sorted(
76
+ self.constraints.items(), key=lambda kv: kv[0]))
77
+ lines.append(f"{indent}// {constraints_str}")
78
+
79
+ return "".join(x + "\n" for x in lines)
80
+
81
+
82
+ class _ParameterTypeScalar(_BaseType):
83
+ type: str
84
+
85
+ def __init__(self, type: str, extra_props: dict[str, Any] | None = None):
86
+ self.type = type
87
+
88
+ allowed_constraint_keys: list[str] = []
89
+ if self.type == "string":
90
+ allowed_constraint_keys = ["maxLength", "minLength", "pattern"]
91
+ elif self.type in ("number", "integer"):
92
+ allowed_constraint_keys = ["maximum", "minimum"]
93
+
94
+ super().__init__(extra_props or {},
95
+ allowed_constraint_keys=allowed_constraint_keys)
96
+
97
+ def to_typescript_style(self, indent: str = "") -> str:
98
+ # Map integer to number in TypeScript
99
+ if self.type == "integer":
100
+ return "number"
101
+ return self.type
102
+
103
+
104
+ class _ParameterTypeObject(_BaseType):
105
+ properties: list["_Parameter"]
106
+ additional_properties: Any | None = None
107
+
108
+ def __init__(self,
109
+ json_schema_object: dict[str, Any],
110
+ registry: _SchemaRegistry | None = None):
111
+ super().__init__(json_schema_object)
112
+
113
+ self.properties = []
114
+ self.additional_properties = None
115
+
116
+ if not json_schema_object:
117
+ return
118
+
119
+ if "$defs" in json_schema_object and registry:
120
+ registry.register_definitions(json_schema_object["$defs"])
121
+
122
+ self.additional_properties = json_schema_object.get(
123
+ "additionalProperties")
124
+ if isinstance(self.additional_properties, dict):
125
+ self.additional_properties = _parse_parameter_type(
126
+ self.additional_properties, registry)
127
+
128
+ if "properties" not in json_schema_object:
129
+ return
130
+
131
+ required_parameters = json_schema_object.get("required", [])
132
+ optional_parameters = set(
133
+ json_schema_object["properties"].keys()) - set(required_parameters)
134
+
135
+ self.properties = [
136
+ _Parameter(
137
+ name=name,
138
+ type=_parse_parameter_type(prop, registry),
139
+ optional=name in optional_parameters,
140
+ default=prop.get("default")
141
+ if isinstance(prop, dict) else None,
142
+ ) for name, prop in json_schema_object["properties"].items()
143
+ ]
144
+
145
+ def to_typescript_style(self, indent: str = "") -> str:
146
+ # sort by optional, make the required parameters first
147
+ parameters = [p for p in self.properties if not p.optional]
148
+ opt_params = [p for p in self.properties if p.optional]
149
+
150
+ parameters = sorted(parameters, key=lambda p: p.name)
151
+ parameters.extend(sorted(opt_params, key=lambda p: p.name))
152
+
153
+ param_strs = []
154
+ for p in parameters:
155
+ one = p.to_typescript_style(indent=indent + _TS_INDENT)
156
+ param_strs.append(one)
157
+
158
+ if self.additional_properties is not None:
159
+ ap_type_str = "any"
160
+ if self.additional_properties is True:
161
+ ap_type_str = "any"
162
+ elif self.additional_properties is False:
163
+ ap_type_str = "never"
164
+ elif isinstance(self.additional_properties, _ParameterType):
165
+ ap_type_str = self.additional_properties.to_typescript_style(
166
+ indent=indent + _TS_INDENT)
167
+ else:
168
+ raise ValueError(
169
+ f"Unknown additionalProperties: {self.additional_properties}"
170
+ )
171
+ param_strs.append(
172
+ f"{indent + _TS_INDENT}[k: string]: {ap_type_str}")
173
+
174
+ if not param_strs:
175
+ return "{}"
176
+
177
+ params_str = _TS_FIELD_DELIMITER.join(param_strs)
178
+ if params_str:
179
+ # add new line before and after
180
+ params_str = f"\n{params_str}\n"
181
+ # always wrap with object
182
+ return f"{{{params_str}{indent}}}"
183
+
184
+
185
+ class _ParameterTypeArray(_BaseType):
186
+ item: "_ParameterType"
187
+
188
+ def __init__(self,
189
+ json_schema_object: dict[str, Any],
190
+ registry: _SchemaRegistry | None = None):
191
+ super().__init__(json_schema_object,
192
+ allowed_constraint_keys=("minItems", "maxItems"))
193
+ if json_schema_object.get("items"):
194
+ self.item = _parse_parameter_type(json_schema_object["items"],
195
+ registry)
196
+ else:
197
+ self.item = _ParameterTypeScalar(type="any")
198
+
199
+ def to_typescript_style(self, indent: str = "") -> str:
200
+ item_docstring = self.item.format_docstring(indent + _TS_INDENT)
201
+ if item_docstring:
202
+ return ("Array<\n" + item_docstring + indent + _TS_INDENT +
203
+ self.item.to_typescript_style(indent=indent + _TS_INDENT) +
204
+ "\n" + indent + ">")
205
+ else:
206
+ return f"Array<{self.item.to_typescript_style(indent=indent)}>"
207
+
208
+
209
+ class _ParameterTypeEnum(_BaseType):
210
+ # support scalar types only
211
+ enum: list[str | int | float | bool | None]
212
+
213
+ def __init__(self, json_schema_object: dict[str, Any]):
214
+ super().__init__(json_schema_object)
215
+ self.enum = json_schema_object["enum"]
216
+
217
+ # Validate enum values against declared type if present
218
+ if "type" in json_schema_object:
219
+ typ = json_schema_object["type"]
220
+ if isinstance(typ, list):
221
+ if len(typ) == 1:
222
+ typ = typ[0]
223
+ elif len(typ) == 2:
224
+ if "null" not in typ:
225
+ raise ValueError(f"Enum type {typ} is not supported")
226
+ else:
227
+ typ = typ[0] if typ[0] != "null" else typ[1]
228
+ else:
229
+ raise ValueError(f"Enum type {typ} is not supported")
230
+ for val in self.enum:
231
+ if val is None:
232
+ continue
233
+ if typ == "string" and not isinstance(val, str):
234
+ raise ValueError(f"Enum value {val} is not a string")
235
+ elif typ == "number" and not isinstance(val, (int, float)):
236
+ raise ValueError(f"Enum value {val} is not a number")
237
+ elif typ == "integer" and not isinstance(val, int):
238
+ raise ValueError(f"Enum value {val} is not an integer")
239
+ elif typ == "boolean" and not isinstance(val, bool):
240
+ raise ValueError(f"Enum value {val} is not a boolean")
241
+
242
+ def to_typescript_style(self, indent: str = "") -> str:
243
+ return " | ".join(
244
+ [f'"{e}"' if isinstance(e, str) else str(e) for e in self.enum])
245
+
246
+
247
+ class _ParameterTypeAnyOf(_BaseType):
248
+ types: list["_ParameterType"]
249
+
250
+ def __init__(
251
+ self,
252
+ json_schema_object: dict[str, Any],
253
+ registry: _SchemaRegistry | None = None,
254
+ ):
255
+ super().__init__(json_schema_object)
256
+ self.types = [
257
+ _parse_parameter_type(t, registry)
258
+ for t in json_schema_object["anyOf"]
259
+ ]
260
+
261
+ def to_typescript_style(self, indent: str = "") -> str:
262
+ return " | ".join(
263
+ [t.to_typescript_style(indent=indent) for t in self.types])
264
+
265
+
266
+ class _ParameterTypeUnion(_BaseType):
267
+ types: list[str]
268
+
269
+ def __init__(self, json_schema_object: dict[str, Any]):
270
+ super().__init__(json_schema_object)
271
+
272
+ mapping = {
273
+ "string": "string",
274
+ "number": "number",
275
+ "integer": "number",
276
+ "boolean": "boolean",
277
+ "null": "null",
278
+ "object": "{}",
279
+ "array": "Array<any>",
280
+ }
281
+ self.types = [mapping[t] for t in json_schema_object["type"]]
282
+
283
+ def to_typescript_style(self, indent: str = "") -> str:
284
+ return " | ".join(self.types)
285
+
286
+
287
+ class _ParameterTypeRef(_BaseType):
288
+ ref_name: str
289
+ is_self_ref: bool = False
290
+
291
+ def __init__(self, json_schema_object: dict[str, Any],
292
+ registry: _SchemaRegistry):
293
+ super().__init__(json_schema_object)
294
+
295
+ ref = json_schema_object["$ref"]
296
+ resolved_schema = registry.resolve_ref(ref)
297
+
298
+ if resolved_schema.get("$self_ref", False):
299
+ self.ref_name = "parameters"
300
+ self.is_self_ref = True
301
+ else:
302
+ self.ref_name = ref.split("/")[-1]
303
+
304
+ def to_typescript_style(self, indent: str = "") -> str:
305
+ return self.ref_name
306
+
307
+
308
+ _ParameterType = (_ParameterTypeScalar
309
+ | _ParameterTypeObject
310
+ | _ParameterTypeArray
311
+ | _ParameterTypeEnum
312
+ | _ParameterTypeAnyOf
313
+ | _ParameterTypeUnion
314
+ | _ParameterTypeRef)
315
+
316
+
317
+ @dataclasses.dataclass
318
+ class _Parameter:
319
+ """
320
+ A parameter in a function, or a field in a object.
321
+ It consists of the type as well as the name.
322
+ """
323
+
324
+ type: _ParameterType
325
+ name: str = "_"
326
+ optional: bool = True
327
+ default: Any | None = None
328
+
329
+ @classmethod
330
+ def parse_extended(cls, attributes: dict[str, Any]) -> "_Parameter":
331
+ if not attributes:
332
+ raise ValueError("attributes is empty")
333
+
334
+ return cls(
335
+ name=attributes.get("name", "_"),
336
+ type=_parse_parameter_type(attributes),
337
+ optional=attributes.get("optional", False),
338
+ default=attributes.get("default"),
339
+ )
340
+
341
+ def to_typescript_style(self, indent: str = "") -> str:
342
+ comments = self.type.format_docstring(indent)
343
+
344
+ if self.default is not None:
345
+ default_repr = (json.dumps(self.default, ensure_ascii=False)
346
+ if not isinstance(self.default, (int, float, bool))
347
+ else repr(self.default))
348
+ comments += f"{indent}// Default: {default_repr}\n"
349
+
350
+ return (
351
+ comments +
352
+ f"{indent}{self.name}{'?' if self.optional else ''}: {self.type.to_typescript_style(indent=indent)}"
353
+ )
354
+
355
+
356
+ def _parse_parameter_type(
357
+ json_schema_object: dict[str, Any] | bool,
358
+ registry: _SchemaRegistry | None = None) -> _ParameterType:
359
+ if isinstance(json_schema_object, bool):
360
+ if json_schema_object:
361
+ return _ParameterTypeScalar(type="any")
362
+ else:
363
+ logger.warning(
364
+ f"Warning: Boolean value {json_schema_object} is not supported, use null instead."
365
+ )
366
+ return _ParameterTypeScalar(type="null")
367
+
368
+ if "$ref" in json_schema_object and registry:
369
+ return _ParameterTypeRef(json_schema_object, registry)
370
+
371
+ if "anyOf" in json_schema_object:
372
+ return _ParameterTypeAnyOf(json_schema_object, registry)
373
+ elif "enum" in json_schema_object:
374
+ return _ParameterTypeEnum(json_schema_object)
375
+ elif "type" in json_schema_object:
376
+ typ = json_schema_object["type"]
377
+ if isinstance(typ, list):
378
+ return _ParameterTypeUnion(json_schema_object)
379
+ elif typ == "object":
380
+ return _ParameterTypeObject(json_schema_object, registry)
381
+ elif typ == "array":
382
+ return _ParameterTypeArray(json_schema_object, registry)
383
+ else:
384
+ return _ParameterTypeScalar(typ, json_schema_object)
385
+ elif json_schema_object == {}:
386
+ return _ParameterTypeScalar(type="any")
387
+ else:
388
+ raise ValueError(f"Invalid JSON Schema object: {json_schema_object}")
389
+
390
+
391
+ def _openai_function_to_typescript_style(function: dict[str, Any], ) -> str:
392
+ """Convert OpenAI function definition (dict) to TypeScript style string."""
393
+ registry = _SchemaRegistry()
394
+ parameters = function.get("parameters") or {}
395
+ parsed = _ParameterTypeObject(parameters, registry)
396
+
397
+ interfaces = []
398
+ root_interface_name = None
399
+ if registry.has_self_ref:
400
+ root_interface_name = "parameters"
401
+ params_str = _TS_FIELD_DELIMITER.join([
402
+ p.to_typescript_style(indent=_TS_INDENT) for p in parsed.properties
403
+ ])
404
+ params_str = f"\n{params_str}\n" if params_str else ""
405
+ interface_def = f"interface {root_interface_name} {{{params_str}}}"
406
+ interfaces.append(interface_def)
407
+
408
+ definitions_copy = dict(registry.definitions)
409
+ for def_name, def_schema in definitions_copy.items():
410
+ obj_type = _parse_parameter_type(def_schema, registry)
411
+ params_str = obj_type.to_typescript_style()
412
+
413
+ description_part = ""
414
+ if obj_description := def_schema.get("description", ""):
415
+ description_part = _format_description(obj_description) + "\n"
416
+
417
+ interface_def = f"{description_part}interface {def_name} {params_str}"
418
+ interfaces.append(interface_def)
419
+
420
+ interface_str = "\n".join(interfaces)
421
+ function_name = function.get("name", "function")
422
+ if root_interface_name:
423
+ type_def = f"type {function_name} = (_: {root_interface_name}) => any;"
424
+ else:
425
+ params_str = parsed.to_typescript_style()
426
+ type_def = f"type {function_name} = (_: {params_str}) => any;"
427
+
428
+ description = function.get("description")
429
+ return "\n".join(
430
+ filter(
431
+ bool,
432
+ [
433
+ interface_str,
434
+ ((description and _format_description(description)) or ""),
435
+ type_def,
436
+ ],
437
+ ))
438
+
439
+
440
+ def encode_tools_to_typescript_style(tools: list[dict[str, Any]], ) -> str:
441
+ """
442
+ Convert tools (list of dict) to TypeScript style string.
443
+
444
+ Supports OpenAI format: {"type": "function", "function": {...}}
445
+
446
+ Args:
447
+ tools: List of tool definitions in dict format
448
+
449
+ Returns:
450
+ TypeScript style string representation of the tools
451
+ """
452
+ if not tools:
453
+ return ""
454
+
455
+ functions = []
456
+
457
+ for tool in tools:
458
+ tool_type = tool.get("type")
459
+ if tool_type == "function":
460
+ func_def = tool.get("function", {})
461
+ if func_def:
462
+ functions.append(
463
+ _openai_function_to_typescript_style(func_def))
464
+ else:
465
+ # Skip unsupported tool types (like "_plugin")
466
+ continue
467
+
468
+ if not functions:
469
+ return ""
470
+
471
+ functions_str = "\n".join(functions)
472
+ result = "# Tools\n\n"
473
+
474
+ if functions_str:
475
+ result += "## functions\nnamespace functions {\n"
476
+ result += functions_str + "\n"
477
+ result += "}\n"
478
+
479
+ return result