ichbinblau commited on
Commit
c94d0b3
·
verified ·
1 Parent(s): bf78a3d

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. chat_template.jinja +101 -0
  2. config.json +894 -0
  3. configuration_deepseek.py +199 -0
  4. generation_config.json +9 -0
  5. model-00001-of-00078.safetensors +3 -0
  6. model-00002-of-00078.safetensors +3 -0
  7. model-00004-of-00078.safetensors +3 -0
  8. model-00005-of-00078.safetensors +3 -0
  9. model-00007-of-00078.safetensors +3 -0
  10. model-00009-of-00078.safetensors +3 -0
  11. model-00010-of-00078.safetensors +3 -0
  12. model-00016-of-00078.safetensors +3 -0
  13. model-00017-of-00078.safetensors +3 -0
  14. model-00019-of-00078.safetensors +3 -0
  15. model-00021-of-00078.safetensors +3 -0
  16. model-00023-of-00078.safetensors +3 -0
  17. model-00024-of-00078.safetensors +3 -0
  18. model-00025-of-00078.safetensors +3 -0
  19. model-00026-of-00078.safetensors +3 -0
  20. model-00029-of-00078.safetensors +3 -0
  21. model-00030-of-00078.safetensors +3 -0
  22. model-00033-of-00078.safetensors +3 -0
  23. model-00034-of-00078.safetensors +3 -0
  24. model-00036-of-00078.safetensors +3 -0
  25. model-00037-of-00078.safetensors +3 -0
  26. model-00040-of-00078.safetensors +3 -0
  27. model-00042-of-00078.safetensors +3 -0
  28. model-00044-of-00078.safetensors +3 -0
  29. model-00046-of-00078.safetensors +3 -0
  30. model-00048-of-00078.safetensors +3 -0
  31. model-00051-of-00078.safetensors +3 -0
  32. model-00052-of-00078.safetensors +3 -0
  33. model-00053-of-00078.safetensors +3 -0
  34. model-00054-of-00078.safetensors +3 -0
  35. model-00055-of-00078.safetensors +3 -0
  36. model-00056-of-00078.safetensors +3 -0
  37. model-00057-of-00078.safetensors +3 -0
  38. model-00058-of-00078.safetensors +3 -0
  39. model-00060-of-00078.safetensors +3 -0
  40. model-00062-of-00078.safetensors +3 -0
  41. model-00067-of-00078.safetensors +3 -0
  42. model-00069-of-00078.safetensors +3 -0
  43. model-00070-of-00078.safetensors +3 -0
  44. model-00072-of-00078.safetensors +3 -0
  45. model-00074-of-00078.safetensors +3 -0
  46. model.safetensors.index.json +0 -0
  47. modeling_deepseek.py +2176 -0
  48. special_tokens_map.json +17 -0
  49. tokenizer.json +0 -0
  50. tokenizer_config.json +0 -0
chat_template.jinja ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if not add_generation_prompt is defined %}
2
+ {%- set add_generation_prompt = false %}
3
+ {%- endif %}
4
+ {%- set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}
5
+ {%- for message in messages %}
6
+ {%- if message['role'] == 'system' %}
7
+ {%- if ns.is_first_sp %}
8
+ {%- set ns.system_prompt = ns.system_prompt + message['content'] %}
9
+ {%- set ns.is_first_sp = false %}
10
+ {%- else %}
11
+ {%- set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
12
+ {%- endif %}
13
+ {%- endif %}
14
+ {%- endfor %}
15
+
16
+ {#- Adapted from https://github.com/sgl-project/sglang/blob/main/examples/chat_template/tool_chat_template_deepseekr1.jinja #}
17
+ {%- if tools is defined and tools is not none %}
18
+ {%- set tool_ns = namespace(text='You are a helpful assistant with tool calling capabilities. ' + 'When a tool call is needed, you MUST use the following format to issue the call:\n' + '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>FUNCTION_NAME\n' + '```json\n{"param1": "value1", "param2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n\n' + 'Make sure the JSON is valid.' + '## Tools\n\n### Function\n\nYou have the following functions available:\n\n') %}
19
+ {%- for tool in tools %}
20
+ {%- set tool_ns.text = tool_ns.text + '\n```json\n' + (tool | tojson) + '\n```\n' %}
21
+ {%- endfor %}
22
+ {%- if ns.system_prompt|length != 0 %}
23
+ {%- set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
24
+ {%- else %}
25
+ {%- set ns.system_prompt = tool_ns.text %}
26
+ {%- endif %}
27
+ {%- endif %}
28
+ {{- bos_token }}
29
+ {{- ns.system_prompt }}
30
+ {%- set last_index = (messages|length - 1) %}
31
+ {%- for message in messages %}
32
+ {%- set content = message['content'] %}
33
+ {%- if message['role'] == 'user' %}
34
+ {%- set ns.is_tool = false -%}
35
+ {%- set ns.is_first = false -%}
36
+ {%- set ns.is_last_user = true -%}
37
+ {%- if loop.index0 == last_index %}
38
+ {{- '<|User|>' + content }}
39
+ {%- else %}
40
+ {{- '<|User|>' + content + '<|Assistant|>'}}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if message['role'] == 'assistant' %}
44
+ {%- if '</think>' in content %}
45
+ {%- set content = (content.split('</think>')|last) %}
46
+ {%- endif %}
47
+ {%- endif %}
48
+ {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
49
+ {%- set ns.is_last_user = false -%}
50
+ {%- if ns.is_tool %}
51
+ {{- '<|tool▁outputs▁end|>'}}
52
+ {%- endif %}
53
+ {%- set ns.is_first = false %}
54
+ {%- set ns.is_tool = false -%}
55
+ {%- set ns.is_output_first = true %}
56
+ {%- for tool in message['tool_calls'] %}
57
+ {%- set arguments = tool['function']['arguments'] %}
58
+ {%- if arguments is not string %}
59
+ {%- set arguments = arguments|tojson %}
60
+ {%- endif %}
61
+ {%- if not ns.is_first %}
62
+ {%- if content is none %}
63
+ {{- '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + arguments + '\n' + '```' + '<|tool▁call▁end|>'}}
64
+ }
65
+ {%- else %}
66
+ {{- content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + arguments + '\n' + '```' + '<|tool▁call▁end|>'}}
67
+ {%- endif %}
68
+ {%- set ns.is_first = true -%}
69
+ {%- else %}
70
+ {{- '\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + arguments + '\n' + '```' + '<|tool▁call▁end|>'}}
71
+ {%- endif %}
72
+ {%- endfor %}
73
+ {{- '<|tool▁calls▁end|><|end▁of▁sentence|>'}}
74
+ {%- endif %}
75
+ {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}
76
+ {%- set ns.is_last_user = false -%}
77
+ {%- if ns.is_tool %}
78
+ {{- '<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}
79
+ {%- set ns.is_tool = false -%}
80
+ {%- else %}
81
+ {{- content + '<|end▁of▁sentence|>'}}
82
+ {%- endif %}
83
+ {%- endif %}
84
+ {%- if message['role'] == 'tool' %}
85
+ {%- set ns.is_last_user = false -%}
86
+ {%- set ns.is_tool = true -%}
87
+ {%- if ns.is_output_first %}
88
+ {{- '<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
89
+ {%- set ns.is_output_first = false %}
90
+ {%- else %}
91
+ {{- '\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
92
+ {%- endif %}
93
+ {%- endif %}
94
+ {%- endfor -%}
95
+ {%- if ns.is_tool %}
96
+ {{- '<|tool▁outputs▁end|>'}}
97
+ {%- endif %}
98
+ {#- if add_generation_prompt and not ns.is_last_user and not ns.is_tool #}
99
+ {%- if add_generation_prompt and not ns.is_tool %}
100
+ {{- '<|Assistant|>'}}
101
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeepseekV3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_deepseek.DeepseekV3Config",
9
+ "AutoModel": "modeling_deepseek.DeepseekV3Model",
10
+ "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
11
+ },
12
+ "bos_token_id": 0,
13
+ "eos_token_id": 1,
14
+ "ep_size": 1,
15
+ "first_k_dense_replace": 3,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 7168,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 18432,
20
+ "kv_lora_rank": 512,
21
+ "max_position_embeddings": 163840,
22
+ "model_type": "deepseek_v3",
23
+ "moe_intermediate_size": 2048,
24
+ "moe_layer_freq": 1,
25
+ "n_group": 8,
26
+ "n_routed_experts": 256,
27
+ "n_shared_experts": 1,
28
+ "norm_topk_prob": true,
29
+ "num_attention_heads": 128,
30
+ "num_experts_per_tok": 8,
31
+ "num_hidden_layers": 61,
32
+ "num_key_value_heads": 128,
33
+ "num_nextn_predict_layers": 1,
34
+ "pad_token_id": 2,
35
+ "q_lora_rank": 1536,
36
+ "qk_nope_head_dim": 128,
37
+ "qk_rope_head_dim": 64,
38
+ "quantization_config": {
39
+ "algo_config": null,
40
+ "exclude": [
41
+ "model.layers.61.embed_tokens",
42
+ "model.layers.61.eh_proj",
43
+ "model.layers.61.self_attn.q_a_proj",
44
+ "model.layers.61.self_attn.q_b_proj",
45
+ "model.layers.61.self_attn.kv_a_proj_with_mqa",
46
+ "model.layers.61.self_attn.kv_b_proj",
47
+ "model.layers.61.self_attn.o_proj",
48
+ "model.layers.61.mlp.experts.0.gate_proj",
49
+ "model.layers.61.mlp.experts.0.up_proj",
50
+ "model.layers.61.mlp.experts.0.down_proj",
51
+ "model.layers.61.mlp.experts.1.gate_proj",
52
+ "model.layers.61.mlp.experts.1.up_proj",
53
+ "model.layers.61.mlp.experts.1.down_proj",
54
+ "model.layers.61.mlp.experts.2.gate_proj",
55
+ "model.layers.61.mlp.experts.2.up_proj",
56
+ "model.layers.61.mlp.experts.2.down_proj",
57
+ "model.layers.61.mlp.experts.3.gate_proj",
58
+ "model.layers.61.mlp.experts.3.up_proj",
59
+ "model.layers.61.mlp.experts.3.down_proj",
60
+ "model.layers.61.mlp.experts.4.gate_proj",
61
+ "model.layers.61.mlp.experts.4.up_proj",
62
+ "model.layers.61.mlp.experts.4.down_proj",
63
+ "model.layers.61.mlp.experts.5.gate_proj",
64
+ "model.layers.61.mlp.experts.5.up_proj",
65
+ "model.layers.61.mlp.experts.5.down_proj",
66
+ "model.layers.61.mlp.experts.6.gate_proj",
67
+ "model.layers.61.mlp.experts.6.up_proj",
68
+ "model.layers.61.mlp.experts.6.down_proj",
69
+ "model.layers.61.mlp.experts.7.gate_proj",
70
+ "model.layers.61.mlp.experts.7.up_proj",
71
+ "model.layers.61.mlp.experts.7.down_proj",
72
+ "model.layers.61.mlp.experts.8.gate_proj",
73
+ "model.layers.61.mlp.experts.8.up_proj",
74
+ "model.layers.61.mlp.experts.8.down_proj",
75
+ "model.layers.61.mlp.experts.9.gate_proj",
76
+ "model.layers.61.mlp.experts.9.up_proj",
77
+ "model.layers.61.mlp.experts.9.down_proj",
78
+ "model.layers.61.mlp.experts.10.gate_proj",
79
+ "model.layers.61.mlp.experts.10.up_proj",
80
+ "model.layers.61.mlp.experts.10.down_proj",
81
+ "model.layers.61.mlp.experts.11.gate_proj",
82
+ "model.layers.61.mlp.experts.11.up_proj",
83
+ "model.layers.61.mlp.experts.11.down_proj",
84
+ "model.layers.61.mlp.experts.12.gate_proj",
85
+ "model.layers.61.mlp.experts.12.up_proj",
86
+ "model.layers.61.mlp.experts.12.down_proj",
87
+ "model.layers.61.mlp.experts.13.gate_proj",
88
+ "model.layers.61.mlp.experts.13.up_proj",
89
+ "model.layers.61.mlp.experts.13.down_proj",
90
+ "model.layers.61.mlp.experts.14.gate_proj",
91
+ "model.layers.61.mlp.experts.14.up_proj",
92
+ "model.layers.61.mlp.experts.14.down_proj",
93
+ "model.layers.61.mlp.experts.15.gate_proj",
94
+ "model.layers.61.mlp.experts.15.up_proj",
95
+ "model.layers.61.mlp.experts.15.down_proj",
96
+ "model.layers.61.mlp.experts.16.gate_proj",
97
+ "model.layers.61.mlp.experts.16.up_proj",
98
+ "model.layers.61.mlp.experts.16.down_proj",
99
+ "model.layers.61.mlp.experts.17.gate_proj",
100
+ "model.layers.61.mlp.experts.17.up_proj",
101
+ "model.layers.61.mlp.experts.17.down_proj",
102
+ "model.layers.61.mlp.experts.18.gate_proj",
103
+ "model.layers.61.mlp.experts.18.up_proj",
104
+ "model.layers.61.mlp.experts.18.down_proj",
105
+ "model.layers.61.mlp.experts.19.gate_proj",
106
+ "model.layers.61.mlp.experts.19.up_proj",
107
+ "model.layers.61.mlp.experts.19.down_proj",
108
+ "model.layers.61.mlp.experts.20.gate_proj",
109
+ "model.layers.61.mlp.experts.20.up_proj",
110
+ "model.layers.61.mlp.experts.20.down_proj",
111
+ "model.layers.61.mlp.experts.21.gate_proj",
112
+ "model.layers.61.mlp.experts.21.up_proj",
113
+ "model.layers.61.mlp.experts.21.down_proj",
114
+ "model.layers.61.mlp.experts.22.gate_proj",
115
+ "model.layers.61.mlp.experts.22.up_proj",
116
+ "model.layers.61.mlp.experts.22.down_proj",
117
+ "model.layers.61.mlp.experts.23.gate_proj",
118
+ "model.layers.61.mlp.experts.23.up_proj",
119
+ "model.layers.61.mlp.experts.23.down_proj",
120
+ "model.layers.61.mlp.experts.24.gate_proj",
121
+ "model.layers.61.mlp.experts.24.up_proj",
122
+ "model.layers.61.mlp.experts.24.down_proj",
123
+ "model.layers.61.mlp.experts.25.gate_proj",
124
+ "model.layers.61.mlp.experts.25.up_proj",
125
+ "model.layers.61.mlp.experts.25.down_proj",
126
+ "model.layers.61.mlp.experts.26.gate_proj",
127
+ "model.layers.61.mlp.experts.26.up_proj",
128
+ "model.layers.61.mlp.experts.26.down_proj",
129
+ "model.layers.61.mlp.experts.27.gate_proj",
130
+ "model.layers.61.mlp.experts.27.up_proj",
131
+ "model.layers.61.mlp.experts.27.down_proj",
132
+ "model.layers.61.mlp.experts.28.gate_proj",
133
+ "model.layers.61.mlp.experts.28.up_proj",
134
+ "model.layers.61.mlp.experts.28.down_proj",
135
+ "model.layers.61.mlp.experts.29.gate_proj",
136
+ "model.layers.61.mlp.experts.29.up_proj",
137
+ "model.layers.61.mlp.experts.29.down_proj",
138
+ "model.layers.61.mlp.experts.30.gate_proj",
139
+ "model.layers.61.mlp.experts.30.up_proj",
140
+ "model.layers.61.mlp.experts.30.down_proj",
141
+ "model.layers.61.mlp.experts.31.gate_proj",
142
+ "model.layers.61.mlp.experts.31.up_proj",
143
+ "model.layers.61.mlp.experts.31.down_proj",
144
+ "model.layers.61.mlp.experts.32.gate_proj",
145
+ "model.layers.61.mlp.experts.32.up_proj",
146
+ "model.layers.61.mlp.experts.32.down_proj",
147
+ "model.layers.61.mlp.experts.33.gate_proj",
148
+ "model.layers.61.mlp.experts.33.up_proj",
149
+ "model.layers.61.mlp.experts.33.down_proj",
150
+ "model.layers.61.mlp.experts.34.gate_proj",
151
+ "model.layers.61.mlp.experts.34.up_proj",
152
+ "model.layers.61.mlp.experts.34.down_proj",
153
+ "model.layers.61.mlp.experts.35.gate_proj",
154
+ "model.layers.61.mlp.experts.35.up_proj",
155
+ "model.layers.61.mlp.experts.35.down_proj",
156
+ "model.layers.61.mlp.experts.36.gate_proj",
157
+ "model.layers.61.mlp.experts.36.up_proj",
158
+ "model.layers.61.mlp.experts.36.down_proj",
159
+ "model.layers.61.mlp.experts.37.gate_proj",
160
+ "model.layers.61.mlp.experts.37.up_proj",
161
+ "model.layers.61.mlp.experts.37.down_proj",
162
+ "model.layers.61.mlp.experts.38.gate_proj",
163
+ "model.layers.61.mlp.experts.38.up_proj",
164
+ "model.layers.61.mlp.experts.38.down_proj",
165
+ "model.layers.61.mlp.experts.39.gate_proj",
166
+ "model.layers.61.mlp.experts.39.up_proj",
167
+ "model.layers.61.mlp.experts.39.down_proj",
168
+ "model.layers.61.mlp.experts.40.gate_proj",
169
+ "model.layers.61.mlp.experts.40.up_proj",
170
+ "model.layers.61.mlp.experts.40.down_proj",
171
+ "model.layers.61.mlp.experts.41.gate_proj",
172
+ "model.layers.61.mlp.experts.41.up_proj",
173
+ "model.layers.61.mlp.experts.41.down_proj",
174
+ "model.layers.61.mlp.experts.42.gate_proj",
175
+ "model.layers.61.mlp.experts.42.up_proj",
176
+ "model.layers.61.mlp.experts.42.down_proj",
177
+ "model.layers.61.mlp.experts.43.gate_proj",
178
+ "model.layers.61.mlp.experts.43.up_proj",
179
+ "model.layers.61.mlp.experts.43.down_proj",
180
+ "model.layers.61.mlp.experts.44.gate_proj",
181
+ "model.layers.61.mlp.experts.44.up_proj",
182
+ "model.layers.61.mlp.experts.44.down_proj",
183
+ "model.layers.61.mlp.experts.45.gate_proj",
184
+ "model.layers.61.mlp.experts.45.up_proj",
185
+ "model.layers.61.mlp.experts.45.down_proj",
186
+ "model.layers.61.mlp.experts.46.gate_proj",
187
+ "model.layers.61.mlp.experts.46.up_proj",
188
+ "model.layers.61.mlp.experts.46.down_proj",
189
+ "model.layers.61.mlp.experts.47.gate_proj",
190
+ "model.layers.61.mlp.experts.47.up_proj",
191
+ "model.layers.61.mlp.experts.47.down_proj",
192
+ "model.layers.61.mlp.experts.48.gate_proj",
193
+ "model.layers.61.mlp.experts.48.up_proj",
194
+ "model.layers.61.mlp.experts.48.down_proj",
195
+ "model.layers.61.mlp.experts.49.gate_proj",
196
+ "model.layers.61.mlp.experts.49.up_proj",
197
+ "model.layers.61.mlp.experts.49.down_proj",
198
+ "model.layers.61.mlp.experts.50.gate_proj",
199
+ "model.layers.61.mlp.experts.50.up_proj",
200
+ "model.layers.61.mlp.experts.50.down_proj",
201
+ "model.layers.61.mlp.experts.51.gate_proj",
202
+ "model.layers.61.mlp.experts.51.up_proj",
203
+ "model.layers.61.mlp.experts.51.down_proj",
204
+ "model.layers.61.mlp.experts.52.gate_proj",
205
+ "model.layers.61.mlp.experts.52.up_proj",
206
+ "model.layers.61.mlp.experts.52.down_proj",
207
+ "model.layers.61.mlp.experts.53.gate_proj",
208
+ "model.layers.61.mlp.experts.53.up_proj",
209
+ "model.layers.61.mlp.experts.53.down_proj",
210
+ "model.layers.61.mlp.experts.54.gate_proj",
211
+ "model.layers.61.mlp.experts.54.up_proj",
212
+ "model.layers.61.mlp.experts.54.down_proj",
213
+ "model.layers.61.mlp.experts.55.gate_proj",
214
+ "model.layers.61.mlp.experts.55.up_proj",
215
+ "model.layers.61.mlp.experts.55.down_proj",
216
+ "model.layers.61.mlp.experts.56.gate_proj",
217
+ "model.layers.61.mlp.experts.56.up_proj",
218
+ "model.layers.61.mlp.experts.56.down_proj",
219
+ "model.layers.61.mlp.experts.57.gate_proj",
220
+ "model.layers.61.mlp.experts.57.up_proj",
221
+ "model.layers.61.mlp.experts.57.down_proj",
222
+ "model.layers.61.mlp.experts.58.gate_proj",
223
+ "model.layers.61.mlp.experts.58.up_proj",
224
+ "model.layers.61.mlp.experts.58.down_proj",
225
+ "model.layers.61.mlp.experts.59.gate_proj",
226
+ "model.layers.61.mlp.experts.59.up_proj",
227
+ "model.layers.61.mlp.experts.59.down_proj",
228
+ "model.layers.61.mlp.experts.60.gate_proj",
229
+ "model.layers.61.mlp.experts.60.up_proj",
230
+ "model.layers.61.mlp.experts.60.down_proj",
231
+ "model.layers.61.mlp.experts.61.gate_proj",
232
+ "model.layers.61.mlp.experts.61.up_proj",
233
+ "model.layers.61.mlp.experts.61.down_proj",
234
+ "model.layers.61.mlp.experts.62.gate_proj",
235
+ "model.layers.61.mlp.experts.62.up_proj",
236
+ "model.layers.61.mlp.experts.62.down_proj",
237
+ "model.layers.61.mlp.experts.63.gate_proj",
238
+ "model.layers.61.mlp.experts.63.up_proj",
239
+ "model.layers.61.mlp.experts.63.down_proj",
240
+ "model.layers.61.mlp.experts.64.gate_proj",
241
+ "model.layers.61.mlp.experts.64.up_proj",
242
+ "model.layers.61.mlp.experts.64.down_proj",
243
+ "model.layers.61.mlp.experts.65.gate_proj",
244
+ "model.layers.61.mlp.experts.65.up_proj",
245
+ "model.layers.61.mlp.experts.65.down_proj",
246
+ "model.layers.61.mlp.experts.66.gate_proj",
247
+ "model.layers.61.mlp.experts.66.up_proj",
248
+ "model.layers.61.mlp.experts.66.down_proj",
249
+ "model.layers.61.mlp.experts.67.gate_proj",
250
+ "model.layers.61.mlp.experts.67.up_proj",
251
+ "model.layers.61.mlp.experts.67.down_proj",
252
+ "model.layers.61.mlp.experts.68.gate_proj",
253
+ "model.layers.61.mlp.experts.68.up_proj",
254
+ "model.layers.61.mlp.experts.68.down_proj",
255
+ "model.layers.61.mlp.experts.69.gate_proj",
256
+ "model.layers.61.mlp.experts.69.up_proj",
257
+ "model.layers.61.mlp.experts.69.down_proj",
258
+ "model.layers.61.mlp.experts.70.gate_proj",
259
+ "model.layers.61.mlp.experts.70.up_proj",
260
+ "model.layers.61.mlp.experts.70.down_proj",
261
+ "model.layers.61.mlp.experts.71.gate_proj",
262
+ "model.layers.61.mlp.experts.71.up_proj",
263
+ "model.layers.61.mlp.experts.71.down_proj",
264
+ "model.layers.61.mlp.experts.72.gate_proj",
265
+ "model.layers.61.mlp.experts.72.up_proj",
266
+ "model.layers.61.mlp.experts.72.down_proj",
267
+ "model.layers.61.mlp.experts.73.gate_proj",
268
+ "model.layers.61.mlp.experts.73.up_proj",
269
+ "model.layers.61.mlp.experts.73.down_proj",
270
+ "model.layers.61.mlp.experts.74.gate_proj",
271
+ "model.layers.61.mlp.experts.74.up_proj",
272
+ "model.layers.61.mlp.experts.74.down_proj",
273
+ "model.layers.61.mlp.experts.75.gate_proj",
274
+ "model.layers.61.mlp.experts.75.up_proj",
275
+ "model.layers.61.mlp.experts.75.down_proj",
276
+ "model.layers.61.mlp.experts.76.gate_proj",
277
+ "model.layers.61.mlp.experts.76.up_proj",
278
+ "model.layers.61.mlp.experts.76.down_proj",
279
+ "model.layers.61.mlp.experts.77.gate_proj",
280
+ "model.layers.61.mlp.experts.77.up_proj",
281
+ "model.layers.61.mlp.experts.77.down_proj",
282
+ "model.layers.61.mlp.experts.78.gate_proj",
283
+ "model.layers.61.mlp.experts.78.up_proj",
284
+ "model.layers.61.mlp.experts.78.down_proj",
285
+ "model.layers.61.mlp.experts.79.gate_proj",
286
+ "model.layers.61.mlp.experts.79.up_proj",
287
+ "model.layers.61.mlp.experts.79.down_proj",
288
+ "model.layers.61.mlp.experts.80.gate_proj",
289
+ "model.layers.61.mlp.experts.80.up_proj",
290
+ "model.layers.61.mlp.experts.80.down_proj",
291
+ "model.layers.61.mlp.experts.81.gate_proj",
292
+ "model.layers.61.mlp.experts.81.up_proj",
293
+ "model.layers.61.mlp.experts.81.down_proj",
294
+ "model.layers.61.mlp.experts.82.gate_proj",
295
+ "model.layers.61.mlp.experts.82.up_proj",
296
+ "model.layers.61.mlp.experts.82.down_proj",
297
+ "model.layers.61.mlp.experts.83.gate_proj",
298
+ "model.layers.61.mlp.experts.83.up_proj",
299
+ "model.layers.61.mlp.experts.83.down_proj",
300
+ "model.layers.61.mlp.experts.84.gate_proj",
301
+ "model.layers.61.mlp.experts.84.up_proj",
302
+ "model.layers.61.mlp.experts.84.down_proj",
303
+ "model.layers.61.mlp.experts.85.gate_proj",
304
+ "model.layers.61.mlp.experts.85.up_proj",
305
+ "model.layers.61.mlp.experts.85.down_proj",
306
+ "model.layers.61.mlp.experts.86.gate_proj",
307
+ "model.layers.61.mlp.experts.86.up_proj",
308
+ "model.layers.61.mlp.experts.86.down_proj",
309
+ "model.layers.61.mlp.experts.87.gate_proj",
310
+ "model.layers.61.mlp.experts.87.up_proj",
311
+ "model.layers.61.mlp.experts.87.down_proj",
312
+ "model.layers.61.mlp.experts.88.gate_proj",
313
+ "model.layers.61.mlp.experts.88.up_proj",
314
+ "model.layers.61.mlp.experts.88.down_proj",
315
+ "model.layers.61.mlp.experts.89.gate_proj",
316
+ "model.layers.61.mlp.experts.89.up_proj",
317
+ "model.layers.61.mlp.experts.89.down_proj",
318
+ "model.layers.61.mlp.experts.90.gate_proj",
319
+ "model.layers.61.mlp.experts.90.up_proj",
320
+ "model.layers.61.mlp.experts.90.down_proj",
321
+ "model.layers.61.mlp.experts.91.gate_proj",
322
+ "model.layers.61.mlp.experts.91.up_proj",
323
+ "model.layers.61.mlp.experts.91.down_proj",
324
+ "model.layers.61.mlp.experts.92.gate_proj",
325
+ "model.layers.61.mlp.experts.92.up_proj",
326
+ "model.layers.61.mlp.experts.92.down_proj",
327
+ "model.layers.61.mlp.experts.93.gate_proj",
328
+ "model.layers.61.mlp.experts.93.up_proj",
329
+ "model.layers.61.mlp.experts.93.down_proj",
330
+ "model.layers.61.mlp.experts.94.gate_proj",
331
+ "model.layers.61.mlp.experts.94.up_proj",
332
+ "model.layers.61.mlp.experts.94.down_proj",
333
+ "model.layers.61.mlp.experts.95.gate_proj",
334
+ "model.layers.61.mlp.experts.95.up_proj",
335
+ "model.layers.61.mlp.experts.95.down_proj",
336
+ "model.layers.61.mlp.experts.96.gate_proj",
337
+ "model.layers.61.mlp.experts.96.up_proj",
338
+ "model.layers.61.mlp.experts.96.down_proj",
339
+ "model.layers.61.mlp.experts.97.gate_proj",
340
+ "model.layers.61.mlp.experts.97.up_proj",
341
+ "model.layers.61.mlp.experts.97.down_proj",
342
+ "model.layers.61.mlp.experts.98.gate_proj",
343
+ "model.layers.61.mlp.experts.98.up_proj",
344
+ "model.layers.61.mlp.experts.98.down_proj",
345
+ "model.layers.61.mlp.experts.99.gate_proj",
346
+ "model.layers.61.mlp.experts.99.up_proj",
347
+ "model.layers.61.mlp.experts.99.down_proj",
348
+ "model.layers.61.mlp.experts.100.gate_proj",
349
+ "model.layers.61.mlp.experts.100.up_proj",
350
+ "model.layers.61.mlp.experts.100.down_proj",
351
+ "model.layers.61.mlp.experts.101.gate_proj",
352
+ "model.layers.61.mlp.experts.101.up_proj",
353
+ "model.layers.61.mlp.experts.101.down_proj",
354
+ "model.layers.61.mlp.experts.102.gate_proj",
355
+ "model.layers.61.mlp.experts.102.up_proj",
356
+ "model.layers.61.mlp.experts.102.down_proj",
357
+ "model.layers.61.mlp.experts.103.gate_proj",
358
+ "model.layers.61.mlp.experts.103.up_proj",
359
+ "model.layers.61.mlp.experts.103.down_proj",
360
+ "model.layers.61.mlp.experts.104.gate_proj",
361
+ "model.layers.61.mlp.experts.104.up_proj",
362
+ "model.layers.61.mlp.experts.104.down_proj",
363
+ "model.layers.61.mlp.experts.105.gate_proj",
364
+ "model.layers.61.mlp.experts.105.up_proj",
365
+ "model.layers.61.mlp.experts.105.down_proj",
366
+ "model.layers.61.mlp.experts.106.gate_proj",
367
+ "model.layers.61.mlp.experts.106.up_proj",
368
+ "model.layers.61.mlp.experts.106.down_proj",
369
+ "model.layers.61.mlp.experts.107.gate_proj",
370
+ "model.layers.61.mlp.experts.107.up_proj",
371
+ "model.layers.61.mlp.experts.107.down_proj",
372
+ "model.layers.61.mlp.experts.108.gate_proj",
373
+ "model.layers.61.mlp.experts.108.up_proj",
374
+ "model.layers.61.mlp.experts.108.down_proj",
375
+ "model.layers.61.mlp.experts.109.gate_proj",
376
+ "model.layers.61.mlp.experts.109.up_proj",
377
+ "model.layers.61.mlp.experts.109.down_proj",
378
+ "model.layers.61.mlp.experts.110.gate_proj",
379
+ "model.layers.61.mlp.experts.110.up_proj",
380
+ "model.layers.61.mlp.experts.110.down_proj",
381
+ "model.layers.61.mlp.experts.111.gate_proj",
382
+ "model.layers.61.mlp.experts.111.up_proj",
383
+ "model.layers.61.mlp.experts.111.down_proj",
384
+ "model.layers.61.mlp.experts.112.gate_proj",
385
+ "model.layers.61.mlp.experts.112.up_proj",
386
+ "model.layers.61.mlp.experts.112.down_proj",
387
+ "model.layers.61.mlp.experts.113.gate_proj",
388
+ "model.layers.61.mlp.experts.113.up_proj",
389
+ "model.layers.61.mlp.experts.113.down_proj",
390
+ "model.layers.61.mlp.experts.114.gate_proj",
391
+ "model.layers.61.mlp.experts.114.up_proj",
392
+ "model.layers.61.mlp.experts.114.down_proj",
393
+ "model.layers.61.mlp.experts.115.gate_proj",
394
+ "model.layers.61.mlp.experts.115.up_proj",
395
+ "model.layers.61.mlp.experts.115.down_proj",
396
+ "model.layers.61.mlp.experts.116.gate_proj",
397
+ "model.layers.61.mlp.experts.116.up_proj",
398
+ "model.layers.61.mlp.experts.116.down_proj",
399
+ "model.layers.61.mlp.experts.117.gate_proj",
400
+ "model.layers.61.mlp.experts.117.up_proj",
401
+ "model.layers.61.mlp.experts.117.down_proj",
402
+ "model.layers.61.mlp.experts.118.gate_proj",
403
+ "model.layers.61.mlp.experts.118.up_proj",
404
+ "model.layers.61.mlp.experts.118.down_proj",
405
+ "model.layers.61.mlp.experts.119.gate_proj",
406
+ "model.layers.61.mlp.experts.119.up_proj",
407
+ "model.layers.61.mlp.experts.119.down_proj",
408
+ "model.layers.61.mlp.experts.120.gate_proj",
409
+ "model.layers.61.mlp.experts.120.up_proj",
410
+ "model.layers.61.mlp.experts.120.down_proj",
411
+ "model.layers.61.mlp.experts.121.gate_proj",
412
+ "model.layers.61.mlp.experts.121.up_proj",
413
+ "model.layers.61.mlp.experts.121.down_proj",
414
+ "model.layers.61.mlp.experts.122.gate_proj",
415
+ "model.layers.61.mlp.experts.122.up_proj",
416
+ "model.layers.61.mlp.experts.122.down_proj",
417
+ "model.layers.61.mlp.experts.123.gate_proj",
418
+ "model.layers.61.mlp.experts.123.up_proj",
419
+ "model.layers.61.mlp.experts.123.down_proj",
420
+ "model.layers.61.mlp.experts.124.gate_proj",
421
+ "model.layers.61.mlp.experts.124.up_proj",
422
+ "model.layers.61.mlp.experts.124.down_proj",
423
+ "model.layers.61.mlp.experts.125.gate_proj",
424
+ "model.layers.61.mlp.experts.125.up_proj",
425
+ "model.layers.61.mlp.experts.125.down_proj",
426
+ "model.layers.61.mlp.experts.126.gate_proj",
427
+ "model.layers.61.mlp.experts.126.up_proj",
428
+ "model.layers.61.mlp.experts.126.down_proj",
429
+ "model.layers.61.mlp.experts.127.gate_proj",
430
+ "model.layers.61.mlp.experts.127.up_proj",
431
+ "model.layers.61.mlp.experts.127.down_proj",
432
+ "model.layers.61.mlp.experts.128.gate_proj",
433
+ "model.layers.61.mlp.experts.128.up_proj",
434
+ "model.layers.61.mlp.experts.128.down_proj",
435
+ "model.layers.61.mlp.experts.129.gate_proj",
436
+ "model.layers.61.mlp.experts.129.up_proj",
437
+ "model.layers.61.mlp.experts.129.down_proj",
438
+ "model.layers.61.mlp.experts.130.gate_proj",
439
+ "model.layers.61.mlp.experts.130.up_proj",
440
+ "model.layers.61.mlp.experts.130.down_proj",
441
+ "model.layers.61.mlp.experts.131.gate_proj",
442
+ "model.layers.61.mlp.experts.131.up_proj",
443
+ "model.layers.61.mlp.experts.131.down_proj",
444
+ "model.layers.61.mlp.experts.132.gate_proj",
445
+ "model.layers.61.mlp.experts.132.up_proj",
446
+ "model.layers.61.mlp.experts.132.down_proj",
447
+ "model.layers.61.mlp.experts.133.gate_proj",
448
+ "model.layers.61.mlp.experts.133.up_proj",
449
+ "model.layers.61.mlp.experts.133.down_proj",
450
+ "model.layers.61.mlp.experts.134.gate_proj",
451
+ "model.layers.61.mlp.experts.134.up_proj",
452
+ "model.layers.61.mlp.experts.134.down_proj",
453
+ "model.layers.61.mlp.experts.135.gate_proj",
454
+ "model.layers.61.mlp.experts.135.up_proj",
455
+ "model.layers.61.mlp.experts.135.down_proj",
456
+ "model.layers.61.mlp.experts.136.gate_proj",
457
+ "model.layers.61.mlp.experts.136.up_proj",
458
+ "model.layers.61.mlp.experts.136.down_proj",
459
+ "model.layers.61.mlp.experts.137.gate_proj",
460
+ "model.layers.61.mlp.experts.137.up_proj",
461
+ "model.layers.61.mlp.experts.137.down_proj",
462
+ "model.layers.61.mlp.experts.138.gate_proj",
463
+ "model.layers.61.mlp.experts.138.up_proj",
464
+ "model.layers.61.mlp.experts.138.down_proj",
465
+ "model.layers.61.mlp.experts.139.gate_proj",
466
+ "model.layers.61.mlp.experts.139.up_proj",
467
+ "model.layers.61.mlp.experts.139.down_proj",
468
+ "model.layers.61.mlp.experts.140.gate_proj",
469
+ "model.layers.61.mlp.experts.140.up_proj",
470
+ "model.layers.61.mlp.experts.140.down_proj",
471
+ "model.layers.61.mlp.experts.141.gate_proj",
472
+ "model.layers.61.mlp.experts.141.up_proj",
473
+ "model.layers.61.mlp.experts.141.down_proj",
474
+ "model.layers.61.mlp.experts.142.gate_proj",
475
+ "model.layers.61.mlp.experts.142.up_proj",
476
+ "model.layers.61.mlp.experts.142.down_proj",
477
+ "model.layers.61.mlp.experts.143.gate_proj",
478
+ "model.layers.61.mlp.experts.143.up_proj",
479
+ "model.layers.61.mlp.experts.143.down_proj",
480
+ "model.layers.61.mlp.experts.144.gate_proj",
481
+ "model.layers.61.mlp.experts.144.up_proj",
482
+ "model.layers.61.mlp.experts.144.down_proj",
483
+ "model.layers.61.mlp.experts.145.gate_proj",
484
+ "model.layers.61.mlp.experts.145.up_proj",
485
+ "model.layers.61.mlp.experts.145.down_proj",
486
+ "model.layers.61.mlp.experts.146.gate_proj",
487
+ "model.layers.61.mlp.experts.146.up_proj",
488
+ "model.layers.61.mlp.experts.146.down_proj",
489
+ "model.layers.61.mlp.experts.147.gate_proj",
490
+ "model.layers.61.mlp.experts.147.up_proj",
491
+ "model.layers.61.mlp.experts.147.down_proj",
492
+ "model.layers.61.mlp.experts.148.gate_proj",
493
+ "model.layers.61.mlp.experts.148.up_proj",
494
+ "model.layers.61.mlp.experts.148.down_proj",
495
+ "model.layers.61.mlp.experts.149.gate_proj",
496
+ "model.layers.61.mlp.experts.149.up_proj",
497
+ "model.layers.61.mlp.experts.149.down_proj",
498
+ "model.layers.61.mlp.experts.150.gate_proj",
499
+ "model.layers.61.mlp.experts.150.up_proj",
500
+ "model.layers.61.mlp.experts.150.down_proj",
501
+ "model.layers.61.mlp.experts.151.gate_proj",
502
+ "model.layers.61.mlp.experts.151.up_proj",
503
+ "model.layers.61.mlp.experts.151.down_proj",
504
+ "model.layers.61.mlp.experts.152.gate_proj",
505
+ "model.layers.61.mlp.experts.152.up_proj",
506
+ "model.layers.61.mlp.experts.152.down_proj",
507
+ "model.layers.61.mlp.experts.153.gate_proj",
508
+ "model.layers.61.mlp.experts.153.up_proj",
509
+ "model.layers.61.mlp.experts.153.down_proj",
510
+ "model.layers.61.mlp.experts.154.gate_proj",
511
+ "model.layers.61.mlp.experts.154.up_proj",
512
+ "model.layers.61.mlp.experts.154.down_proj",
513
+ "model.layers.61.mlp.experts.155.gate_proj",
514
+ "model.layers.61.mlp.experts.155.up_proj",
515
+ "model.layers.61.mlp.experts.155.down_proj",
516
+ "model.layers.61.mlp.experts.156.gate_proj",
517
+ "model.layers.61.mlp.experts.156.up_proj",
518
+ "model.layers.61.mlp.experts.156.down_proj",
519
+ "model.layers.61.mlp.experts.157.gate_proj",
520
+ "model.layers.61.mlp.experts.157.up_proj",
521
+ "model.layers.61.mlp.experts.157.down_proj",
522
+ "model.layers.61.mlp.experts.158.gate_proj",
523
+ "model.layers.61.mlp.experts.158.up_proj",
524
+ "model.layers.61.mlp.experts.158.down_proj",
525
+ "model.layers.61.mlp.experts.159.gate_proj",
526
+ "model.layers.61.mlp.experts.159.up_proj",
527
+ "model.layers.61.mlp.experts.159.down_proj",
528
+ "model.layers.61.mlp.experts.160.gate_proj",
529
+ "model.layers.61.mlp.experts.160.up_proj",
530
+ "model.layers.61.mlp.experts.160.down_proj",
531
+ "model.layers.61.mlp.experts.161.gate_proj",
532
+ "model.layers.61.mlp.experts.161.up_proj",
533
+ "model.layers.61.mlp.experts.161.down_proj",
534
+ "model.layers.61.mlp.experts.162.gate_proj",
535
+ "model.layers.61.mlp.experts.162.up_proj",
536
+ "model.layers.61.mlp.experts.162.down_proj",
537
+ "model.layers.61.mlp.experts.163.gate_proj",
538
+ "model.layers.61.mlp.experts.163.up_proj",
539
+ "model.layers.61.mlp.experts.163.down_proj",
540
+ "model.layers.61.mlp.experts.164.gate_proj",
541
+ "model.layers.61.mlp.experts.164.up_proj",
542
+ "model.layers.61.mlp.experts.164.down_proj",
543
+ "model.layers.61.mlp.experts.165.gate_proj",
544
+ "model.layers.61.mlp.experts.165.up_proj",
545
+ "model.layers.61.mlp.experts.165.down_proj",
546
+ "model.layers.61.mlp.experts.166.gate_proj",
547
+ "model.layers.61.mlp.experts.166.up_proj",
548
+ "model.layers.61.mlp.experts.166.down_proj",
549
+ "model.layers.61.mlp.experts.167.gate_proj",
550
+ "model.layers.61.mlp.experts.167.up_proj",
551
+ "model.layers.61.mlp.experts.167.down_proj",
552
+ "model.layers.61.mlp.experts.168.gate_proj",
553
+ "model.layers.61.mlp.experts.168.up_proj",
554
+ "model.layers.61.mlp.experts.168.down_proj",
555
+ "model.layers.61.mlp.experts.169.gate_proj",
556
+ "model.layers.61.mlp.experts.169.up_proj",
557
+ "model.layers.61.mlp.experts.169.down_proj",
558
+ "model.layers.61.mlp.experts.170.gate_proj",
559
+ "model.layers.61.mlp.experts.170.up_proj",
560
+ "model.layers.61.mlp.experts.170.down_proj",
561
+ "model.layers.61.mlp.experts.171.gate_proj",
562
+ "model.layers.61.mlp.experts.171.up_proj",
563
+ "model.layers.61.mlp.experts.171.down_proj",
564
+ "model.layers.61.mlp.experts.172.gate_proj",
565
+ "model.layers.61.mlp.experts.172.up_proj",
566
+ "model.layers.61.mlp.experts.172.down_proj",
567
+ "model.layers.61.mlp.experts.173.gate_proj",
568
+ "model.layers.61.mlp.experts.173.up_proj",
569
+ "model.layers.61.mlp.experts.173.down_proj",
570
+ "model.layers.61.mlp.experts.174.gate_proj",
571
+ "model.layers.61.mlp.experts.174.up_proj",
572
+ "model.layers.61.mlp.experts.174.down_proj",
573
+ "model.layers.61.mlp.experts.175.gate_proj",
574
+ "model.layers.61.mlp.experts.175.up_proj",
575
+ "model.layers.61.mlp.experts.175.down_proj",
576
+ "model.layers.61.mlp.experts.176.gate_proj",
577
+ "model.layers.61.mlp.experts.176.up_proj",
578
+ "model.layers.61.mlp.experts.176.down_proj",
579
+ "model.layers.61.mlp.experts.177.gate_proj",
580
+ "model.layers.61.mlp.experts.177.up_proj",
581
+ "model.layers.61.mlp.experts.177.down_proj",
582
+ "model.layers.61.mlp.experts.178.gate_proj",
583
+ "model.layers.61.mlp.experts.178.up_proj",
584
+ "model.layers.61.mlp.experts.178.down_proj",
585
+ "model.layers.61.mlp.experts.179.gate_proj",
586
+ "model.layers.61.mlp.experts.179.up_proj",
587
+ "model.layers.61.mlp.experts.179.down_proj",
588
+ "model.layers.61.mlp.experts.180.gate_proj",
589
+ "model.layers.61.mlp.experts.180.up_proj",
590
+ "model.layers.61.mlp.experts.180.down_proj",
591
+ "model.layers.61.mlp.experts.181.gate_proj",
592
+ "model.layers.61.mlp.experts.181.up_proj",
593
+ "model.layers.61.mlp.experts.181.down_proj",
594
+ "model.layers.61.mlp.experts.182.gate_proj",
595
+ "model.layers.61.mlp.experts.182.up_proj",
596
+ "model.layers.61.mlp.experts.182.down_proj",
597
+ "model.layers.61.mlp.experts.183.gate_proj",
598
+ "model.layers.61.mlp.experts.183.up_proj",
599
+ "model.layers.61.mlp.experts.183.down_proj",
600
+ "model.layers.61.mlp.experts.184.gate_proj",
601
+ "model.layers.61.mlp.experts.184.up_proj",
602
+ "model.layers.61.mlp.experts.184.down_proj",
603
+ "model.layers.61.mlp.experts.185.gate_proj",
604
+ "model.layers.61.mlp.experts.185.up_proj",
605
+ "model.layers.61.mlp.experts.185.down_proj",
606
+ "model.layers.61.mlp.experts.186.gate_proj",
607
+ "model.layers.61.mlp.experts.186.up_proj",
608
+ "model.layers.61.mlp.experts.186.down_proj",
609
+ "model.layers.61.mlp.experts.187.gate_proj",
610
+ "model.layers.61.mlp.experts.187.up_proj",
611
+ "model.layers.61.mlp.experts.187.down_proj",
612
+ "model.layers.61.mlp.experts.188.gate_proj",
613
+ "model.layers.61.mlp.experts.188.up_proj",
614
+ "model.layers.61.mlp.experts.188.down_proj",
615
+ "model.layers.61.mlp.experts.189.gate_proj",
616
+ "model.layers.61.mlp.experts.189.up_proj",
617
+ "model.layers.61.mlp.experts.189.down_proj",
618
+ "model.layers.61.mlp.experts.190.gate_proj",
619
+ "model.layers.61.mlp.experts.190.up_proj",
620
+ "model.layers.61.mlp.experts.190.down_proj",
621
+ "model.layers.61.mlp.experts.191.gate_proj",
622
+ "model.layers.61.mlp.experts.191.up_proj",
623
+ "model.layers.61.mlp.experts.191.down_proj",
624
+ "model.layers.61.mlp.experts.192.gate_proj",
625
+ "model.layers.61.mlp.experts.192.up_proj",
626
+ "model.layers.61.mlp.experts.192.down_proj",
627
+ "model.layers.61.mlp.experts.193.gate_proj",
628
+ "model.layers.61.mlp.experts.193.up_proj",
629
+ "model.layers.61.mlp.experts.193.down_proj",
630
+ "model.layers.61.mlp.experts.194.gate_proj",
631
+ "model.layers.61.mlp.experts.194.up_proj",
632
+ "model.layers.61.mlp.experts.194.down_proj",
633
+ "model.layers.61.mlp.experts.195.gate_proj",
634
+ "model.layers.61.mlp.experts.195.up_proj",
635
+ "model.layers.61.mlp.experts.195.down_proj",
636
+ "model.layers.61.mlp.experts.196.gate_proj",
637
+ "model.layers.61.mlp.experts.196.up_proj",
638
+ "model.layers.61.mlp.experts.196.down_proj",
639
+ "model.layers.61.mlp.experts.197.gate_proj",
640
+ "model.layers.61.mlp.experts.197.up_proj",
641
+ "model.layers.61.mlp.experts.197.down_proj",
642
+ "model.layers.61.mlp.experts.198.gate_proj",
643
+ "model.layers.61.mlp.experts.198.up_proj",
644
+ "model.layers.61.mlp.experts.198.down_proj",
645
+ "model.layers.61.mlp.experts.199.gate_proj",
646
+ "model.layers.61.mlp.experts.199.up_proj",
647
+ "model.layers.61.mlp.experts.199.down_proj",
648
+ "model.layers.61.mlp.experts.200.gate_proj",
649
+ "model.layers.61.mlp.experts.200.up_proj",
650
+ "model.layers.61.mlp.experts.200.down_proj",
651
+ "model.layers.61.mlp.experts.201.gate_proj",
652
+ "model.layers.61.mlp.experts.201.up_proj",
653
+ "model.layers.61.mlp.experts.201.down_proj",
654
+ "model.layers.61.mlp.experts.202.gate_proj",
655
+ "model.layers.61.mlp.experts.202.up_proj",
656
+ "model.layers.61.mlp.experts.202.down_proj",
657
+ "model.layers.61.mlp.experts.203.gate_proj",
658
+ "model.layers.61.mlp.experts.203.up_proj",
659
+ "model.layers.61.mlp.experts.203.down_proj",
660
+ "model.layers.61.mlp.experts.204.gate_proj",
661
+ "model.layers.61.mlp.experts.204.up_proj",
662
+ "model.layers.61.mlp.experts.204.down_proj",
663
+ "model.layers.61.mlp.experts.205.gate_proj",
664
+ "model.layers.61.mlp.experts.205.up_proj",
665
+ "model.layers.61.mlp.experts.205.down_proj",
666
+ "model.layers.61.mlp.experts.206.gate_proj",
667
+ "model.layers.61.mlp.experts.206.up_proj",
668
+ "model.layers.61.mlp.experts.206.down_proj",
669
+ "model.layers.61.mlp.experts.207.gate_proj",
670
+ "model.layers.61.mlp.experts.207.up_proj",
671
+ "model.layers.61.mlp.experts.207.down_proj",
672
+ "model.layers.61.mlp.experts.208.gate_proj",
673
+ "model.layers.61.mlp.experts.208.up_proj",
674
+ "model.layers.61.mlp.experts.208.down_proj",
675
+ "model.layers.61.mlp.experts.209.gate_proj",
676
+ "model.layers.61.mlp.experts.209.up_proj",
677
+ "model.layers.61.mlp.experts.209.down_proj",
678
+ "model.layers.61.mlp.experts.210.gate_proj",
679
+ "model.layers.61.mlp.experts.210.up_proj",
680
+ "model.layers.61.mlp.experts.210.down_proj",
681
+ "model.layers.61.mlp.experts.211.gate_proj",
682
+ "model.layers.61.mlp.experts.211.up_proj",
683
+ "model.layers.61.mlp.experts.211.down_proj",
684
+ "model.layers.61.mlp.experts.212.gate_proj",
685
+ "model.layers.61.mlp.experts.212.up_proj",
686
+ "model.layers.61.mlp.experts.212.down_proj",
687
+ "model.layers.61.mlp.experts.213.gate_proj",
688
+ "model.layers.61.mlp.experts.213.up_proj",
689
+ "model.layers.61.mlp.experts.213.down_proj",
690
+ "model.layers.61.mlp.experts.214.gate_proj",
691
+ "model.layers.61.mlp.experts.214.up_proj",
692
+ "model.layers.61.mlp.experts.214.down_proj",
693
+ "model.layers.61.mlp.experts.215.gate_proj",
694
+ "model.layers.61.mlp.experts.215.up_proj",
695
+ "model.layers.61.mlp.experts.215.down_proj",
696
+ "model.layers.61.mlp.experts.216.gate_proj",
697
+ "model.layers.61.mlp.experts.216.up_proj",
698
+ "model.layers.61.mlp.experts.216.down_proj",
699
+ "model.layers.61.mlp.experts.217.gate_proj",
700
+ "model.layers.61.mlp.experts.217.up_proj",
701
+ "model.layers.61.mlp.experts.217.down_proj",
702
+ "model.layers.61.mlp.experts.218.gate_proj",
703
+ "model.layers.61.mlp.experts.218.up_proj",
704
+ "model.layers.61.mlp.experts.218.down_proj",
705
+ "model.layers.61.mlp.experts.219.gate_proj",
706
+ "model.layers.61.mlp.experts.219.up_proj",
707
+ "model.layers.61.mlp.experts.219.down_proj",
708
+ "model.layers.61.mlp.experts.220.gate_proj",
709
+ "model.layers.61.mlp.experts.220.up_proj",
710
+ "model.layers.61.mlp.experts.220.down_proj",
711
+ "model.layers.61.mlp.experts.221.gate_proj",
712
+ "model.layers.61.mlp.experts.221.up_proj",
713
+ "model.layers.61.mlp.experts.221.down_proj",
714
+ "model.layers.61.mlp.experts.222.gate_proj",
715
+ "model.layers.61.mlp.experts.222.up_proj",
716
+ "model.layers.61.mlp.experts.222.down_proj",
717
+ "model.layers.61.mlp.experts.223.gate_proj",
718
+ "model.layers.61.mlp.experts.223.up_proj",
719
+ "model.layers.61.mlp.experts.223.down_proj",
720
+ "model.layers.61.mlp.experts.224.gate_proj",
721
+ "model.layers.61.mlp.experts.224.up_proj",
722
+ "model.layers.61.mlp.experts.224.down_proj",
723
+ "model.layers.61.mlp.experts.225.gate_proj",
724
+ "model.layers.61.mlp.experts.225.up_proj",
725
+ "model.layers.61.mlp.experts.225.down_proj",
726
+ "model.layers.61.mlp.experts.226.gate_proj",
727
+ "model.layers.61.mlp.experts.226.up_proj",
728
+ "model.layers.61.mlp.experts.226.down_proj",
729
+ "model.layers.61.mlp.experts.227.gate_proj",
730
+ "model.layers.61.mlp.experts.227.up_proj",
731
+ "model.layers.61.mlp.experts.227.down_proj",
732
+ "model.layers.61.mlp.experts.228.gate_proj",
733
+ "model.layers.61.mlp.experts.228.up_proj",
734
+ "model.layers.61.mlp.experts.228.down_proj",
735
+ "model.layers.61.mlp.experts.229.gate_proj",
736
+ "model.layers.61.mlp.experts.229.up_proj",
737
+ "model.layers.61.mlp.experts.229.down_proj",
738
+ "model.layers.61.mlp.experts.230.gate_proj",
739
+ "model.layers.61.mlp.experts.230.up_proj",
740
+ "model.layers.61.mlp.experts.230.down_proj",
741
+ "model.layers.61.mlp.experts.231.gate_proj",
742
+ "model.layers.61.mlp.experts.231.up_proj",
743
+ "model.layers.61.mlp.experts.231.down_proj",
744
+ "model.layers.61.mlp.experts.232.gate_proj",
745
+ "model.layers.61.mlp.experts.232.up_proj",
746
+ "model.layers.61.mlp.experts.232.down_proj",
747
+ "model.layers.61.mlp.experts.233.gate_proj",
748
+ "model.layers.61.mlp.experts.233.up_proj",
749
+ "model.layers.61.mlp.experts.233.down_proj",
750
+ "model.layers.61.mlp.experts.234.gate_proj",
751
+ "model.layers.61.mlp.experts.234.up_proj",
752
+ "model.layers.61.mlp.experts.234.down_proj",
753
+ "model.layers.61.mlp.experts.235.gate_proj",
754
+ "model.layers.61.mlp.experts.235.up_proj",
755
+ "model.layers.61.mlp.experts.235.down_proj",
756
+ "model.layers.61.mlp.experts.236.gate_proj",
757
+ "model.layers.61.mlp.experts.236.up_proj",
758
+ "model.layers.61.mlp.experts.236.down_proj",
759
+ "model.layers.61.mlp.experts.237.gate_proj",
760
+ "model.layers.61.mlp.experts.237.up_proj",
761
+ "model.layers.61.mlp.experts.237.down_proj",
762
+ "model.layers.61.mlp.experts.238.gate_proj",
763
+ "model.layers.61.mlp.experts.238.up_proj",
764
+ "model.layers.61.mlp.experts.238.down_proj",
765
+ "model.layers.61.mlp.experts.239.gate_proj",
766
+ "model.layers.61.mlp.experts.239.up_proj",
767
+ "model.layers.61.mlp.experts.239.down_proj",
768
+ "model.layers.61.mlp.experts.240.gate_proj",
769
+ "model.layers.61.mlp.experts.240.up_proj",
770
+ "model.layers.61.mlp.experts.240.down_proj",
771
+ "model.layers.61.mlp.experts.241.gate_proj",
772
+ "model.layers.61.mlp.experts.241.up_proj",
773
+ "model.layers.61.mlp.experts.241.down_proj",
774
+ "model.layers.61.mlp.experts.242.gate_proj",
775
+ "model.layers.61.mlp.experts.242.up_proj",
776
+ "model.layers.61.mlp.experts.242.down_proj",
777
+ "model.layers.61.mlp.experts.243.gate_proj",
778
+ "model.layers.61.mlp.experts.243.up_proj",
779
+ "model.layers.61.mlp.experts.243.down_proj",
780
+ "model.layers.61.mlp.experts.244.gate_proj",
781
+ "model.layers.61.mlp.experts.244.up_proj",
782
+ "model.layers.61.mlp.experts.244.down_proj",
783
+ "model.layers.61.mlp.experts.245.gate_proj",
784
+ "model.layers.61.mlp.experts.245.up_proj",
785
+ "model.layers.61.mlp.experts.245.down_proj",
786
+ "model.layers.61.mlp.experts.246.gate_proj",
787
+ "model.layers.61.mlp.experts.246.up_proj",
788
+ "model.layers.61.mlp.experts.246.down_proj",
789
+ "model.layers.61.mlp.experts.247.gate_proj",
790
+ "model.layers.61.mlp.experts.247.up_proj",
791
+ "model.layers.61.mlp.experts.247.down_proj",
792
+ "model.layers.61.mlp.experts.248.gate_proj",
793
+ "model.layers.61.mlp.experts.248.up_proj",
794
+ "model.layers.61.mlp.experts.248.down_proj",
795
+ "model.layers.61.mlp.experts.249.gate_proj",
796
+ "model.layers.61.mlp.experts.249.up_proj",
797
+ "model.layers.61.mlp.experts.249.down_proj",
798
+ "model.layers.61.mlp.experts.250.gate_proj",
799
+ "model.layers.61.mlp.experts.250.up_proj",
800
+ "model.layers.61.mlp.experts.250.down_proj",
801
+ "model.layers.61.mlp.experts.251.gate_proj",
802
+ "model.layers.61.mlp.experts.251.up_proj",
803
+ "model.layers.61.mlp.experts.251.down_proj",
804
+ "model.layers.61.mlp.experts.252.gate_proj",
805
+ "model.layers.61.mlp.experts.252.up_proj",
806
+ "model.layers.61.mlp.experts.252.down_proj",
807
+ "model.layers.61.mlp.experts.253.gate_proj",
808
+ "model.layers.61.mlp.experts.253.up_proj",
809
+ "model.layers.61.mlp.experts.253.down_proj",
810
+ "model.layers.61.mlp.experts.254.gate_proj",
811
+ "model.layers.61.mlp.experts.254.up_proj",
812
+ "model.layers.61.mlp.experts.254.down_proj",
813
+ "model.layers.61.mlp.experts.255.gate_proj",
814
+ "model.layers.61.mlp.experts.255.up_proj",
815
+ "model.layers.61.mlp.experts.255.down_proj",
816
+ "model.layers.61.mlp.shared_experts.gate_proj",
817
+ "model.layers.61.mlp.shared_experts.up_proj",
818
+ "model.layers.61.mlp.shared_experts.down_proj",
819
+ "model.layers.61.shared_head.head",
820
+ "lm_head"
821
+ ],
822
+ "export": {
823
+ "kv_cache_group": [],
824
+ "min_kv_scale": 0.0,
825
+ "pack_method": "reorder",
826
+ "weight_format": "real_quantized",
827
+ "weight_merge_groups": null
828
+ },
829
+ "global_quant_config": {
830
+ "bias": null,
831
+ "input_tensors": {
832
+ "ch_axis": -1,
833
+ "dtype": "fp4",
834
+ "group_size": 32,
835
+ "is_dynamic": true,
836
+ "is_scale_quant": false,
837
+ "mx_element_dtype": null,
838
+ "observer_cls": "PerBlockMXObserver",
839
+ "qscheme": "per_group",
840
+ "round_method": "half_even",
841
+ "scale_calculation_mode": "even",
842
+ "scale_format": "e8m0",
843
+ "scale_type": "float",
844
+ "symmetric": null
845
+ },
846
+ "output_tensors": null,
847
+ "target_device": null,
848
+ "weight": {
849
+ "ch_axis": -1,
850
+ "dtype": "fp4",
851
+ "group_size": 32,
852
+ "is_dynamic": false,
853
+ "is_scale_quant": false,
854
+ "mx_element_dtype": null,
855
+ "observer_cls": "PerBlockMXObserver",
856
+ "qscheme": "per_group",
857
+ "round_method": "half_even",
858
+ "scale_calculation_mode": "even",
859
+ "scale_format": "e8m0",
860
+ "scale_type": "float",
861
+ "symmetric": null
862
+ }
863
+ },
864
+ "kv_cache_quant_config": {},
865
+ "layer_quant_config": {},
866
+ "layer_type_quant_config": {},
867
+ "quant_method": "quark",
868
+ "quant_mode": "eager_mode",
869
+ "softmax_quant_spec": null,
870
+ "version": "0.10"
871
+ },
872
+ "rms_norm_eps": 1e-06,
873
+ "rope_scaling": {
874
+ "beta_fast": 32,
875
+ "beta_slow": 1,
876
+ "factor": 40,
877
+ "mscale": 1.0,
878
+ "mscale_all_dim": 1.0,
879
+ "original_max_position_embeddings": 4096,
880
+ "type": "yarn"
881
+ },
882
+ "rope_theta": 10000,
883
+ "routed_scaling_factor": 2.5,
884
+ "scoring_func": "sigmoid",
885
+ "tie_word_embeddings": false,
886
+ "topk_group": 4,
887
+ "topk_method": "noaux_tc",
888
+ "torch_dtype": "bfloat16",
889
+ "transformers_version": "4.53.0",
890
+ "unsloth_fixed": true,
891
+ "use_cache": true,
892
+ "v_head_dim": 128,
893
+ "vocab_size": 129280
894
+ }
configuration_deepseek.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
7
+ class DeepseekV3Config(PretrainedConfig):
8
+ r"""
9
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
10
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
11
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
12
+
13
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14
+ documentation from [`PretrainedConfig`] for more information.
15
+
16
+
17
+ Args:
18
+ vocab_size (`int`, *optional*, defaults to 129280):
19
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
20
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
21
+ hidden_size (`int`, *optional*, defaults to 4096):
22
+ Dimension of the hidden representations.
23
+ intermediate_size (`int`, *optional*, defaults to 11008):
24
+ Dimension of the MLP representations.
25
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
26
+ Dimension of the MoE representations.
27
+ num_hidden_layers (`int`, *optional*, defaults to 32):
28
+ Number of hidden layers in the Transformer decoder.
29
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
30
+ Number of nextn predict layers in the DeepSeekV3 Model.
31
+ num_attention_heads (`int`, *optional*, defaults to 32):
32
+ Number of attention heads for each attention layer in the Transformer decoder.
33
+ n_shared_experts (`int`, *optional*, defaults to None):
34
+ Number of shared experts, None means dense model.
35
+ n_routed_experts (`int`, *optional*, defaults to None):
36
+ Number of routed experts, None means dense model.
37
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
38
+ Scaling factor or routed experts.
39
+ topk_method (`str`, *optional*, defaults to `gready`):
40
+ Topk method used in routed gate.
41
+ n_group (`int`, *optional*, defaults to None):
42
+ Number of groups for routed experts.
43
+ topk_group (`int`, *optional*, defaults to None):
44
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
45
+ num_experts_per_tok (`int`, *optional*, defaults to None):
46
+ Number of selected experts, None means dense model.
47
+ moe_layer_freq (`int`, *optional*, defaults to 1):
48
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
49
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
50
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
51
+ \--k dense layers--/
52
+ norm_topk_prob (`bool`, *optional*, defaults to False):
53
+ Whether to normalize the weights of the routed experts.
54
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
55
+ Method of computing expert weights.
56
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
57
+ Auxiliary loss weight coefficient.
58
+ seq_aux = (`bool`, *optional*, defaults to True):
59
+ Whether to compute the auxiliary loss for each individual sample.
60
+ num_key_value_heads (`int`, *optional*):
61
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
62
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
63
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
64
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
65
+ by meanpooling all the original heads within that group. For more details checkout [this
66
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
67
+ `num_attention_heads`.
68
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
69
+ The non-linear activation function (function or string) in the decoder.
70
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
71
+ The maximum sequence length that this model might ever be used with.
72
+ initializer_range (`float`, *optional*, defaults to 0.02):
73
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
74
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
75
+ The epsilon used by the rms normalization layers.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`.
79
+ pad_token_id (`int`, *optional*):
80
+ Padding token id.
81
+ bos_token_id (`int`, *optional*, defaults to 1):
82
+ Beginning of stream token id.
83
+ eos_token_id (`int`, *optional*, defaults to 2):
84
+ End of stream token id.
85
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
86
+ Whether to tie weight embeddings
87
+ rope_theta (`float`, *optional*, defaults to 10000.0):
88
+ The base period of the RoPE embeddings.
89
+ rope_scaling (`Dict`, *optional*):
90
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
91
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
92
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
93
+ `max_position_embeddings` to the expected new maximum.
94
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
95
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
96
+ attention_dropout (`float`, *optional*, defaults to 0.0):
97
+ The dropout ratio for the attention probabilities.
98
+
99
+ ```python
100
+ >>> from transformers import DeepseekV3Model, DeepseekV3Config
101
+
102
+ >>> # Initializing a Deepseek-V3 style configuration
103
+ >>> configuration = DeepseekV3Config()
104
+
105
+ >>> # Accessing the model configuration
106
+ >>> configuration = model.config
107
+ ```"""
108
+
109
+ model_type = "deepseek_v3"
110
+ keys_to_ignore_at_inference = ["past_key_values"]
111
+
112
+ def __init__(
113
+ self,
114
+ vocab_size=129280,
115
+ hidden_size=7168,
116
+ intermediate_size=18432,
117
+ moe_intermediate_size = 2048,
118
+ num_hidden_layers=61,
119
+ num_nextn_predict_layers=1,
120
+ num_attention_heads=128,
121
+ num_key_value_heads=128,
122
+ n_shared_experts = 1,
123
+ n_routed_experts = 256,
124
+ ep_size = 1,
125
+ routed_scaling_factor = 2.5,
126
+ kv_lora_rank = 512,
127
+ q_lora_rank = 1536,
128
+ qk_rope_head_dim = 64,
129
+ v_head_dim = 128,
130
+ qk_nope_head_dim = 128,
131
+ topk_method = 'noaux_tc',
132
+ n_group = 8,
133
+ topk_group = 4,
134
+ num_experts_per_tok = 8,
135
+ moe_layer_freq = 1,
136
+ first_k_dense_replace = 3,
137
+ norm_topk_prob = True,
138
+ scoring_func = 'sigmoid',
139
+ hidden_act="silu",
140
+ max_position_embeddings=4096,
141
+ initializer_range=0.02,
142
+ rms_norm_eps=1e-6,
143
+ use_cache=True,
144
+ pad_token_id=None,
145
+ bos_token_id=0,
146
+ eos_token_id=1,
147
+ tie_word_embeddings=False,
148
+ rope_theta=10000.0,
149
+ rope_scaling=None,
150
+ attention_bias=False,
151
+ attention_dropout=0.0,
152
+ **kwargs,
153
+ ):
154
+ self.vocab_size = vocab_size
155
+ self.max_position_embeddings = max_position_embeddings
156
+ self.hidden_size = hidden_size
157
+ self.intermediate_size = intermediate_size
158
+ self.moe_intermediate_size = moe_intermediate_size
159
+ self.num_hidden_layers = num_hidden_layers
160
+ self.num_nextn_predict_layers = num_nextn_predict_layers
161
+ self.num_attention_heads = num_attention_heads
162
+ self.n_shared_experts = n_shared_experts
163
+ self.n_routed_experts = n_routed_experts
164
+ self.ep_size = ep_size
165
+ self.routed_scaling_factor = routed_scaling_factor
166
+ self.kv_lora_rank = kv_lora_rank
167
+ self.q_lora_rank = q_lora_rank
168
+ self.qk_rope_head_dim = qk_rope_head_dim
169
+ self.v_head_dim = v_head_dim
170
+ self.qk_nope_head_dim = qk_nope_head_dim
171
+ self.topk_method = topk_method
172
+ self.n_group = n_group
173
+ self.topk_group = topk_group
174
+ self.num_experts_per_tok = num_experts_per_tok
175
+ self.moe_layer_freq = moe_layer_freq
176
+ self.first_k_dense_replace = first_k_dense_replace
177
+ self.norm_topk_prob = norm_topk_prob
178
+ self.scoring_func = scoring_func
179
+ # for backward compatibility
180
+ if num_key_value_heads is None:
181
+ num_key_value_heads = num_attention_heads
182
+
183
+ self.num_key_value_heads = num_key_value_heads
184
+ self.hidden_act = hidden_act
185
+ self.initializer_range = initializer_range
186
+ self.rms_norm_eps = rms_norm_eps
187
+ self.use_cache = use_cache
188
+ self.rope_theta = rope_theta
189
+ self.rope_scaling = rope_scaling
190
+ self.attention_bias = attention_bias
191
+ self.attention_dropout = attention_dropout
192
+
193
+ super().__init__(
194
+ pad_token_id=pad_token_id,
195
+ bos_token_id=bos_token_id,
196
+ eos_token_id=eos_token_id,
197
+ tie_word_embeddings=tie_word_embeddings,
198
+ **kwargs,
199
+ )
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "do_sample": true,
5
+ "eos_token_id": 1,
6
+ "temperature": 0.6,
7
+ "top_p": 0.95,
8
+ "transformers_version": "4.53.0"
9
+ }
model-00001-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43d742dbbcc52faea5915b1d61a1f5d4da28f5cb38e38a023884a879a9f392e4
3
+ size 4996307368
model-00002-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74daff529f8dafa197660e29907c301fe658c5e0372ad3962d0438a46e6fe214
3
+ size 4993099120
model-00004-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d53a589f7ab0c9c888c90d495d6b1bf80382de58e4d8c159afc4995d12ed4487
3
+ size 4993098832
model-00005-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:055df3c822c834c9eff9c0a123744fbc0ddb69b52daa779c13348a768d6d1c04
3
+ size 4993098832
model-00007-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36568d07bf45dc5325b90d6fe2e64e668abc583ff0c48b8e067626c52ec51e31
3
+ size 4993099048
model-00009-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f1b5bab2627254866986fd93b4a264b15031dfa496fb8fbd538efed37101b4a
3
+ size 4999180616
model-00010-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d46ef6572f25cca13147682575babf337ec523bfee4798c35b224938418b58e
3
+ size 4993099920
model-00016-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32c13a6bc019e120321c01bd5689ae822c1183704dd6bf2015a4f64ab7c34e2f
3
+ size 4993100104
model-00017-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7ba7d046e4236f24aa46cddad7a4505807df17369f01be4f742aa91a10cdcca
3
+ size 4993100104
model-00019-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6dea35542eeb8d4fc0976404d289201591b8cb8b2f63ce2902f1e3de8ef52d1
3
+ size 4993100392
model-00021-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d17e5f7576a3ef2ae318f4acebdbd8b8b2e3d76e6f0c623ff0e96201d8ec584
3
+ size 4993100104
model-00023-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efbb9da3ac9097ddce616993d5027e6e795486f176bb1d07ee1c8f3bcb242cb2
3
+ size 4993100216
model-00024-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e1805b889f633732781728f389c788266b180e0e6461e8fa615e220b9a81ac3
3
+ size 4993100392
model-00025-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb812611869972a822c5df6b083da343606dcc15d15b729a3b06a20e8de6e19f
3
+ size 4999182000
model-00026-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a1fb530783eb5fd1547264dfff4e3378d5137d7a6e71c2f44e3c0c146bfb0f5
3
+ size 4993100104
model-00029-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eef2d5fe476937cf5129e9049382629f7ad6a41c7c313f8765892ff5fc01ed13
3
+ size 4993100376
model-00030-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac34a20df6ebf5b36d046c054685c6dedebe3e29061797651fd2136203fd3bee
3
+ size 4993100408
model-00033-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2723af70fa78f9e7b39eb0ea6d73ee6a37ed0bddad6fe2da6600d6185b6bc239
3
+ size 4993100104
model-00034-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6f4a09fca2598f913fc4d49d290c975df91218aa6f79653280b5a5642d8e9db
3
+ size 4993100240
model-00036-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca871e327222f21308a635a4e76a6efb04bd7da79a25cfca5c58d3c713343856
3
+ size 4999181968
model-00037-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:056d1fc2ce21aa1c2df63092f11fb163a838d86fdcad2856e6c0c870ac2693ce
3
+ size 4993100104
model-00040-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77bd287af890c35fe1fee1d7b29882fdae79be86fa68624c74b24be29d932c9f
3
+ size 4993100392
model-00042-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec72d22340b751ccead4e261cabd6a3cf8ab9ed070f99c7dde556eca790279fe
3
+ size 4999181768
model-00044-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71f747bd9bc0da5a90969a743891b81975ab2e1ed3b7945e2528c2acc89873ed
3
+ size 4993100104
model-00046-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e95f518f6cabf55a1782474eec35a3716e4f3a52fc8e20b4d68b03d2e88c616
3
+ size 4993100392
model-00048-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:477b68ce6afdc4bf7152dc466ee9486bd503d5f3ff7b22f329ce79ae92dcc265
3
+ size 4993100104
model-00051-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23bb11017f9638a56c385edcbb0464487cd81e6ff1d1c2b43607588b68e7dfee
3
+ size 4993100392
model-00052-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c88b0e1b10ae5231519dca344e69e2e4d2da2a749515d7698a4f7274044ce640
3
+ size 4993102224
model-00053-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78a8ce5f8327b8ce47221fbf217354e189aa039b8c3107a99b2c66ab7ba96548
3
+ size 4999179952
model-00054-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:775b36b9aca92bfaee49019dcc5d44e13f1af596f407e47bf634f8af436c5187
3
+ size 4993100104
model-00055-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a7ba38dd5e1a034cb063d8527394d126d5a73134cf532c69ce4969e07372162
3
+ size 4993100104
model-00056-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46fe595f3855a5530a6149902d8bda29bfd91a63cb91ea97432b7164dc06721c
3
+ size 4993100296
model-00057-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82e43ea1f64e5167663d7161b5330a125facfcc26ee8ae1248008d764ee12464
3
+ size 4993100392
model-00058-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e4a4e7c99e6789985d63f314b783e17c7edd1c979acbfd4c61d29cf0693beb4
3
+ size 4999181912
model-00060-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebd500bd7ac84fb67a5211248650b5b59cbb2e6ecdd1d9c68426d6e480eb358d
3
+ size 4993100104
model-00062-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:450949317444bd175badaee084262d7f5d26d0d1c536afe35dfc55a0623c4551
3
+ size 4993100392
model-00067-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6cf2ca7a580dbc0e9487da7474e3cf620b455e396cb06a549cba5889e891823
3
+ size 4993100328
model-00069-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2df37fbeb4115cc49bb0652bd25fa954185082e82bb14d3b9e8607920fef9097
3
+ size 4999181888
model-00070-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0449858fc5f82c979d2e31efd2fe2a3e214ae29695773cdb5a9ebc7d05eb8963
3
+ size 4993100104
model-00072-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b3801f540a48979aed3beef8f5d64e668e42200e6c14c723d49c0d9d9dc182f
3
+ size 4987246760
model-00074-of-00078.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f16d4bb4f5ca9460873341a764a941a9d87a76b7adad99cd8a0e065745de312
3
+ size 4991243056
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_deepseek.py ADDED
@@ -0,0 +1,2176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch DeepSeek model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.modeling_attn_mask_utils import (
34
+ AttentionMaskConverter,
35
+ _prepare_4d_attention_mask,
36
+ _prepare_4d_causal_attention_mask,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ CausalLMOutputWithPast,
41
+ SequenceClassifierOutputWithPast,
42
+ )
43
+ from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.pytorch_utils import (
45
+ ALL_LAYERNORM_LAYERS,
46
+ is_torch_greater_or_equal_than_1_13,
47
+ )
48
+ from transformers.utils import (
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ is_flash_attn_2_available,
52
+ is_flash_attn_greater_or_equal_2_10,
53
+ logging,
54
+ replace_return_docstrings,
55
+ )
56
+ from transformers.utils.import_utils import is_torch_fx_available
57
+ from .configuration_deepseek import DeepseekV3Config
58
+ import torch.distributed as dist
59
+ import numpy as np
60
+
61
+ if is_flash_attn_2_available():
62
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
63
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
64
+
65
+
66
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
67
+ # It means that the function will not be traced through and simply appear as a node in the graph.
68
+ if is_torch_fx_available():
69
+ if not is_torch_greater_or_equal_than_1_13:
70
+ import torch.fx
71
+
72
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
73
+
74
+
75
+ logger = logging.get_logger(__name__)
76
+
77
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
78
+
79
+
80
+ def _get_unpad_data(attention_mask):
81
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
82
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
83
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
84
+ cu_seqlens = F.pad(
85
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
86
+ )
87
+ return (
88
+ indices,
89
+ cu_seqlens,
90
+ max_seqlen_in_batch,
91
+ )
92
+
93
+
94
+ class DeepseekV3RMSNorm(nn.Module):
95
+ def __init__(self, hidden_size, eps=1e-6):
96
+ """
97
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
98
+ """
99
+ super().__init__()
100
+ self.weight = nn.Parameter(torch.ones(hidden_size))
101
+ self.variance_epsilon = eps
102
+
103
+ def forward(self, hidden_states):
104
+ input_dtype = hidden_states.dtype
105
+ hidden_states = hidden_states.to(torch.float32)
106
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
107
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
108
+ return self.weight * hidden_states.to(input_dtype)
109
+
110
+
111
+ ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
112
+
113
+
114
+ class DeepseekV3RotaryEmbedding(nn.Module):
115
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
116
+ super().__init__()
117
+
118
+ self.dim = dim
119
+ self.max_position_embeddings = max_position_embeddings
120
+ self.base = base
121
+ inv_freq = 1.0 / (
122
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
123
+ )
124
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
125
+
126
+ # Build here to make `torch.jit.trace` work.
127
+ self._set_cos_sin_cache(
128
+ seq_len=max_position_embeddings,
129
+ device=self.inv_freq.device,
130
+ dtype=torch.get_default_dtype(),
131
+ )
132
+ self.max_seq_len_cached = None
133
+
134
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
135
+ self.max_seq_len_cached = seq_len
136
+ t = torch.arange(
137
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
138
+ )
139
+
140
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
141
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
142
+ emb = torch.cat((freqs, freqs), dim=-1)
143
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
144
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
145
+
146
+ def forward(self, x, seq_len=None):
147
+ # x: [bs, num_attention_heads, seq_len, head_size]
148
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
149
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
150
+
151
+ return (
152
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
153
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
154
+ )
155
+
156
+
157
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3
158
+ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
159
+ """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
160
+
161
+ def __init__(
162
+ self,
163
+ dim,
164
+ max_position_embeddings=2048,
165
+ base=10000,
166
+ device=None,
167
+ scaling_factor=1.0,
168
+ ):
169
+ self.scaling_factor = scaling_factor
170
+ super().__init__(dim, max_position_embeddings, base, device)
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+ t = torch.arange(
175
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
176
+ )
177
+ t = t / self.scaling_factor
178
+
179
+ freqs = torch.outer(t, self.inv_freq)
180
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
181
+ emb = torch.cat((freqs, freqs), dim=-1)
182
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
183
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
184
+
185
+
186
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3
187
+ class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
188
+ """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
189
+
190
+ def __init__(
191
+ self,
192
+ dim,
193
+ max_position_embeddings=2048,
194
+ base=10000,
195
+ device=None,
196
+ scaling_factor=1.0,
197
+ ):
198
+ self.scaling_factor = scaling_factor
199
+ super().__init__(dim, max_position_embeddings, base, device)
200
+
201
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
202
+ self.max_seq_len_cached = seq_len
203
+
204
+ if seq_len > self.max_position_embeddings:
205
+ base = self.base * (
206
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
207
+ - (self.scaling_factor - 1)
208
+ ) ** (self.dim / (self.dim - 2))
209
+ inv_freq = 1.0 / (
210
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
211
+ )
212
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
213
+
214
+ t = torch.arange(
215
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
216
+ )
217
+
218
+ freqs = torch.outer(t, self.inv_freq)
219
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
220
+ emb = torch.cat((freqs, freqs), dim=-1)
221
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
222
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
223
+
224
+
225
+ # Inverse dim formula to find dim based on number of rotations
226
+ def yarn_find_correction_dim(
227
+ num_rotations, dim, base=10000, max_position_embeddings=2048
228
+ ):
229
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
230
+ 2 * math.log(base)
231
+ )
232
+
233
+
234
+ # Find dim range bounds based on rotations
235
+ def yarn_find_correction_range(
236
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
237
+ ):
238
+ low = math.floor(
239
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
240
+ )
241
+ high = math.ceil(
242
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
243
+ )
244
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
245
+
246
+
247
+ def yarn_get_mscale(scale=1, mscale=1):
248
+ if scale <= 1:
249
+ return 1.0
250
+ return 0.1 * mscale * math.log(scale) + 1.0
251
+
252
+
253
+ def yarn_linear_ramp_mask(min, max, dim):
254
+ if min == max:
255
+ max += 0.001 # Prevent singularity
256
+
257
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
258
+ ramp_func = torch.clamp(linear_func, 0, 1)
259
+ return ramp_func
260
+
261
+
262
+ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
263
+
264
+ def __init__(
265
+ self,
266
+ dim,
267
+ max_position_embeddings=2048,
268
+ base=10000,
269
+ device=None,
270
+ scaling_factor=1.0,
271
+ original_max_position_embeddings=4096,
272
+ beta_fast=32,
273
+ beta_slow=1,
274
+ mscale=1,
275
+ mscale_all_dim=0,
276
+ ):
277
+ self.scaling_factor = scaling_factor
278
+ self.original_max_position_embeddings = original_max_position_embeddings
279
+ self.beta_fast = beta_fast
280
+ self.beta_slow = beta_slow
281
+ self.mscale = mscale
282
+ self.mscale_all_dim = mscale_all_dim
283
+ super().__init__(dim, max_position_embeddings, base, device)
284
+
285
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
286
+ self.max_seq_len_cached = seq_len
287
+ dim = self.dim
288
+
289
+ freq_extra = 1.0 / (
290
+ self.base
291
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
292
+ )
293
+ freq_inter = 1.0 / (
294
+ self.scaling_factor
295
+ * self.base
296
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
297
+ )
298
+
299
+ low, high = yarn_find_correction_range(
300
+ self.beta_fast,
301
+ self.beta_slow,
302
+ dim,
303
+ self.base,
304
+ self.original_max_position_embeddings,
305
+ )
306
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
307
+ device=device, dtype=torch.float32
308
+ )
309
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
310
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
311
+
312
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
313
+
314
+ freqs = torch.outer(t, inv_freq)
315
+
316
+ _mscale = float(
317
+ yarn_get_mscale(self.scaling_factor, self.mscale)
318
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
319
+ )
320
+
321
+ emb = torch.cat((freqs, freqs), dim=-1)
322
+ self.register_buffer(
323
+ "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
324
+ )
325
+ self.register_buffer(
326
+ "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
327
+ )
328
+
329
+
330
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
331
+ def rotate_half(x):
332
+ """Rotates half the hidden dims of the input."""
333
+ x1 = x[..., : x.shape[-1] // 2]
334
+ x2 = x[..., x.shape[-1] // 2 :]
335
+ return torch.cat((-x2, x1), dim=-1)
336
+
337
+
338
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
339
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
340
+ """Applies Rotary Position Embedding to the query and key tensors.
341
+
342
+ Args:
343
+ q (`torch.Tensor`): The query tensor.
344
+ k (`torch.Tensor`): The key tensor.
345
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
346
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
347
+ position_ids (`torch.Tensor`):
348
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
349
+ used to pass offsetted position ids when working with a KV-cache.
350
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
351
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
352
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
353
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
354
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
355
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
356
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
357
+ Returns:
358
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
359
+ """
360
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
361
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
362
+
363
+ b, h, s, d = q.shape
364
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
365
+
366
+ b, h, s, d = k.shape
367
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
368
+
369
+ q_embed = (q * cos) + (rotate_half(q) * sin)
370
+ k_embed = (k * cos) + (rotate_half(k) * sin)
371
+ return q_embed, k_embed
372
+
373
+
374
+ class DeepseekV3MLP(nn.Module):
375
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
376
+ super().__init__()
377
+ self.config = config
378
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
379
+ self.intermediate_size = (
380
+ config.intermediate_size if intermediate_size is None else intermediate_size
381
+ )
382
+
383
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
384
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
385
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
386
+ self.act_fn = ACT2FN[config.hidden_act]
387
+
388
+ def forward(self, x):
389
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
390
+ return down_proj
391
+
392
+
393
+ class MoEGate(nn.Module):
394
+ def __init__(self, config):
395
+ super().__init__()
396
+ self.config = config
397
+ self.top_k = config.num_experts_per_tok
398
+ self.n_routed_experts = config.n_routed_experts
399
+ self.routed_scaling_factor = config.routed_scaling_factor
400
+ self.scoring_func = config.scoring_func
401
+ self.topk_method = config.topk_method
402
+ self.n_group = config.n_group
403
+ self.topk_group = config.topk_group
404
+
405
+ # topk selection algorithm
406
+ self.norm_topk_prob = config.norm_topk_prob
407
+ self.gating_dim = config.hidden_size
408
+ self.weight = nn.Parameter(
409
+ torch.empty((self.n_routed_experts, self.gating_dim))
410
+ )
411
+ if self.topk_method == "noaux_tc":
412
+ self.e_score_correction_bias = nn.Parameter(
413
+ torch.empty((self.n_routed_experts))
414
+ )
415
+ self.reset_parameters()
416
+
417
+ def reset_parameters(self) -> None:
418
+ import torch.nn.init as init
419
+
420
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
421
+
422
+ def forward(self, hidden_states):
423
+ bsz, seq_len, h = hidden_states.shape
424
+ ### compute gating score
425
+ hidden_states = hidden_states.view(-1, h)
426
+ logits = F.linear(
427
+ hidden_states.type(torch.float32), self.weight.type(torch.float32), None
428
+ )
429
+ if self.scoring_func == "sigmoid":
430
+ scores = logits.sigmoid()
431
+ else:
432
+ raise NotImplementedError(
433
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
434
+ )
435
+
436
+ ### select top-k experts
437
+ if self.topk_method == "noaux_tc":
438
+ assert not self.training
439
+ scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
440
+ group_scores = (
441
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
442
+ ) # [n, n_group]
443
+ group_idx = torch.topk(
444
+ group_scores, k=self.topk_group, dim=-1, sorted=False
445
+ )[
446
+ 1
447
+ ] # [n, top_k_group]
448
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
449
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
450
+ score_mask = (
451
+ group_mask.unsqueeze(-1)
452
+ .expand(
453
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
454
+ )
455
+ .reshape(bsz * seq_len, -1)
456
+ ) # [n, e]
457
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
458
+ _, topk_idx = torch.topk(
459
+ tmp_scores, k=self.top_k, dim=-1, sorted=False
460
+ )
461
+ topk_weight = scores.gather(1, topk_idx)
462
+ else:
463
+ raise NotImplementedError(
464
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
465
+ )
466
+
467
+ ### norm gate to sum 1
468
+ if self.top_k > 1 and self.norm_topk_prob:
469
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
470
+ topk_weight = topk_weight / denominator
471
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
472
+
473
+ return topk_idx, topk_weight
474
+
475
+ class DeepseekV3MoE(nn.Module):
476
+ """
477
+ A mixed expert module containing shared experts.
478
+ """
479
+
480
+ def __init__(self, config):
481
+ super().__init__()
482
+ self.config = config
483
+ self.num_experts_per_tok = config.num_experts_per_tok
484
+
485
+ if hasattr(config, "ep_size") and config.ep_size > 1:
486
+ assert config.ep_size == dist.get_world_size()
487
+ self.ep_size = config.ep_size
488
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
489
+ self.ep_rank = dist.get_rank()
490
+ self.experts = nn.ModuleList(
491
+ [
492
+ (
493
+ DeepseekV3MLP(
494
+ config, intermediate_size=config.moe_intermediate_size
495
+ )
496
+ if i >= self.ep_rank * self.experts_per_rank
497
+ and i < (self.ep_rank + 1) * self.experts_per_rank
498
+ else None
499
+ )
500
+ for i in range(config.n_routed_experts)
501
+ ]
502
+ )
503
+ else:
504
+ self.ep_size = 1
505
+ self.experts_per_rank = config.n_routed_experts
506
+ self.ep_rank = 0
507
+ self.experts = nn.ModuleList(
508
+ [
509
+ DeepseekV3MLP(
510
+ config, intermediate_size=config.moe_intermediate_size
511
+ )
512
+ for i in range(config.n_routed_experts)
513
+ ]
514
+ )
515
+ self.gate = MoEGate(config)
516
+ if config.n_shared_experts is not None:
517
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
518
+ self.shared_experts = DeepseekV3MLP(
519
+ config=config, intermediate_size=intermediate_size
520
+ )
521
+
522
+ def forward(self, hidden_states):
523
+ identity = hidden_states
524
+ orig_shape = hidden_states.shape
525
+ topk_idx, topk_weight = self.gate(hidden_states)
526
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
527
+ flat_topk_idx = topk_idx.view(-1)
528
+ if not self.training:
529
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
530
+ if self.config.n_shared_experts is not None:
531
+ y = y + self.shared_experts(identity)
532
+ return y
533
+
534
+ @torch.no_grad()
535
+ def moe_infer(self, x, topk_ids, topk_weight):
536
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
537
+ cnts.scatter_(1, topk_ids, 1)
538
+ tokens_per_expert = cnts.sum(dim=0)
539
+ idxs = topk_ids.view(-1).argsort()
540
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
541
+ sorted_tokens_shape = sorted_tokens.shape
542
+ if self.ep_size > 1:
543
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
544
+ tokens_per_expert_group = tokens_per_expert.new_empty(
545
+ tokens_per_expert.shape[0]
546
+ )
547
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
548
+ output_splits = (
549
+ tokens_per_expert_group.view(self.ep_size, -1)
550
+ .sum(1)
551
+ .cpu()
552
+ .numpy()
553
+ .tolist()
554
+ )
555
+ gathered_tokens = sorted_tokens.new_empty(
556
+ tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
557
+ )
558
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
559
+ dist.all_to_all(
560
+ list(gathered_tokens.split(output_splits)),
561
+ list(sorted_tokens.split(input_split_sizes)),
562
+ )
563
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
564
+ self.ep_size, self.experts_per_rank
565
+ ).sum(dim=0)
566
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
567
+ s = 0
568
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
569
+ gatherd_idxs[s : s + k] = i % self.experts_per_rank
570
+ s += k
571
+ gatherd_idxs = gatherd_idxs.argsort()
572
+ sorted_tokens = gathered_tokens[gatherd_idxs]
573
+ tokens_per_expert = tokens_per_expert_post_gather
574
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
575
+
576
+ outputs = []
577
+ start_idx = 0
578
+ for i, num_tokens in enumerate(tokens_per_expert):
579
+ end_idx = start_idx + num_tokens
580
+ if num_tokens == 0:
581
+ continue
582
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
583
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
584
+ expert_out = expert(tokens_for_this_expert)
585
+ outputs.append(expert_out)
586
+ start_idx = end_idx
587
+
588
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
589
+ if self.ep_size > 1:
590
+ new_x = torch.empty_like(outs)
591
+ new_x[gatherd_idxs] = outs
592
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
593
+ dist.all_to_all(
594
+ list(gathered_tokens.split(input_split_sizes)),
595
+ list(new_x.split(output_splits)),
596
+ )
597
+ outs = gathered_tokens
598
+
599
+ new_x = torch.empty_like(outs)
600
+ new_x[idxs] = outs
601
+ final_out = (
602
+ new_x.view(*topk_ids.shape, -1)
603
+ .type(topk_weight.dtype)
604
+ .mul_(topk_weight.unsqueeze(dim=-1))
605
+ .sum(dim=1)
606
+ .type(new_x.dtype)
607
+ )
608
+ return final_out
609
+
610
+
611
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
612
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
613
+ """
614
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
615
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
616
+ """
617
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
618
+ if n_rep == 1:
619
+ return hidden_states
620
+ hidden_states = hidden_states[:, :, None, :, :].expand(
621
+ batch, num_key_value_heads, n_rep, slen, head_dim
622
+ )
623
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
624
+
625
+
626
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3
627
+ class DeepseekV3Attention(nn.Module):
628
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
629
+
630
+ def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
631
+ super().__init__()
632
+ self.config = config
633
+ self.layer_idx = layer_idx
634
+ if layer_idx is None:
635
+ logger.warning_once(
636
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
637
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
638
+ "when creating this class."
639
+ )
640
+
641
+ self.attention_dropout = config.attention_dropout
642
+ self.hidden_size = config.hidden_size
643
+ self.num_heads = config.num_attention_heads
644
+
645
+ self.max_position_embeddings = config.max_position_embeddings
646
+ self.rope_theta = config.rope_theta
647
+ self.q_lora_rank = config.q_lora_rank
648
+ self.qk_rope_head_dim = config.qk_rope_head_dim
649
+ self.kv_lora_rank = config.kv_lora_rank
650
+ self.v_head_dim = config.v_head_dim
651
+ self.qk_nope_head_dim = config.qk_nope_head_dim
652
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
653
+
654
+ self.is_causal = True
655
+
656
+ if self.q_lora_rank is None:
657
+ self.q_proj = nn.Linear(
658
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
659
+ )
660
+ else:
661
+ self.q_a_proj = nn.Linear(
662
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
663
+ )
664
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
665
+ self.q_b_proj = nn.Linear(
666
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
667
+ )
668
+
669
+ self.kv_a_proj_with_mqa = nn.Linear(
670
+ self.hidden_size,
671
+ config.kv_lora_rank + config.qk_rope_head_dim,
672
+ bias=config.attention_bias,
673
+ )
674
+ self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
675
+ self.kv_b_proj = nn.Linear(
676
+ config.kv_lora_rank,
677
+ self.num_heads
678
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
679
+ bias=False,
680
+ )
681
+
682
+ self.o_proj = nn.Linear(
683
+ self.num_heads * self.v_head_dim,
684
+ self.hidden_size,
685
+ bias=config.attention_bias,
686
+ )
687
+ self._init_rope()
688
+
689
+ self.softmax_scale = self.q_head_dim ** (-0.5)
690
+ if self.config.rope_scaling is not None:
691
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
692
+ scaling_factor = self.config.rope_scaling["factor"]
693
+ if mscale_all_dim:
694
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
695
+ self.softmax_scale = self.softmax_scale * mscale * mscale
696
+
697
+ def _init_rope(self):
698
+ if self.config.rope_scaling is None:
699
+ self.rotary_emb = DeepseekV3RotaryEmbedding(
700
+ self.qk_rope_head_dim,
701
+ max_position_embeddings=self.max_position_embeddings,
702
+ base=self.rope_theta,
703
+ )
704
+ else:
705
+ scaling_type = self.config.rope_scaling["type"]
706
+ scaling_factor = self.config.rope_scaling["factor"]
707
+ if scaling_type == "linear":
708
+ self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
709
+ self.qk_rope_head_dim,
710
+ max_position_embeddings=self.max_position_embeddings,
711
+ scaling_factor=scaling_factor,
712
+ base=self.rope_theta,
713
+ )
714
+ elif scaling_type == "dynamic":
715
+ self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
716
+ self.qk_rope_head_dim,
717
+ max_position_embeddings=self.max_position_embeddings,
718
+ scaling_factor=scaling_factor,
719
+ base=self.rope_theta,
720
+ )
721
+ elif scaling_type == "yarn":
722
+ kwargs = {
723
+ key: self.config.rope_scaling[key]
724
+ for key in [
725
+ "original_max_position_embeddings",
726
+ "beta_fast",
727
+ "beta_slow",
728
+ "mscale",
729
+ "mscale_all_dim",
730
+ ]
731
+ if key in self.config.rope_scaling
732
+ }
733
+ self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
734
+ self.qk_rope_head_dim,
735
+ max_position_embeddings=self.max_position_embeddings,
736
+ scaling_factor=scaling_factor,
737
+ base=self.rope_theta,
738
+ **kwargs,
739
+ )
740
+ else:
741
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
742
+
743
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
744
+ return (
745
+ tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
746
+ .transpose(1, 2)
747
+ .contiguous()
748
+ )
749
+
750
+ def forward(
751
+ self,
752
+ hidden_states: torch.Tensor,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ position_ids: Optional[torch.LongTensor] = None,
755
+ past_key_value: Optional[Cache] = None,
756
+ output_attentions: bool = False,
757
+ use_cache: bool = False,
758
+ **kwargs,
759
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
760
+ if "padding_mask" in kwargs:
761
+ warnings.warn(
762
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
763
+ )
764
+ bsz, q_len, _ = hidden_states.size()
765
+
766
+ if self.q_lora_rank is None:
767
+ q = self.q_proj(hidden_states)
768
+ else:
769
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
770
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
771
+ q_nope, q_pe = torch.split(
772
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
773
+ )
774
+
775
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
776
+ compressed_kv, k_pe = torch.split(
777
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
778
+ )
779
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
780
+ kv = (
781
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
782
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
783
+ .transpose(1, 2)
784
+ )
785
+
786
+ k_nope, value_states = torch.split(
787
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
788
+ )
789
+ kv_seq_len = value_states.shape[-2]
790
+ if past_key_value is not None:
791
+ if self.layer_idx is None:
792
+ raise ValueError(
793
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
794
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
795
+ "with a layer index."
796
+ )
797
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
798
+
799
+ # Ensure RoPE cache covers all position_ids (important for MTP with absolute positions)
800
+ if position_ids is not None:
801
+ max_pos = position_ids.max().item() + 1
802
+ rope_seq_len = max(kv_seq_len, max_pos)
803
+ else:
804
+ rope_seq_len = kv_seq_len
805
+ cos, sin = self.rotary_emb(value_states, seq_len=rope_seq_len)
806
+
807
+ # Ensure cos and sin are expanded to cover position_ids
808
+ if position_ids is not None and position_ids.max().item() >= cos.shape[0]:
809
+ print(f"Warning: position_ids max ({position_ids.max().item()}) >= cos.shape[0] ({cos.shape[0]})")
810
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max().item() + 1)
811
+
812
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
813
+
814
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
815
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
816
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
817
+
818
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
819
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
820
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
821
+ if past_key_value is not None:
822
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
823
+ key_states, value_states = past_key_value.update(
824
+ key_states, value_states, self.layer_idx, cache_kwargs
825
+ )
826
+
827
+ attn_weights = (
828
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
829
+ )
830
+
831
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
832
+ raise ValueError(
833
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
834
+ f" {attn_weights.size()}"
835
+ )
836
+ assert attention_mask is not None
837
+ if attention_mask is not None:
838
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
839
+ raise ValueError(
840
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
841
+ )
842
+ attn_weights = attn_weights + attention_mask
843
+
844
+ # upcast attention to fp32
845
+ attn_weights = nn.functional.softmax(
846
+ attn_weights, dim=-1, dtype=torch.float32
847
+ ).to(query_states.dtype)
848
+ attn_weights = nn.functional.dropout(
849
+ attn_weights, p=self.attention_dropout, training=self.training
850
+ )
851
+ attn_output = torch.matmul(attn_weights, value_states)
852
+
853
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
854
+ raise ValueError(
855
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
856
+ f" {attn_output.size()}"
857
+ )
858
+
859
+ attn_output = attn_output.transpose(1, 2).contiguous()
860
+
861
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
862
+
863
+ attn_output = self.o_proj(attn_output)
864
+
865
+ if not output_attentions:
866
+ attn_weights = None
867
+
868
+ return attn_output, attn_weights, past_key_value
869
+
870
+
871
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
872
+ class DeepseekV3FlashAttention2(DeepseekV3Attention):
873
+ """
874
+ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
875
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
876
+ flash attention and deal with padding tokens in case the input contains any of them.
877
+ """
878
+
879
+ def __init__(self, *args, **kwargs):
880
+ super().__init__(*args, **kwargs)
881
+
882
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
883
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
884
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
885
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
886
+
887
+ def forward(
888
+ self,
889
+ hidden_states: torch.Tensor,
890
+ attention_mask: Optional[torch.LongTensor] = None,
891
+ position_ids: Optional[torch.LongTensor] = None,
892
+ past_key_value: Optional[Cache] = None,
893
+ output_attentions: bool = False,
894
+ use_cache: bool = False,
895
+ **kwargs,
896
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
897
+ # DeepseekV3FlashAttention2 attention does not support output_attentions
898
+ if "padding_mask" in kwargs:
899
+ warnings.warn(
900
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
901
+ )
902
+
903
+ # overwrite attention_mask with padding_mask
904
+ attention_mask = kwargs.pop("padding_mask")
905
+
906
+ output_attentions = False
907
+
908
+ bsz, q_len, _ = hidden_states.size()
909
+
910
+ if self.q_lora_rank is None:
911
+ q = self.q_proj(hidden_states)
912
+ else:
913
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
914
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
915
+ q_nope, q_pe = torch.split(
916
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
917
+ )
918
+
919
+ # Flash attention requires the input to have the shape
920
+ # batch_size x seq_length x head_dim x hidden_dim
921
+ # therefore we just need to keep the original shape
922
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
923
+ compressed_kv, k_pe = torch.split(
924
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
925
+ )
926
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
927
+ kv = (
928
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
929
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
930
+ .transpose(1, 2)
931
+ )
932
+
933
+ k_nope, value_states = torch.split(
934
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
935
+ )
936
+ kv_seq_len = value_states.shape[-2]
937
+
938
+ kv_seq_len = value_states.shape[-2]
939
+ if past_key_value is not None:
940
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
941
+
942
+ # Ensure RoPE cache covers all position_ids (important for MTP with absolute positions)
943
+ if position_ids is not None:
944
+ max_pos = position_ids.max().item() + 1
945
+ rope_seq_len = max(kv_seq_len, max_pos)
946
+ else:
947
+ rope_seq_len = kv_seq_len
948
+ cos, sin = self.rotary_emb(value_states, seq_len=rope_seq_len)
949
+
950
+ # Ensure cos and sin are expanded to cover position_ids
951
+ if position_ids is not None and position_ids.max().item() >= cos.shape[0]:
952
+ print(f"Warning: position_ids max ({position_ids.max().item()}) >= cos.shape[0] ({cos.shape[0]})")
953
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max().item() + 1)
954
+
955
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
956
+
957
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
958
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
959
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
960
+
961
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
962
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
963
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
964
+
965
+ if self.q_head_dim != self.v_head_dim:
966
+ value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
967
+
968
+ if past_key_value is not None:
969
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
970
+ key_states, value_states = past_key_value.update(
971
+ key_states, value_states, self.layer_idx, cache_kwargs
972
+ )
973
+
974
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
975
+ # to be able to avoid many of these transpose/reshape/view.
976
+ query_states = query_states.transpose(1, 2)
977
+ key_states = key_states.transpose(1, 2)
978
+ value_states = value_states.transpose(1, 2)
979
+
980
+ dropout_rate = self.attention_dropout if self.training else 0.0
981
+
982
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
983
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
984
+ # cast them back in the correct dtype just to be sure everything works as expected.
985
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
986
+ # in fp32. (DeepseekV3RMSNorm handles it correctly)
987
+
988
+ input_dtype = query_states.dtype
989
+ if input_dtype == torch.float32:
990
+ # Handle the case where the model is quantized
991
+ if hasattr(self.config, "_pre_quantization_dtype"):
992
+ target_dtype = self.config._pre_quantization_dtype
993
+ elif torch.is_autocast_enabled():
994
+ target_dtype = torch.get_autocast_gpu_dtype()
995
+ else:
996
+ target_dtype = (
997
+ self.q_proj.weight.dtype
998
+ if self.q_lora_rank is None
999
+ else self.q_a_proj.weight.dtype
1000
+ )
1001
+
1002
+ logger.warning_once(
1003
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1004
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1005
+ f" {target_dtype}."
1006
+ )
1007
+
1008
+ query_states = query_states.to(target_dtype)
1009
+ key_states = key_states.to(target_dtype)
1010
+ value_states = value_states.to(target_dtype)
1011
+
1012
+ attn_output = self._flash_attention_forward(
1013
+ query_states,
1014
+ key_states,
1015
+ value_states,
1016
+ attention_mask,
1017
+ q_len,
1018
+ dropout=dropout_rate,
1019
+ softmax_scale=self.softmax_scale,
1020
+ )
1021
+ if self.q_head_dim != self.v_head_dim:
1022
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
1023
+
1024
+ attn_output = attn_output.reshape(
1025
+ bsz, q_len, self.num_heads * self.v_head_dim
1026
+ ).contiguous()
1027
+ attn_output = self.o_proj(attn_output)
1028
+
1029
+ if not output_attentions:
1030
+ attn_weights = None
1031
+
1032
+ return attn_output, attn_weights, past_key_value
1033
+
1034
+ def _flash_attention_forward(
1035
+ self,
1036
+ query_states,
1037
+ key_states,
1038
+ value_states,
1039
+ attention_mask,
1040
+ query_length,
1041
+ dropout=0.0,
1042
+ softmax_scale=None,
1043
+ ):
1044
+ """
1045
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1046
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1047
+
1048
+ Args:
1049
+ query_states (`torch.Tensor`):
1050
+ Input query states to be passed to Flash Attention API
1051
+ key_states (`torch.Tensor`):
1052
+ Input key states to be passed to Flash Attention API
1053
+ value_states (`torch.Tensor`):
1054
+ Input value states to be passed to Flash Attention API
1055
+ attention_mask (`torch.Tensor`):
1056
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1057
+ position of padding tokens and 1 for the position of non-padding tokens.
1058
+ dropout (`int`, *optional*):
1059
+ Attention dropout
1060
+ softmax_scale (`float`, *optional*):
1061
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1062
+ """
1063
+ if not self._flash_attn_uses_top_left_mask:
1064
+ causal = self.is_causal
1065
+ else:
1066
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
1067
+ causal = self.is_causal and query_length != 1
1068
+
1069
+ # Contains at least one padding token in the sequence
1070
+ if attention_mask is not None:
1071
+ batch_size = query_states.shape[0]
1072
+ (
1073
+ query_states,
1074
+ key_states,
1075
+ value_states,
1076
+ indices_q,
1077
+ cu_seq_lens,
1078
+ max_seq_lens,
1079
+ ) = self._upad_input(
1080
+ query_states, key_states, value_states, attention_mask, query_length
1081
+ )
1082
+
1083
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1084
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1085
+
1086
+ attn_output_unpad = flash_attn_varlen_func(
1087
+ query_states,
1088
+ key_states,
1089
+ value_states,
1090
+ cu_seqlens_q=cu_seqlens_q,
1091
+ cu_seqlens_k=cu_seqlens_k,
1092
+ max_seqlen_q=max_seqlen_in_batch_q,
1093
+ max_seqlen_k=max_seqlen_in_batch_k,
1094
+ dropout_p=dropout,
1095
+ softmax_scale=softmax_scale,
1096
+ causal=causal,
1097
+ )
1098
+
1099
+ attn_output = pad_input(
1100
+ attn_output_unpad, indices_q, batch_size, query_length
1101
+ )
1102
+ else:
1103
+ attn_output = flash_attn_func(
1104
+ query_states,
1105
+ key_states,
1106
+ value_states,
1107
+ dropout,
1108
+ softmax_scale=softmax_scale,
1109
+ causal=causal,
1110
+ )
1111
+
1112
+ return attn_output
1113
+
1114
+ def _upad_input(
1115
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
1116
+ ):
1117
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1118
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1119
+
1120
+ key_layer = index_first_axis(
1121
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1122
+ indices_k,
1123
+ )
1124
+ value_layer = index_first_axis(
1125
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1126
+ indices_k,
1127
+ )
1128
+ if query_length == kv_seq_len:
1129
+ query_layer = index_first_axis(
1130
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1131
+ indices_k,
1132
+ )
1133
+ cu_seqlens_q = cu_seqlens_k
1134
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1135
+ indices_q = indices_k
1136
+ elif query_length == 1:
1137
+ max_seqlen_in_batch_q = 1
1138
+ cu_seqlens_q = torch.arange(
1139
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1140
+ ) # There is a memcpy here, that is very bad.
1141
+ indices_q = cu_seqlens_q[:-1]
1142
+ query_layer = query_layer.squeeze(1)
1143
+ else:
1144
+ # The -q_len: slice assumes left padding.
1145
+ attention_mask = attention_mask[:, -query_length:]
1146
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1147
+ query_layer, attention_mask
1148
+ )
1149
+
1150
+ return (
1151
+ query_layer,
1152
+ key_layer,
1153
+ value_layer,
1154
+ indices_q,
1155
+ (cu_seqlens_q, cu_seqlens_k),
1156
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1157
+ )
1158
+
1159
+
1160
+ ATTENTION_CLASSES = {
1161
+ "eager": DeepseekV3Attention,
1162
+ "flash_attention_2": DeepseekV3FlashAttention2,
1163
+ }
1164
+
1165
+ class DeepseekV3SharedHead(nn.Module):
1166
+ def __init__(
1167
+ self,
1168
+ config: DeepseekV3Config,
1169
+ ) -> None:
1170
+ super().__init__()
1171
+ self.padding_idx = config.pad_token_id
1172
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1173
+ self.head = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1174
+
1175
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1176
+ return self.norm(hidden_states)
1177
+
1178
+ class DeepseekV3MultiTokenPredictorLayer(nn.Module):
1179
+ def __init__(self, config: DeepseekV3Config, layer_idx: int) -> None:
1180
+ super().__init__()
1181
+
1182
+ self.config = config
1183
+ self.padding_idx = config.pad_token_id
1184
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1185
+ self.enorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1186
+ self.hnorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1187
+ self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
1188
+ # Ensure eh_proj weight stays in bfloat16, bypassing FP8 quantization
1189
+ self.eh_proj.weight.data = self.eh_proj.weight.data.to(torch.bfloat16)
1190
+
1191
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
1192
+ self.mlp = DeepseekV3MoE(config)
1193
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1194
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1195
+
1196
+ self.shared_head = DeepseekV3SharedHead(config=config)
1197
+
1198
+ def forward(
1199
+ self,
1200
+ input_ids: torch.Tensor,
1201
+ positions: torch.Tensor,
1202
+ previous_hidden_states: torch.Tensor,
1203
+ inputs_embeds: torch.Tensor | None = None,
1204
+ attention_mask: Optional[torch.Tensor] = None,
1205
+ spec_step_index: int = 0,
1206
+ ) -> torch.Tensor:
1207
+ """
1208
+ Forward pass for Multi-Token Predictor Layer.
1209
+
1210
+ Args:
1211
+ input_ids: Input token IDs (length N)
1212
+ positions: Position indices for tokens (length N)
1213
+ previous_hidden_states: Hidden states from the previous (main) model layer (length N-1)
1214
+ previous_hidden_states[i] corresponds to input_ids[i+1]
1215
+ inputs_embeds: Optional pre-computed input embeddings (length N)
1216
+ attention_mask: Optional attention mask
1217
+ spec_step_index: Speculative decoding step index
1218
+
1219
+ Returns:
1220
+ hidden_states: Output hidden states (length N-1, corresponding to input_ids[1:])
1221
+ """
1222
+ assert inputs_embeds is not None, "inputs_embeds must be provided for MTP"
1223
+
1224
+ # Shift input_ids by 1 to align with previous_hidden_states
1225
+ # previous_hidden_states[i] corresponds to input_ids[i+1]
1226
+ # So we use input_ids[1:] and positions[1:]
1227
+ shifted_input_ids = input_ids[:, 1:]
1228
+ shifted_positions = positions[:, 1:]
1229
+ shifted_inputs_embeds = inputs_embeds[:, 1:, :]
1230
+
1231
+ # Mask inputs at position 0, as they are not needed by MTP
1232
+ shifted_inputs_embeds = shifted_inputs_embeds.clone()
1233
+ shifted_inputs_embeds[shifted_positions == 0] = 0
1234
+
1235
+ # Normalize input embeddings and previous hidden states
1236
+ shifted_inputs_embeds = self.enorm(shifted_inputs_embeds)
1237
+ previous_hidden_states = self.hnorm(previous_hidden_states)
1238
+
1239
+ # Concatenate and project
1240
+ hidden_states = self.eh_proj(
1241
+ torch.cat([shifted_inputs_embeds, previous_hidden_states], dim=-1)
1242
+ )
1243
+
1244
+ # If no attention mask is provided, create a causal one
1245
+ if attention_mask is None:
1246
+ batch_size, seq_len = shifted_input_ids.shape[:2]
1247
+ attention_mask = torch.tril(
1248
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=shifted_input_ids.device)
1249
+ ).unsqueeze(0).unsqueeze(0)
1250
+ attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
1251
+ # Convert to attention weights format (0 for attend, -inf for mask)
1252
+ attention_mask = torch.where(
1253
+ attention_mask,
1254
+ torch.zeros_like(attention_mask, dtype=hidden_states.dtype),
1255
+ torch.full_like(attention_mask, float('-inf'), dtype=hidden_states.dtype)
1256
+ )
1257
+
1258
+ # Self-attention
1259
+ residual = hidden_states
1260
+ hidden_states = self.input_layernorm(hidden_states)
1261
+ hidden_states, _, _ = self.self_attn(
1262
+ hidden_states=hidden_states,
1263
+ attention_mask=attention_mask,
1264
+ position_ids=shifted_positions,
1265
+ past_key_value=None,
1266
+ output_attentions=False,
1267
+ use_cache=False,
1268
+ )
1269
+ hidden_states = residual + hidden_states
1270
+
1271
+ # MLP
1272
+ residual = hidden_states
1273
+ hidden_states = self.post_attention_layernorm(hidden_states)
1274
+ hidden_states = self.mlp(hidden_states)
1275
+ hidden_states = residual + hidden_states
1276
+
1277
+ return hidden_states
1278
+
1279
+
1280
+ class DeepseekV3DecoderLayer(nn.Module):
1281
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
1282
+ super().__init__()
1283
+ self.hidden_size = config.hidden_size
1284
+
1285
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1286
+ config=config, layer_idx=layer_idx
1287
+ )
1288
+
1289
+ self.mlp = (
1290
+ DeepseekV3MoE(config)
1291
+ if (
1292
+ config.n_routed_experts is not None
1293
+ and layer_idx >= config.first_k_dense_replace
1294
+ and layer_idx % config.moe_layer_freq == 0
1295
+ )
1296
+ else DeepseekV3MLP(config)
1297
+ )
1298
+ self.input_layernorm = DeepseekV3RMSNorm(
1299
+ config.hidden_size, eps=config.rms_norm_eps
1300
+ )
1301
+ self.post_attention_layernorm = DeepseekV3RMSNorm(
1302
+ config.hidden_size, eps=config.rms_norm_eps
1303
+ )
1304
+
1305
+ def forward(
1306
+ self,
1307
+ hidden_states: torch.Tensor,
1308
+ attention_mask: Optional[torch.Tensor] = None,
1309
+ position_ids: Optional[torch.LongTensor] = None,
1310
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1311
+ output_attentions: Optional[bool] = False,
1312
+ use_cache: Optional[bool] = False,
1313
+ **kwargs,
1314
+ ) -> Tuple[
1315
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1316
+ ]:
1317
+ """
1318
+ Args:
1319
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1320
+ attention_mask (`torch.FloatTensor`, *optional*):
1321
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1322
+ query_sequence_length, key_sequence_length)` if default attention is used.
1323
+ output_attentions (`bool`, *optional*):
1324
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1325
+ returned tensors for more detail.
1326
+ use_cache (`bool`, *optional*):
1327
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1328
+ (see `past_key_values`).
1329
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1330
+ """
1331
+ if "padding_mask" in kwargs:
1332
+ warnings.warn(
1333
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1334
+ )
1335
+ residual = hidden_states
1336
+
1337
+ hidden_states = self.input_layernorm(hidden_states)
1338
+
1339
+ # Self Attention
1340
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1341
+ hidden_states=hidden_states,
1342
+ attention_mask=attention_mask,
1343
+ position_ids=position_ids,
1344
+ past_key_value=past_key_value,
1345
+ output_attentions=output_attentions,
1346
+ use_cache=use_cache,
1347
+ **kwargs,
1348
+ )
1349
+ hidden_states = residual + hidden_states
1350
+
1351
+ # Fully Connected
1352
+ residual = hidden_states
1353
+ hidden_states = self.post_attention_layernorm(hidden_states)
1354
+ hidden_states = self.mlp(hidden_states)
1355
+ hidden_states = residual + hidden_states
1356
+
1357
+ outputs = (hidden_states,)
1358
+
1359
+ if output_attentions:
1360
+ outputs += (self_attn_weights,)
1361
+
1362
+ if use_cache:
1363
+ outputs += (present_key_value,)
1364
+
1365
+ return outputs
1366
+
1367
+
1368
+ DeepseekV3_START_DOCSTRING = r"""
1369
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1370
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1371
+ etc.)
1372
+
1373
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1374
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1375
+ and behavior.
1376
+
1377
+ Parameters:
1378
+ config ([`DeepseekV3Config`]):
1379
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1380
+ load the weights associated with the model, only the configuration. Check out the
1381
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1382
+ """
1383
+
1384
+
1385
+ @add_start_docstrings(
1386
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1387
+ DeepseekV3_START_DOCSTRING,
1388
+ )
1389
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
1390
+ config_class = DeepseekV3Config
1391
+ base_model_prefix = "model"
1392
+ supports_gradient_checkpointing = True
1393
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
1394
+ _skip_keys_device_placement = "past_key_values"
1395
+ _supports_flash_attn_2 = True
1396
+ _supports_cache_class = True
1397
+
1398
+ def _init_weights(self, module):
1399
+ std = self.config.initializer_range
1400
+ if isinstance(module, nn.Linear):
1401
+ module.weight.data.normal_(mean=0.0, std=std)
1402
+ if module.bias is not None:
1403
+ module.bias.data.zero_()
1404
+ elif isinstance(module, nn.Embedding):
1405
+ module.weight.data.normal_(mean=0.0, std=std)
1406
+ if module.padding_idx is not None:
1407
+ module.weight.data[module.padding_idx].zero_()
1408
+
1409
+
1410
+ DeepseekV3_INPUTS_DOCSTRING = r"""
1411
+ Args:
1412
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1413
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1414
+ it.
1415
+
1416
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1417
+ [`PreTrainedTokenizer.__call__`] for details.
1418
+
1419
+ [What are input IDs?](../glossary#input-ids)
1420
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1421
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1422
+
1423
+ - 1 for tokens that are **not masked**,
1424
+ - 0 for tokens that are **masked**.
1425
+
1426
+ [What are attention masks?](../glossary#attention-mask)
1427
+
1428
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1429
+ [`PreTrainedTokenizer.__call__`] for details.
1430
+
1431
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1432
+ `past_key_values`).
1433
+
1434
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1435
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1436
+ information on the default strategy.
1437
+
1438
+ - 1 indicates the head is **not masked**,
1439
+ - 0 indicates the head is **masked**.
1440
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1441
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1442
+ config.n_positions - 1]`.
1443
+
1444
+ [What are position IDs?](../glossary#position-ids)
1445
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1446
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1447
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1448
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1449
+
1450
+ Two formats are allowed:
1451
+ - a [`~cache_utils.Cache`] instance;
1452
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1453
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1454
+ cache format.
1455
+
1456
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1457
+ legacy cache format will be returned.
1458
+
1459
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1460
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1461
+ of shape `(batch_size, sequence_length)`.
1462
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1463
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1464
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1465
+ model's internal embedding lookup matrix.
1466
+ use_cache (`bool`, *optional*):
1467
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1468
+ `past_key_values`).
1469
+ output_attentions (`bool`, *optional*):
1470
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1471
+ tensors for more detail.
1472
+ output_hidden_states (`bool`, *optional*):
1473
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1474
+ more detail.
1475
+ return_dict (`bool`, *optional*):
1476
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1477
+ """
1478
+
1479
+
1480
+ @add_start_docstrings(
1481
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1482
+ DeepseekV3_START_DOCSTRING,
1483
+ )
1484
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1485
+ """
1486
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
1487
+
1488
+ Args:
1489
+ config: DeepseekV3Config
1490
+ """
1491
+
1492
+ def __init__(self, config: DeepseekV3Config):
1493
+ super().__init__(config)
1494
+ self.padding_idx = config.pad_token_id
1495
+ self.vocab_size = config.vocab_size
1496
+
1497
+ self.embed_tokens = nn.Embedding(
1498
+ config.vocab_size, config.hidden_size, self.padding_idx
1499
+ )
1500
+ self.layers = nn.ModuleList(
1501
+ [
1502
+ DeepseekV3DecoderLayer(config, layer_idx)
1503
+ for layer_idx in range(config.num_hidden_layers)
1504
+ ]
1505
+ )
1506
+
1507
+ for layer_idx in range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers):
1508
+ self.layers.append(DeepseekV3MultiTokenPredictorLayer(config, layer_idx))
1509
+
1510
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1511
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1512
+
1513
+ self.gradient_checkpointing = False
1514
+ # Initialize weights and apply final processing
1515
+ self.post_init()
1516
+
1517
+ def get_input_embeddings(self):
1518
+ return self.embed_tokens
1519
+
1520
+ def set_input_embeddings(self, value):
1521
+ self.embed_tokens = value
1522
+
1523
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1524
+ def forward(
1525
+ self,
1526
+ input_ids: torch.LongTensor = None,
1527
+ attention_mask: Optional[torch.Tensor] = None,
1528
+ position_ids: Optional[torch.LongTensor] = None,
1529
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1530
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1531
+ use_cache: Optional[bool] = None,
1532
+ output_attentions: Optional[bool] = None,
1533
+ output_hidden_states: Optional[bool] = None,
1534
+ return_dict: Optional[bool] = None,
1535
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1536
+ output_attentions = (
1537
+ output_attentions
1538
+ if output_attentions is not None
1539
+ else self.config.output_attentions
1540
+ )
1541
+ output_hidden_states = (
1542
+ output_hidden_states
1543
+ if output_hidden_states is not None
1544
+ else self.config.output_hidden_states
1545
+ )
1546
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1547
+
1548
+ return_dict = (
1549
+ return_dict if return_dict is not None else self.config.use_return_dict
1550
+ )
1551
+
1552
+ # retrieve input_ids and inputs_embeds
1553
+ if input_ids is not None and inputs_embeds is not None:
1554
+ raise ValueError(
1555
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1556
+ )
1557
+ elif input_ids is not None:
1558
+ batch_size, seq_length = input_ids.shape[:2]
1559
+ elif inputs_embeds is not None:
1560
+ batch_size, seq_length = inputs_embeds.shape[:2]
1561
+ else:
1562
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1563
+
1564
+ past_key_values_length = 0
1565
+ if use_cache:
1566
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1567
+ if use_legacy_cache:
1568
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1569
+ past_key_values_length = past_key_values.get_seq_length()
1570
+
1571
+ if position_ids is None:
1572
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1573
+ position_ids = torch.arange(
1574
+ past_key_values_length,
1575
+ seq_length + past_key_values_length,
1576
+ dtype=torch.long,
1577
+ device=device,
1578
+ )
1579
+ position_ids = position_ids.unsqueeze(0)
1580
+
1581
+ if inputs_embeds is None:
1582
+ inputs_embeds = self.embed_tokens(input_ids)
1583
+
1584
+ if self._use_flash_attention_2:
1585
+ # 2d mask is passed through the layers
1586
+ attention_mask = (
1587
+ attention_mask
1588
+ if (attention_mask is not None and 0 in attention_mask)
1589
+ else None
1590
+ )
1591
+ else:
1592
+ # 4d mask is passed through the layers
1593
+ attention_mask = _prepare_4d_causal_attention_mask(
1594
+ attention_mask,
1595
+ (batch_size, seq_length),
1596
+ inputs_embeds,
1597
+ past_key_values_length,
1598
+ )
1599
+
1600
+ # embed positions
1601
+ hidden_states = inputs_embeds
1602
+
1603
+ # decoder layers (only process base decoder layers, not MTP layers)
1604
+ all_hidden_states = () if output_hidden_states else None
1605
+ all_self_attns = () if output_attentions else None
1606
+ next_decoder_cache = None
1607
+
1608
+ # Only iterate through base decoder layers, not MTP layers
1609
+ num_base_layers = self.config.num_hidden_layers
1610
+ for layer_idx in range(num_base_layers):
1611
+ decoder_layer = self.layers[layer_idx]
1612
+
1613
+ if output_hidden_states:
1614
+ all_hidden_states += (hidden_states,)
1615
+
1616
+ layer_outputs = decoder_layer(
1617
+ hidden_states,
1618
+ attention_mask=attention_mask,
1619
+ position_ids=position_ids,
1620
+ past_key_value=past_key_values,
1621
+ output_attentions=output_attentions,
1622
+ use_cache=use_cache,
1623
+ )
1624
+
1625
+ hidden_states = layer_outputs[0]
1626
+
1627
+ if use_cache:
1628
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1629
+
1630
+ if output_attentions:
1631
+ all_self_attns += (layer_outputs[1],)
1632
+
1633
+ hidden_states = self.norm(hidden_states)
1634
+
1635
+ # add hidden states from the last decoder layer
1636
+ if output_hidden_states:
1637
+ all_hidden_states += (hidden_states,)
1638
+
1639
+ next_cache = None
1640
+ if use_cache:
1641
+ next_cache = (
1642
+ next_decoder_cache.to_legacy_cache()
1643
+ if use_legacy_cache
1644
+ else next_decoder_cache
1645
+ )
1646
+ if not return_dict:
1647
+ return tuple(
1648
+ v
1649
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1650
+ if v is not None
1651
+ )
1652
+ return BaseModelOutputWithPast(
1653
+ last_hidden_state=hidden_states,
1654
+ past_key_values=next_cache,
1655
+ hidden_states=all_hidden_states,
1656
+ attentions=all_self_attns,
1657
+ )
1658
+
1659
+
1660
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1661
+ _tied_weights_keys = ["lm_head.weight"]
1662
+
1663
+ def __init__(self, config):
1664
+ super().__init__(config)
1665
+ self.model = DeepseekV3Model(config)
1666
+ self.vocab_size = config.vocab_size
1667
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1668
+
1669
+ # Dynamically add multi-token predictor layer embeddings to tied weights
1670
+ tied_weights = ["lm_head.weight"]
1671
+ for layer_idx in range(config.num_hidden_layers,
1672
+ config.num_hidden_layers + config.num_nextn_predict_layers):
1673
+ tied_weights.append(f"model.layers.{layer_idx}.embed_tokens.weight")
1674
+ self._tied_weights_keys = tied_weights
1675
+
1676
+ # Initialize weights and apply final processing
1677
+ self.post_init()
1678
+
1679
+ def get_input_embeddings(self):
1680
+ return self.model.embed_tokens
1681
+
1682
+ def set_input_embeddings(self, value):
1683
+ self.model.embed_tokens = value
1684
+
1685
+ def get_output_embeddings(self):
1686
+ return self.lm_head
1687
+
1688
+ def set_output_embeddings(self, new_embeddings):
1689
+ self.lm_head = new_embeddings
1690
+
1691
+ def set_decoder(self, decoder):
1692
+ self.model = decoder
1693
+
1694
+ def get_decoder(self):
1695
+ return self.model
1696
+
1697
+ def get_mtp_layers(self):
1698
+ """
1699
+ Get the Multi-Token Predictor layers from the model.
1700
+
1701
+ Returns:
1702
+ List of MTP layers
1703
+ """
1704
+ mtp_layers = []
1705
+ num_base_layers = self.config.num_hidden_layers
1706
+ num_mtp_layers = getattr(self.config, 'num_nextn_predict_layers', 0)
1707
+
1708
+ for idx in range(num_base_layers, num_base_layers + num_mtp_layers):
1709
+ if idx < len(self.model.layers):
1710
+ layer = self.model.layers[idx]
1711
+ if isinstance(layer, DeepseekV3MultiTokenPredictorLayer):
1712
+ mtp_layers.append(layer)
1713
+
1714
+ return mtp_layers
1715
+
1716
+ def mtp_forward(
1717
+ self,
1718
+ input_ids: torch.Tensor,
1719
+ positions: torch.Tensor,
1720
+ previous_hidden_states: torch.Tensor,
1721
+ spec_step_idx: int = 0,
1722
+ ) -> torch.Tensor:
1723
+ """
1724
+ Forward pass for Multi-Token Predictor.
1725
+
1726
+ Time-step relationship:
1727
+ - If main model processes tokens up to position t
1728
+ - Main LM head predicts token at position t+1
1729
+ - MTP receives: previous_hidden_states from position t + token_{t+1}
1730
+ - MTP processes at position t+1 to predict token at position t+2
1731
+
1732
+ Args:
1733
+ input_ids: Token ID at position t+1 (predicted by main model or ground truth)
1734
+ positions: Position indices (should be t+1)
1735
+ previous_hidden_states: Hidden states from main model at position t
1736
+ spec_step_idx: Speculative decoding step index (0-based, which MTP layer to use)
1737
+
1738
+ Returns:
1739
+ hidden_states: Hidden states after MTP processing (used to predict token at t+2)
1740
+ """
1741
+ mtp_layers = self.get_mtp_layers()
1742
+ if not mtp_layers:
1743
+ raise ValueError("No MTP layers found in the model")
1744
+
1745
+ num_mtp_layers = len(mtp_layers)
1746
+ current_step_idx = spec_step_idx % num_mtp_layers
1747
+
1748
+ # Get input embeddings
1749
+ inputs_embeds = self.model.embed_tokens(input_ids)
1750
+
1751
+ # Forward through the corresponding MTP layer
1752
+ hidden_states = mtp_layers[current_step_idx].forward(
1753
+ input_ids=input_ids,
1754
+ positions=positions,
1755
+ previous_hidden_states=previous_hidden_states,
1756
+ inputs_embeds=inputs_embeds,
1757
+ spec_step_index=current_step_idx,
1758
+ )
1759
+
1760
+ return hidden_states
1761
+
1762
+ def compute_mtp_logits(
1763
+ self,
1764
+ hidden_states: torch.Tensor,
1765
+ spec_step_idx: int = 0,
1766
+ ) -> torch.Tensor:
1767
+ """
1768
+ Compute logits from MTP hidden states.
1769
+
1770
+ Time-step relationship:
1771
+ - If MTP processes at position t+1 (given token_{t+1})
1772
+ - This function computes logits for token at position t+2
1773
+ - Main LM head predicts t+1, MTP head predicts t+2
1774
+
1775
+ Args:
1776
+ hidden_states: Hidden states from MTP layer at position t+1
1777
+ spec_step_idx: Speculative decoding step index (which MTP layer was used)
1778
+
1779
+ Returns:
1780
+ Logits for token prediction at position t+2
1781
+ """
1782
+ mtp_layers = self.get_mtp_layers()
1783
+ if not mtp_layers:
1784
+ raise ValueError("No MTP layers found in the model")
1785
+
1786
+ num_mtp_layers = len(mtp_layers)
1787
+ current_step_idx = spec_step_idx % num_mtp_layers
1788
+
1789
+ # Get the corresponding MTP layer's shared head
1790
+ mtp_layer = mtp_layers[current_step_idx]
1791
+
1792
+ # Normalize hidden states
1793
+ normalized_hidden_states = mtp_layer.shared_head(hidden_states)
1794
+
1795
+ # Compute logits
1796
+ logits = self.lm_head(normalized_hidden_states)
1797
+ logits = logits.float()
1798
+
1799
+ return logits
1800
+
1801
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1802
+ @replace_return_docstrings(
1803
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1804
+ )
1805
+ def forward(
1806
+ self,
1807
+ input_ids: torch.LongTensor = None,
1808
+ attention_mask: Optional[torch.Tensor] = None,
1809
+ position_ids: Optional[torch.LongTensor] = None,
1810
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1811
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1812
+ labels: Optional[torch.LongTensor] = None,
1813
+ use_cache: Optional[bool] = None,
1814
+ output_attentions: Optional[bool] = None,
1815
+ output_hidden_states: Optional[bool] = None,
1816
+ return_dict: Optional[bool] = None,
1817
+ use_mtp_pipeline: Optional[bool] = True
1818
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1819
+ r"""
1820
+ Args:
1821
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1822
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1823
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1824
+ (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1825
+
1826
+ Returns:
1827
+
1828
+ Example:
1829
+
1830
+ ```python
1831
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
1832
+
1833
+ >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1834
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1835
+
1836
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1837
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1838
+
1839
+ >>> # Generate
1840
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1841
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1842
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1843
+ ```"""
1844
+ output_attentions = (
1845
+ output_attentions
1846
+ if output_attentions is not None
1847
+ else self.config.output_attentions
1848
+ )
1849
+ output_hidden_states = (
1850
+ output_hidden_states
1851
+ if output_hidden_states is not None
1852
+ else self.config.output_hidden_states
1853
+ )
1854
+ return_dict = (
1855
+ return_dict if return_dict is not None else self.config.use_return_dict
1856
+ )
1857
+
1858
+ # Run main model (always needed)
1859
+ # Force output_hidden_states=True if using MTP pipeline
1860
+ _output_hidden_states = True if use_mtp_pipeline else output_hidden_states
1861
+ outputs = self.model(
1862
+ input_ids=input_ids,
1863
+ attention_mask=attention_mask,
1864
+ position_ids=position_ids,
1865
+ past_key_values=past_key_values,
1866
+ inputs_embeds=inputs_embeds,
1867
+ use_cache=use_cache,
1868
+ output_attentions=output_attentions,
1869
+ output_hidden_states=_output_hidden_states,
1870
+ return_dict=True,
1871
+ )
1872
+
1873
+ # Compute main model logits
1874
+ hidden_states = outputs.hidden_states[-1] if _output_hidden_states else outputs.last_hidden_state
1875
+ logits = self.lm_head(hidden_states).float()
1876
+ # Run MTP pipeline if enabled
1877
+ if use_mtp_pipeline:
1878
+ # Get prediction from main model (token at position t+1)
1879
+ # Ensure the predicted token is on the same device as input_ids
1880
+ main_predicted_token = logits[:, -1:, :].argmax(dim=-1).to(input_ids.device)
1881
+
1882
+ # Get number of MTP tokens from config
1883
+ num_mtp_tokens = getattr(self.config, 'num_nextn_predict_layers', 1)
1884
+
1885
+ # MTP inference using main model results
1886
+ seq_len = input_ids.shape[1] if input_ids is not None else hidden_states.shape[1]
1887
+
1888
+ # Initialize: Start with original input_ids and hidden_states
1889
+ # Accumulate generated tokens to form complete sequence for MTP
1890
+ accumulated_input_ids = input_ids # Original input sequence
1891
+ # Ensure hidden_states is on the same device as input_ids (for multi-GPU)
1892
+ accumulated_hidden_states = hidden_states.to(input_ids.device) if hidden_states.device != input_ids.device else hidden_states
1893
+ generated_tokens = [main_predicted_token] # Tokens predicted by main model and MTP
1894
+
1895
+ # Collect MTP predictions
1896
+ mtp_logits_list = []
1897
+
1898
+ for mtp_step in range(num_mtp_tokens):
1899
+ # Build complete input sequence: original + all generated tokens so far
1900
+ mtp_input_ids = torch.cat(
1901
+ [accumulated_input_ids] + generated_tokens,
1902
+ dim=1
1903
+ )
1904
+
1905
+ # Build complete position IDs for the full sequence
1906
+ mtp_position_ids = torch.arange(
1907
+ 0, mtp_input_ids.shape[1],
1908
+ dtype=torch.long,
1909
+ device=mtp_input_ids.device
1910
+ ).unsqueeze(0).expand(mtp_input_ids.shape[0], -1)
1911
+
1912
+ # previous_hidden_states only from main model (not including generated tokens)
1913
+ # The MTP layer will handle the alignment: previous_hidden_states[i] corresponds to input_ids[i+1]
1914
+ mtp_prev_hidden = accumulated_hidden_states
1915
+
1916
+ # MTP forward with full sequence
1917
+ mtp_hidden = self.mtp_forward(
1918
+ input_ids=mtp_input_ids,
1919
+ positions=mtp_position_ids,
1920
+ previous_hidden_states=mtp_prev_hidden,
1921
+ spec_step_idx=mtp_step,
1922
+ )
1923
+
1924
+ # Compute MTP logits (only for the last position)
1925
+ mtp_logits = self.compute_mtp_logits(
1926
+ hidden_states=mtp_hidden[:, -1:, :],
1927
+ spec_step_idx=mtp_step,
1928
+ )
1929
+
1930
+ # Get next token prediction
1931
+ # Ensure the predicted token is on the same device as input_ids
1932
+ next_token = mtp_logits.argmax(dim=-1).to(input_ids.device)
1933
+
1934
+ # Collect results
1935
+ # Ensure mtp_logits is on the same device as main logits
1936
+ mtp_logits_list.append(mtp_logits.to(logits.device))
1937
+
1938
+ # Update state for next MTP layer
1939
+ generated_tokens.append(next_token)
1940
+
1941
+ # Combine main logits with MTP logits
1942
+ logits = torch.cat([logits] + mtp_logits_list, dim=1)
1943
+
1944
+ loss = None
1945
+ if labels is not None:
1946
+ # Shift so that tokens < n predict n
1947
+ shift_logits = logits[..., :-1, :].contiguous()
1948
+ shift_labels = labels[..., 1:].contiguous()
1949
+ # Flatten the tokens
1950
+ loss_fct = CrossEntropyLoss()
1951
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1952
+ shift_labels = shift_labels.view(-1)
1953
+ # Enable model parallelism
1954
+ shift_labels = shift_labels.to(shift_logits.device)
1955
+ loss = loss_fct(shift_logits, shift_labels)
1956
+
1957
+ if not return_dict:
1958
+ output = (logits,) + outputs[1:]
1959
+ return (loss,) + output if loss is not None else output
1960
+
1961
+ return CausalLMOutputWithPast(
1962
+ loss=loss,
1963
+ logits=logits,
1964
+ past_key_values=outputs.past_key_values,
1965
+ hidden_states=outputs.hidden_states,
1966
+ attentions=outputs.attentions,
1967
+ )
1968
+
1969
+ def prepare_inputs_for_generation(
1970
+ self,
1971
+ input_ids,
1972
+ past_key_values=None,
1973
+ attention_mask=None,
1974
+ inputs_embeds=None,
1975
+ **kwargs,
1976
+ ):
1977
+ if past_key_values is not None:
1978
+ if isinstance(past_key_values, Cache):
1979
+ cache_length = past_key_values.get_seq_length()
1980
+ past_length = past_key_values.seen_tokens
1981
+ max_cache_length = past_key_values.get_max_length()
1982
+ else:
1983
+ cache_length = past_length = past_key_values[0][0].shape[2]
1984
+ max_cache_length = None
1985
+
1986
+ # Keep only the unprocessed tokens:
1987
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1988
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1989
+ # input)
1990
+ if (
1991
+ attention_mask is not None
1992
+ and attention_mask.shape[1] > input_ids.shape[1]
1993
+ ):
1994
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1995
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1996
+ # input_ids based on the past_length.
1997
+ elif past_length < input_ids.shape[1]:
1998
+ input_ids = input_ids[:, past_length:]
1999
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
2000
+
2001
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
2002
+ if (
2003
+ max_cache_length is not None
2004
+ and attention_mask is not None
2005
+ and cache_length + input_ids.shape[1] > max_cache_length
2006
+ ):
2007
+ attention_mask = attention_mask[:, -max_cache_length:]
2008
+
2009
+ position_ids = kwargs.get("position_ids", None)
2010
+ if attention_mask is not None and position_ids is None:
2011
+ # create position_ids on the fly for batch generation
2012
+ position_ids = attention_mask.long().cumsum(-1) - 1
2013
+ position_ids.masked_fill_(attention_mask == 0, 1)
2014
+ if past_key_values:
2015
+ position_ids = position_ids[:, -input_ids.shape[1] :]
2016
+
2017
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2018
+ if inputs_embeds is not None and past_key_values is None:
2019
+ model_inputs = {"inputs_embeds": inputs_embeds}
2020
+ else:
2021
+ model_inputs = {"input_ids": input_ids}
2022
+
2023
+ model_inputs.update(
2024
+ {
2025
+ "position_ids": position_ids,
2026
+ "past_key_values": past_key_values,
2027
+ "use_cache": kwargs.get("use_cache"),
2028
+ "attention_mask": attention_mask,
2029
+ "use_mtp_pipeline": kwargs.get("use_mtp_pipeline", False),
2030
+ }
2031
+ )
2032
+ return model_inputs
2033
+
2034
+ @staticmethod
2035
+ def _reorder_cache(past_key_values, beam_idx):
2036
+ reordered_past = ()
2037
+ for layer_past in past_key_values:
2038
+ reordered_past += (
2039
+ tuple(
2040
+ past_state.index_select(0, beam_idx.to(past_state.device))
2041
+ for past_state in layer_past
2042
+ ),
2043
+ )
2044
+ return reordered_past
2045
+
2046
+
2047
+ @add_start_docstrings(
2048
+ """
2049
+ The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).
2050
+
2051
+ [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
2052
+ (e.g. GPT-2) do.
2053
+
2054
+ Since it does classification on the last token, it requires to know the position of the last token. If a
2055
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
2056
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
2057
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
2058
+ each row of the batch).
2059
+ """,
2060
+ DeepseekV3_START_DOCSTRING,
2061
+ )
2062
+ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
2063
+ def __init__(self, config):
2064
+ super().__init__(config)
2065
+ self.num_labels = config.num_labels
2066
+ self.model = DeepseekV3Model(config)
2067
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
2068
+
2069
+ # Initialize weights and apply final processing
2070
+ self.post_init()
2071
+
2072
+ def get_input_embeddings(self):
2073
+ return self.model.embed_tokens
2074
+
2075
+ def set_input_embeddings(self, value):
2076
+ self.model.embed_tokens = value
2077
+
2078
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
2079
+ def forward(
2080
+ self,
2081
+ input_ids: torch.LongTensor = None,
2082
+ attention_mask: Optional[torch.Tensor] = None,
2083
+ position_ids: Optional[torch.LongTensor] = None,
2084
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
2085
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2086
+ labels: Optional[torch.LongTensor] = None,
2087
+ use_cache: Optional[bool] = None,
2088
+ output_attentions: Optional[bool] = None,
2089
+ output_hidden_states: Optional[bool] = None,
2090
+ return_dict: Optional[bool] = None,
2091
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
2092
+ r"""
2093
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2094
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
2095
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
2096
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
2097
+ """
2098
+ return_dict = (
2099
+ return_dict if return_dict is not None else self.config.use_return_dict
2100
+ )
2101
+
2102
+ transformer_outputs = self.model(
2103
+ input_ids,
2104
+ attention_mask=attention_mask,
2105
+ position_ids=position_ids,
2106
+ past_key_values=past_key_values,
2107
+ inputs_embeds=inputs_embeds,
2108
+ use_cache=use_cache,
2109
+ output_attentions=output_attentions,
2110
+ output_hidden_states=output_hidden_states,
2111
+ return_dict=return_dict,
2112
+ )
2113
+ hidden_states = transformer_outputs[0]
2114
+ logits = self.score(hidden_states)
2115
+
2116
+ if input_ids is not None:
2117
+ batch_size = input_ids.shape[0]
2118
+ else:
2119
+ batch_size = inputs_embeds.shape[0]
2120
+
2121
+ if self.config.pad_token_id is None and batch_size != 1:
2122
+ raise ValueError(
2123
+ "Cannot handle batch sizes > 1 if no padding token is defined."
2124
+ )
2125
+ if self.config.pad_token_id is None:
2126
+ sequence_lengths = -1
2127
+ else:
2128
+ if input_ids is not None:
2129
+ sequence_lengths = (
2130
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
2131
+ ).to(logits.device)
2132
+ else:
2133
+ sequence_lengths = -1
2134
+
2135
+ pooled_logits = logits[
2136
+ torch.arange(batch_size, device=logits.device), sequence_lengths
2137
+ ]
2138
+
2139
+ loss = None
2140
+ if labels is not None:
2141
+ labels = labels.to(logits.device)
2142
+ if self.config.problem_type is None:
2143
+ if self.num_labels == 1:
2144
+ self.config.problem_type = "regression"
2145
+ elif self.num_labels > 1 and (
2146
+ labels.dtype == torch.long or labels.dtype == torch.int
2147
+ ):
2148
+ self.config.problem_type = "single_label_classification"
2149
+ else:
2150
+ self.config.problem_type = "multi_label_classification"
2151
+
2152
+ if self.config.problem_type == "regression":
2153
+ loss_fct = MSELoss()
2154
+ if self.num_labels == 1:
2155
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
2156
+ else:
2157
+ loss = loss_fct(pooled_logits, labels)
2158
+ elif self.config.problem_type == "single_label_classification":
2159
+ loss_fct = CrossEntropyLoss()
2160
+ loss = loss_fct(
2161
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
2162
+ )
2163
+ elif self.config.problem_type == "multi_label_classification":
2164
+ loss_fct = BCEWithLogitsLoss()
2165
+ loss = loss_fct(pooled_logits, labels)
2166
+ if not return_dict:
2167
+ output = (pooled_logits,) + transformer_outputs[1:]
2168
+ return ((loss,) + output) if loss is not None else output
2169
+
2170
+ return SequenceClassifierOutputWithPast(
2171
+ loss=loss,
2172
+ logits=pooled_logits,
2173
+ past_key_values=transformer_outputs.past_key_values,
2174
+ hidden_states=transformer_outputs.hidden_states,
2175
+ attentions=transformer_outputs.attentions,
2176
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin▁of▁sentence|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|end▁of▁sentence|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|end▁of▁sentence|>"
17
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff