WinstonDeng commited on
Commit
8fb8cbc
·
verified ·
1 Parent(s): 3eb118b

upload step3p5_flash_release_mtp3_bf16

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .done +0 -0
  2. chat_template.jinja +80 -0
  3. config.json +313 -0
  4. configuration_step3p5.py +59 -0
  5. model-00001.safetensors +3 -0
  6. model-00002.safetensors +3 -0
  7. model-00003.safetensors +3 -0
  8. model-00004.safetensors +3 -0
  9. model-00005.safetensors +3 -0
  10. model-00006.safetensors +3 -0
  11. model-00007.safetensors +3 -0
  12. model-00008.safetensors +3 -0
  13. model-00009.safetensors +3 -0
  14. model-00010.safetensors +3 -0
  15. model-00011.safetensors +3 -0
  16. model-00012.safetensors +3 -0
  17. model-00013.safetensors +3 -0
  18. model-00014.safetensors +3 -0
  19. model-00015.safetensors +3 -0
  20. model-00016.safetensors +3 -0
  21. model-00017.safetensors +3 -0
  22. model-00018.safetensors +3 -0
  23. model-00019.safetensors +3 -0
  24. model-00020.safetensors +3 -0
  25. model-00021.safetensors +3 -0
  26. model-00022.safetensors +3 -0
  27. model-00023.safetensors +3 -0
  28. model-00024.safetensors +3 -0
  29. model-00025.safetensors +3 -0
  30. model-00026.safetensors +3 -0
  31. model-00027.safetensors +3 -0
  32. model-00028.safetensors +3 -0
  33. model-00029.safetensors +3 -0
  34. model-00030.safetensors +3 -0
  35. model-00031.safetensors +3 -0
  36. model-00032.safetensors +3 -0
  37. model-00033.safetensors +3 -0
  38. model-00034.safetensors +3 -0
  39. model-00035.safetensors +3 -0
  40. model-00036.safetensors +3 -0
  41. model-00037.safetensors +3 -0
  42. model-00038.safetensors +3 -0
  43. model-00039.safetensors +3 -0
  44. model-00040.safetensors +3 -0
  45. model-00041.safetensors +3 -0
  46. model-00042.safetensors +3 -0
  47. model-00043.safetensors +3 -0
  48. model-00044.safetensors +3 -0
  49. model.safetensors.index.json +811 -0
  50. modeling_step3p5.py +900 -0
.done ADDED
File without changes
chat_template.jinja ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {{bos_token}}{%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if messages[0].role == 'system' %}
5
+ {{- render_content(messages[0].content) + '\n\n' }}
6
+ {%- endif %}
7
+ {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
8
+ {%- for tool in tools %}
9
+ {{- "\n" }}
10
+ {{- tool | tojson(ensure_ascii=False) }}
11
+ {%- endfor %}
12
+ {{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
13
+ {%- else %}
14
+ {%- if messages[0].role == 'system' %}
15
+ {{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
16
+ {%- endif %}
17
+ {%- endif %}
18
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
19
+ {%- for message in messages[::-1] %}
20
+ {%- set index = (messages|length - 1) - loop.index0 %}
21
+ {%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
22
+ {%- set ns.multi_step_tool = false %}
23
+ {%- set ns.last_query_index = index %}
24
+ {%- endif %}
25
+ {%- endfor %}
26
+ {%- for message in messages %}
27
+ {%- set content = render_content(message.content) %}
28
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
29
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
30
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
31
+ {%- elif message.role == "assistant" %}
32
+ {%- if message.reasoning_content is string %}
33
+ {%- set reasoning_content = render_content(message.reasoning_content) %}
34
+ {%- else %}
35
+ {%- if '</think>' in content %}
36
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
38
+ {%- else %}
39
+ {%- set reasoning_content = '' %}
40
+ {%- endif %}
41
+ {%- endif %}
42
+ {%- if loop.index0 > ns.last_query_index %}
43
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
44
+ {%- else %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content }}
46
+ {%- endif %}
47
+ {%- if message.tool_calls %}
48
+ {%- for tool_call in message.tool_calls %}
49
+ {%- if tool_call.function is defined %}
50
+ {%- set tool_call = tool_call.function %}
51
+ {%- endif %}
52
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
53
+ {%- if tool_call.arguments is defined %}
54
+ {%- set arguments = tool_call.arguments %}
55
+ {%- for args_name, args_value in arguments|items %}
56
+ {{- '<parameter=' + args_name + '>\n' }}
57
+ {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
58
+ {{- args_value }}
59
+ {{- '\n</parameter>\n' }}
60
+ {%- endfor %}
61
+ {%- endif %}
62
+ {{- '</function>\n</tool_call>' }}
63
+ {%- endfor %}
64
+ {%- endif %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- elif message.role == "tool" %}
67
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
68
+ {{- '<|im_start|>tool_response\n' }}
69
+ {%- endif %}
70
+ {{- '<tool_response>' }}
71
+ {{- content }}
72
+ {{- '</tool_response>' }}
73
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
74
+ {{- '<|im_end|>\n' }}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if add_generation_prompt %}
79
+ {{- '<|im_start|>assistant\n<think>\n' }}
80
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step3p5ForCausalLM"
4
+ ],
5
+ "model_type": "step3p5",
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_step3p5.Step3p5Config",
8
+ "AutoModelForCausalLM": "modeling_step3p5.Step3p5ForCausalLM"
9
+ },
10
+ "rope_scaling": {
11
+ "rope_type": "llama3",
12
+ "factor": 2.0,
13
+ "original_max_position_embeddings": 131072,
14
+ "low_freq_factor": 1.0,
15
+ "high_freq_factor": 32.0
16
+ },
17
+ "yarn_only_types": ["full_attention"],
18
+ "hidden_size": 4096,
19
+ "intermediate_size": 11264,
20
+ "num_hidden_layers": 45,
21
+ "max_seq_len": 262144,
22
+ "vocab_size": 128896,
23
+ "torch_dtype": "bfloat16",
24
+ "use_qk_norm": false,
25
+ "moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
26
+ "use_mfa": false,
27
+ "num_attention_heads": 64,
28
+ "num_attention_groups": 8,
29
+ "head_dim": 128,
30
+ "use_moe": true,
31
+ "moe_num_experts": 288,
32
+ "moe_top_k": 8,
33
+ "moe_intermediate_size": 1280,
34
+ "share_expert_dim": 1280,
35
+ "moe_layer_offset": 0,
36
+ "moe_every_n_layer": 1,
37
+ "norm_expert_weight": true,
38
+ "moe_router_activation": "sigmoid",
39
+ "moe_router_scaling_factor": 3.0,
40
+ "att_impl_type": "GQA",
41
+ "rope_theta": [
42
+ 5000000.0,
43
+ 10000.0,
44
+ 10000.0,
45
+ 10000.0,
46
+ 5000000.0,
47
+ 10000.0,
48
+ 10000.0,
49
+ 10000.0,
50
+ 5000000.0,
51
+ 10000.0,
52
+ 10000.0,
53
+ 10000.0,
54
+ 5000000.0,
55
+ 10000.0,
56
+ 10000.0,
57
+ 10000.0,
58
+ 5000000.0,
59
+ 10000.0,
60
+ 10000.0,
61
+ 10000.0,
62
+ 5000000.0,
63
+ 10000.0,
64
+ 10000.0,
65
+ 10000.0,
66
+ 5000000.0,
67
+ 10000.0,
68
+ 10000.0,
69
+ 10000.0,
70
+ 5000000.0,
71
+ 10000.0,
72
+ 10000.0,
73
+ 10000.0,
74
+ 5000000.0,
75
+ 10000.0,
76
+ 10000.0,
77
+ 10000.0,
78
+ 5000000.0,
79
+ 10000.0,
80
+ 10000.0,
81
+ 10000.0,
82
+ 5000000.0,
83
+ 10000.0,
84
+ 10000.0,
85
+ 10000.0,
86
+ 5000000.0,
87
+ 10000.0,
88
+ 10000.0,
89
+ 10000.0
90
+ ],
91
+ "use_head_wise_attn_gate": true,
92
+ "sliding_window": 512,
93
+ "use_moe_router_bias": true,
94
+ "need_fp32_gate": true,
95
+ "sink": false,
96
+ "layer_types": [
97
+ "full_attention",
98
+ "sliding_attention",
99
+ "sliding_attention",
100
+ "sliding_attention",
101
+ "full_attention",
102
+ "sliding_attention",
103
+ "sliding_attention",
104
+ "sliding_attention",
105
+ "full_attention",
106
+ "sliding_attention",
107
+ "sliding_attention",
108
+ "sliding_attention",
109
+ "full_attention",
110
+ "sliding_attention",
111
+ "sliding_attention",
112
+ "sliding_attention",
113
+ "full_attention",
114
+ "sliding_attention",
115
+ "sliding_attention",
116
+ "sliding_attention",
117
+ "full_attention",
118
+ "sliding_attention",
119
+ "sliding_attention",
120
+ "sliding_attention",
121
+ "full_attention",
122
+ "sliding_attention",
123
+ "sliding_attention",
124
+ "sliding_attention",
125
+ "full_attention",
126
+ "sliding_attention",
127
+ "sliding_attention",
128
+ "sliding_attention",
129
+ "full_attention",
130
+ "sliding_attention",
131
+ "sliding_attention",
132
+ "sliding_attention",
133
+ "full_attention",
134
+ "sliding_attention",
135
+ "sliding_attention",
136
+ "sliding_attention",
137
+ "full_attention",
138
+ "sliding_attention",
139
+ "sliding_attention",
140
+ "sliding_attention",
141
+ "full_attention",
142
+ "sliding_attention",
143
+ "sliding_attention",
144
+ "sliding_attention"
145
+ ],
146
+ "use_rope_layers": [],
147
+ "num_nextn_predict_layers": 3,
148
+ "partial_rotary_factors": [
149
+ 0.5,
150
+ 1.0,
151
+ 1.0,
152
+ 1.0,
153
+ 0.5,
154
+ 1.0,
155
+ 1.0,
156
+ 1.0,
157
+ 0.5,
158
+ 1.0,
159
+ 1.0,
160
+ 1.0,
161
+ 0.5,
162
+ 1.0,
163
+ 1.0,
164
+ 1.0,
165
+ 0.5,
166
+ 1.0,
167
+ 1.0,
168
+ 1.0,
169
+ 0.5,
170
+ 1.0,
171
+ 1.0,
172
+ 1.0,
173
+ 0.5,
174
+ 1.0,
175
+ 1.0,
176
+ 1.0,
177
+ 0.5,
178
+ 1.0,
179
+ 1.0,
180
+ 1.0,
181
+ 0.5,
182
+ 1.0,
183
+ 1.0,
184
+ 1.0,
185
+ 0.5,
186
+ 1.0,
187
+ 1.0,
188
+ 1.0,
189
+ 0.5,
190
+ 1.0,
191
+ 1.0,
192
+ 1.0,
193
+ 0.5,
194
+ 1.0,
195
+ 1.0,
196
+ 1.0
197
+ ],
198
+ "eos_token_id": [
199
+ 1,
200
+ 2,
201
+ 128007
202
+ ],
203
+ "bos_token_id": 0,
204
+ "attention_other_setting": {
205
+ "attention_type": "sliding_attention",
206
+ "num_attention_heads": 96,
207
+ "num_attention_groups": 8,
208
+ "head_dim": 128,
209
+ "true_head_dim": 128
210
+ },
211
+ "swiglu_limits": [
212
+ 0.0,
213
+ 0.0,
214
+ 0.0,
215
+ 0.0,
216
+ 0.0,
217
+ 0.0,
218
+ 0.0,
219
+ 0.0,
220
+ 0.0,
221
+ 0.0,
222
+ 0.0,
223
+ 0.0,
224
+ 0.0,
225
+ 0.0,
226
+ 0.0,
227
+ 0.0,
228
+ 0.0,
229
+ 0.0,
230
+ 0.0,
231
+ 0.0,
232
+ 0.0,
233
+ 0.0,
234
+ 0.0,
235
+ 0.0,
236
+ 0.0,
237
+ 0.0,
238
+ 0.0,
239
+ 0.0,
240
+ 0.0,
241
+ 0.0,
242
+ 0.0,
243
+ 0.0,
244
+ 0.0,
245
+ 0.0,
246
+ 0.0,
247
+ 0.0,
248
+ 0.0,
249
+ 0.0,
250
+ 0.0,
251
+ 0.0,
252
+ 0.0,
253
+ 0.0,
254
+ 0.0,
255
+ 7,
256
+ 7,
257
+ 0.0,
258
+ 0.0,
259
+ 0.0
260
+ ],
261
+ "swiglu_limits_shared": [
262
+ 0.0,
263
+ 0.0,
264
+ 0.0,
265
+ 0.0,
266
+ 0.0,
267
+ 0.0,
268
+ 0.0,
269
+ 0.0,
270
+ 0.0,
271
+ 0.0,
272
+ 0.0,
273
+ 0.0,
274
+ 0.0,
275
+ 0.0,
276
+ 0.0,
277
+ 0.0,
278
+ 0.0,
279
+ 0.0,
280
+ 0.0,
281
+ 0.0,
282
+ 0.0,
283
+ 0.0,
284
+ 0.0,
285
+ 0.0,
286
+ 0.0,
287
+ 0.0,
288
+ 0.0,
289
+ 0.0,
290
+ 0.0,
291
+ 0.0,
292
+ 0.0,
293
+ 0.0,
294
+ 0.0,
295
+ 0.0,
296
+ 0.0,
297
+ 0.0,
298
+ 0.0,
299
+ 0.0,
300
+ 0.0,
301
+ 0.0,
302
+ 0.0,
303
+ 0.0,
304
+ 0.0,
305
+ 0.0,
306
+ 16,
307
+ 0.0,
308
+ 0.0,
309
+ 0.0
310
+ ],
311
+ "zero_centered": true,
312
+ "max_position_embeddings": 262144
313
+ }
configuration_step3p5.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+
7
+ class Step3p5Config(PretrainedConfig):
8
+ model_type = "step3p5"
9
+ architectures = ["Step3p5ForCausalLM"]
10
+
11
+ def __init__(
12
+ self,
13
+ hidden_size: int = 4096,
14
+ intermediate_size: int = 11264,
15
+ num_attention_heads: int = 64,
16
+ num_attention_groups: int = 8,
17
+ num_hidden_layers: int = 45,
18
+ max_seq_len: int = 128000,
19
+ vocab_size: int = 128815,
20
+ rms_norm_eps: float = 1e-5,
21
+ moe_intermediate_size: int = 1280,
22
+ moe_num_experts: int = 288,
23
+ moe_top_k: int = 8,
24
+ rope_theta: float = 10000,
25
+ rope_scaling: Optional[dict[str, Any]] = None,
26
+ max_position_embeddings: int = 128000,
27
+ share_expert_dims: int = 1280,
28
+ head_dim: int = 128,
29
+ norm_expert_weight: bool = True,
30
+ layer_types: list[str] = None,
31
+ sliding_window: Optional[int] = None,
32
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
33
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
34
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
36
+ **kwargs,
37
+ ) -> None:
38
+ self.hidden_size = hidden_size
39
+ self.intermediate_size = intermediate_size
40
+ self.num_attention_heads = num_attention_heads
41
+ self.num_attention_groups = num_attention_groups
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.max_seq_len = max_seq_len
44
+ self.vocab_size = vocab_size
45
+ self.rms_norm_eps = rms_norm_eps
46
+ self.moe_intermediate_size = moe_intermediate_size
47
+ self.moe_num_experts = moe_num_experts
48
+ self.moe_top_k = moe_top_k
49
+ self.rope_theta = rope_theta
50
+ self.rope_scaling = rope_scaling
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.share_expert_dim = share_expert_dims
53
+ self.head_dim = head_dim
54
+ self.norm_expert_weight = norm_expert_weight
55
+ self.moe_layers_enum = moe_layers_enum
56
+ self.layer_types = layer_types
57
+ self.sliding_window = sliding_window
58
+ super().__init__(**kwargs)
59
+
model-00001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f8e2662c620b7f21cb8b00961471128746421c6f61ba92ad24ff8651e986766
3
+ size 9628898912
model-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22e138ece5dce3c2611f8b6728f966e506ed139c57b72ee4139ba1deca91ead7
3
+ size 8632547344
model-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8726b3b15516baa63f029590c8b5da4459edd0f50da57e273d13fc2deff31d8
3
+ size 9059696992
model-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:198b2f0c237110ad9bbe65de107dd4dfe450a86cfc5f250c6b9cb688321e571d
3
+ size 9059696992
model-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d3721a49193beb502621b97b55f22ff9b7458a71a910ffc1fe74243abc19777
3
+ size 9059696992
model-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc035e1afcc9785ee895ee1372519d2bd327c76284629f74c4d157bdf573c53f
3
+ size 9059696992
model-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2ff1c7baff83316fae101d258bb44924032df3278849c09d236441ead771330
3
+ size 9059696992
model-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2615aeb466afb62ec1b7403b9040019c3018a44d167405547e7559b3d8e34e8
3
+ size 9059696992
model-00009.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f544f4375037da6cd02366427ff5e75e5244c2491a0090fba350867050b2bed7
3
+ size 9059696992
model-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc5288b60acfc96f4be69f993c9bc603b778a853d7a4cdd24591054ade49bd42
3
+ size 9059697000
model-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0fd8a16bda39fb8ae26ad0dd3a5c32e7ce915be778505c44066c74b3c17e386
3
+ size 9059697000
model-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53f27314566dbf7e6c4755094515dc0b77a20a15ecd83ddaabc61ba0a6f65ecd
3
+ size 9059697000
model-00013.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a1d66062d7e22083872d894ae6f44ec9faea9bc01f0306778b709e1eca83313
3
+ size 9059697000
model-00014.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3db10310139824a2952baaec66da5f93968dd7485c513270943d2c649843b874
3
+ size 9059697000
model-00015.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52db61304d989ae57dc4b8a01c78cb0798bef2624afb1c6425bd9f9296a90494
3
+ size 9059697000
model-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db4c6aef9cc7468ff5587496e6d950b3992e5de02f9b771a0fc1bae4aa873866
3
+ size 9059697000
model-00017.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e376274886e231751f41d71bcba67ad6c02a68fed8bfae6a6335bd10a37b414
3
+ size 9059697000
model-00018.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc48094eac9f450a9b773869aa85394dbec8e32d307e51de94a5bf72b575498b
3
+ size 9059697000
model-00019.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7f962f5a2ab272cc4e90ffd22eb639f4beb15eadd6eca391e80002861161b3f
3
+ size 9059697000
model-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffe4b5d3083c7a874a10129c8085ab1012066981f0d700d315563f759d3dfff9
3
+ size 9059697000
model-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0b8f2f30a026adc1e6936f4bbe0807d4fc4e2caf391f05c533fc32fbec7e0ab
3
+ size 9059697000
model-00022.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e70c49f26b69916ca1fc0266dc2b03f7dd47a2508ee40c68588993b31879962b
3
+ size 9059697000
model-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63fef3259ccfbd135ab4d289ca962e8ed1c2ac55216546e49e259e66ac2cc87e
3
+ size 9059697000
model-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29eeb0ed8525d3bcdbc19643ff66252cbd641f44aa54592fcb0654260e428611
3
+ size 9059697000
model-00025.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20ccb4b13ee04b6d76d8d397420c310312d514f71688d55df5529114c7adf769
3
+ size 9059697000
model-00026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da930ba12ecb5ac0f111e5351ce4f4a691b95b99cc34d657490ed0c4ecc6ce29
3
+ size 9059697000
model-00027.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ef0f3708e484aa4be0157b0e78794280820e51604056b12a96e478d3486511c
3
+ size 9059697000
model-00028.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08feb8feb8a238efc43ea376b83139f0a0b121386048c68265d2af2f9beee712
3
+ size 9059697000
model-00029.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05e4482ec47a1e77c2a337be59a5598517fdb23f5aca1fc295e44b2d0283ba59
3
+ size 9059697000
model-00030.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9cb29b0a63748adcb055e1001e955df82278db8c42447829bc6f61cdb162305
3
+ size 9059697000
model-00031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebe00abaa258ee7142289ef259db4ce3272a1fbe1a3f096acbedfddeb6fcf214
3
+ size 9059697000
model-00032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0227047725983fd5601c5dfda03cf65f23aec2fd528260b1c33d43a71f473460
3
+ size 9059697000
model-00033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1673cef13088986b6fad926cc706ff3e2947824c71825a55477e7baa8b09e201
3
+ size 9059697000
model-00034.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52930ddcd3f6c8f99212883967f0bca80f1b0bd1e1269e97b4a5cce5666ddc82
3
+ size 9059697000
model-00035.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:581b9539799266540a2e78dee7a2b294c8820b412b2d148fca47c32c9347ee32
3
+ size 9059697000
model-00036.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:571ddde460371caa95e0e2a0f3c18e7974d80ada1c5cd9f00b1242ba880fda47
3
+ size 9059697000
model-00037.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abeea37a41821b07daed0e98b44349a089d7cb43030cfcc6608db8bc6fcd926c
3
+ size 9059697000
model-00038.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df998f9af1c225263239d9210fc66fcdb76b1134833aa318bc0f144e8d33426b
3
+ size 9059697000
model-00039.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d316a3f796db46b13bfd9dddc086fa4cd866e1e2338f63e68bf35ba186274a7
3
+ size 9059697000
model-00040.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e287dbcfaa2c9c8ea77334cdd6fc011b8c058b8f750870ff309faf671b38d688
3
+ size 9059697000
model-00041.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:933b06efa555d81b7525b87f3b5636ca82fc5b730a8cc0e35f7f0fedb658dbf4
3
+ size 9059697000
model-00042.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0057a63127ae7282dc9aaf28fc7948e3377d8d04bc7aec580772520264a7f450
3
+ size 9059697000
model-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29bb1885b5721799447e65b4d67054aba290cba32e38e6934c3690b41ff22c6c
3
+ size 9059697000
model-00044.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d4febe2c29ebb67ed97b99922113c353f97935a7f4243d7843c833d06471b41
3
+ size 9059697000
model.safetensors.index.json ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 398768626944
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002.safetensors",
7
+ "model.embed_tokens.weight": "model-00001.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001.safetensors",
13
+ "model.layers.0.self_attn.g_proj.weight": "model-00001.safetensors",
14
+ "model.layers.0.self_attn.k_norm.weight": "model-00001.safetensors",
15
+ "model.layers.0.self_attn.k_proj.weight": "model-00001.safetensors",
16
+ "model.layers.0.self_attn.o_proj.weight": "model-00001.safetensors",
17
+ "model.layers.0.self_attn.q_norm.weight": "model-00001.safetensors",
18
+ "model.layers.0.self_attn.q_proj.weight": "model-00001.safetensors",
19
+ "model.layers.0.self_attn.v_proj.weight": "model-00001.safetensors",
20
+ "model.layers.1.input_layernorm.weight": "model-00001.safetensors",
21
+ "model.layers.1.mlp.down_proj.weight": "model-00001.safetensors",
22
+ "model.layers.1.mlp.gate_proj.weight": "model-00001.safetensors",
23
+ "model.layers.1.mlp.up_proj.weight": "model-00001.safetensors",
24
+ "model.layers.1.post_attention_layernorm.weight": "model-00001.safetensors",
25
+ "model.layers.1.self_attn.g_proj.weight": "model-00001.safetensors",
26
+ "model.layers.1.self_attn.k_norm.weight": "model-00001.safetensors",
27
+ "model.layers.1.self_attn.k_proj.weight": "model-00001.safetensors",
28
+ "model.layers.1.self_attn.o_proj.weight": "model-00001.safetensors",
29
+ "model.layers.1.self_attn.q_norm.weight": "model-00001.safetensors",
30
+ "model.layers.1.self_attn.q_proj.weight": "model-00001.safetensors",
31
+ "model.layers.1.self_attn.v_proj.weight": "model-00001.safetensors",
32
+ "model.layers.10.input_layernorm.weight": "model-00001.safetensors",
33
+ "model.layers.10.moe.down_proj.weight": "model-00010.safetensors",
34
+ "model.layers.10.moe.gate.weight": "model-00001.safetensors",
35
+ "model.layers.10.moe.gate_proj.weight": "model-00010.safetensors",
36
+ "model.layers.10.moe.router_bias": "model-00001.safetensors",
37
+ "model.layers.10.moe.up_proj.weight": "model-00010.safetensors",
38
+ "model.layers.10.post_attention_layernorm.weight": "model-00001.safetensors",
39
+ "model.layers.10.self_attn.g_proj.weight": "model-00001.safetensors",
40
+ "model.layers.10.self_attn.k_norm.weight": "model-00001.safetensors",
41
+ "model.layers.10.self_attn.k_proj.weight": "model-00001.safetensors",
42
+ "model.layers.10.self_attn.o_proj.weight": "model-00001.safetensors",
43
+ "model.layers.10.self_attn.q_norm.weight": "model-00001.safetensors",
44
+ "model.layers.10.self_attn.q_proj.weight": "model-00001.safetensors",
45
+ "model.layers.10.self_attn.v_proj.weight": "model-00001.safetensors",
46
+ "model.layers.10.share_expert.down_proj.weight": "model-00001.safetensors",
47
+ "model.layers.10.share_expert.gate_proj.weight": "model-00001.safetensors",
48
+ "model.layers.10.share_expert.up_proj.weight": "model-00001.safetensors",
49
+ "model.layers.11.input_layernorm.weight": "model-00001.safetensors",
50
+ "model.layers.11.moe.down_proj.weight": "model-00011.safetensors",
51
+ "model.layers.11.moe.gate.weight": "model-00001.safetensors",
52
+ "model.layers.11.moe.gate_proj.weight": "model-00011.safetensors",
53
+ "model.layers.11.moe.router_bias": "model-00001.safetensors",
54
+ "model.layers.11.moe.up_proj.weight": "model-00011.safetensors",
55
+ "model.layers.11.post_attention_layernorm.weight": "model-00001.safetensors",
56
+ "model.layers.11.self_attn.g_proj.weight": "model-00001.safetensors",
57
+ "model.layers.11.self_attn.k_norm.weight": "model-00001.safetensors",
58
+ "model.layers.11.self_attn.k_proj.weight": "model-00001.safetensors",
59
+ "model.layers.11.self_attn.o_proj.weight": "model-00001.safetensors",
60
+ "model.layers.11.self_attn.q_norm.weight": "model-00001.safetensors",
61
+ "model.layers.11.self_attn.q_proj.weight": "model-00001.safetensors",
62
+ "model.layers.11.self_attn.v_proj.weight": "model-00001.safetensors",
63
+ "model.layers.11.share_expert.down_proj.weight": "model-00001.safetensors",
64
+ "model.layers.11.share_expert.gate_proj.weight": "model-00001.safetensors",
65
+ "model.layers.11.share_expert.up_proj.weight": "model-00001.safetensors",
66
+ "model.layers.12.input_layernorm.weight": "model-00001.safetensors",
67
+ "model.layers.12.moe.down_proj.weight": "model-00012.safetensors",
68
+ "model.layers.12.moe.gate.weight": "model-00001.safetensors",
69
+ "model.layers.12.moe.gate_proj.weight": "model-00012.safetensors",
70
+ "model.layers.12.moe.router_bias": "model-00001.safetensors",
71
+ "model.layers.12.moe.up_proj.weight": "model-00012.safetensors",
72
+ "model.layers.12.post_attention_layernorm.weight": "model-00001.safetensors",
73
+ "model.layers.12.self_attn.g_proj.weight": "model-00001.safetensors",
74
+ "model.layers.12.self_attn.k_norm.weight": "model-00001.safetensors",
75
+ "model.layers.12.self_attn.k_proj.weight": "model-00001.safetensors",
76
+ "model.layers.12.self_attn.o_proj.weight": "model-00001.safetensors",
77
+ "model.layers.12.self_attn.q_norm.weight": "model-00001.safetensors",
78
+ "model.layers.12.self_attn.q_proj.weight": "model-00001.safetensors",
79
+ "model.layers.12.self_attn.v_proj.weight": "model-00001.safetensors",
80
+ "model.layers.12.share_expert.down_proj.weight": "model-00001.safetensors",
81
+ "model.layers.12.share_expert.gate_proj.weight": "model-00001.safetensors",
82
+ "model.layers.12.share_expert.up_proj.weight": "model-00001.safetensors",
83
+ "model.layers.13.input_layernorm.weight": "model-00001.safetensors",
84
+ "model.layers.13.moe.down_proj.weight": "model-00013.safetensors",
85
+ "model.layers.13.moe.gate.weight": "model-00001.safetensors",
86
+ "model.layers.13.moe.gate_proj.weight": "model-00013.safetensors",
87
+ "model.layers.13.moe.router_bias": "model-00001.safetensors",
88
+ "model.layers.13.moe.up_proj.weight": "model-00013.safetensors",
89
+ "model.layers.13.post_attention_layernorm.weight": "model-00001.safetensors",
90
+ "model.layers.13.self_attn.g_proj.weight": "model-00001.safetensors",
91
+ "model.layers.13.self_attn.k_norm.weight": "model-00001.safetensors",
92
+ "model.layers.13.self_attn.k_proj.weight": "model-00001.safetensors",
93
+ "model.layers.13.self_attn.o_proj.weight": "model-00001.safetensors",
94
+ "model.layers.13.self_attn.q_norm.weight": "model-00001.safetensors",
95
+ "model.layers.13.self_attn.q_proj.weight": "model-00001.safetensors",
96
+ "model.layers.13.self_attn.v_proj.weight": "model-00001.safetensors",
97
+ "model.layers.13.share_expert.down_proj.weight": "model-00001.safetensors",
98
+ "model.layers.13.share_expert.gate_proj.weight": "model-00001.safetensors",
99
+ "model.layers.13.share_expert.up_proj.weight": "model-00001.safetensors",
100
+ "model.layers.14.input_layernorm.weight": "model-00001.safetensors",
101
+ "model.layers.14.moe.down_proj.weight": "model-00014.safetensors",
102
+ "model.layers.14.moe.gate.weight": "model-00001.safetensors",
103
+ "model.layers.14.moe.gate_proj.weight": "model-00014.safetensors",
104
+ "model.layers.14.moe.router_bias": "model-00001.safetensors",
105
+ "model.layers.14.moe.up_proj.weight": "model-00014.safetensors",
106
+ "model.layers.14.post_attention_layernorm.weight": "model-00001.safetensors",
107
+ "model.layers.14.self_attn.g_proj.weight": "model-00001.safetensors",
108
+ "model.layers.14.self_attn.k_norm.weight": "model-00001.safetensors",
109
+ "model.layers.14.self_attn.k_proj.weight": "model-00001.safetensors",
110
+ "model.layers.14.self_attn.o_proj.weight": "model-00001.safetensors",
111
+ "model.layers.14.self_attn.q_norm.weight": "model-00001.safetensors",
112
+ "model.layers.14.self_attn.q_proj.weight": "model-00001.safetensors",
113
+ "model.layers.14.self_attn.v_proj.weight": "model-00001.safetensors",
114
+ "model.layers.14.share_expert.down_proj.weight": "model-00001.safetensors",
115
+ "model.layers.14.share_expert.gate_proj.weight": "model-00001.safetensors",
116
+ "model.layers.14.share_expert.up_proj.weight": "model-00001.safetensors",
117
+ "model.layers.15.input_layernorm.weight": "model-00001.safetensors",
118
+ "model.layers.15.moe.down_proj.weight": "model-00015.safetensors",
119
+ "model.layers.15.moe.gate.weight": "model-00001.safetensors",
120
+ "model.layers.15.moe.gate_proj.weight": "model-00015.safetensors",
121
+ "model.layers.15.moe.router_bias": "model-00001.safetensors",
122
+ "model.layers.15.moe.up_proj.weight": "model-00015.safetensors",
123
+ "model.layers.15.post_attention_layernorm.weight": "model-00001.safetensors",
124
+ "model.layers.15.self_attn.g_proj.weight": "model-00001.safetensors",
125
+ "model.layers.15.self_attn.k_norm.weight": "model-00001.safetensors",
126
+ "model.layers.15.self_attn.k_proj.weight": "model-00001.safetensors",
127
+ "model.layers.15.self_attn.o_proj.weight": "model-00001.safetensors",
128
+ "model.layers.15.self_attn.q_norm.weight": "model-00001.safetensors",
129
+ "model.layers.15.self_attn.q_proj.weight": "model-00001.safetensors",
130
+ "model.layers.15.self_attn.v_proj.weight": "model-00001.safetensors",
131
+ "model.layers.15.share_expert.down_proj.weight": "model-00001.safetensors",
132
+ "model.layers.15.share_expert.gate_proj.weight": "model-00001.safetensors",
133
+ "model.layers.15.share_expert.up_proj.weight": "model-00001.safetensors",
134
+ "model.layers.16.input_layernorm.weight": "model-00001.safetensors",
135
+ "model.layers.16.moe.down_proj.weight": "model-00016.safetensors",
136
+ "model.layers.16.moe.gate.weight": "model-00001.safetensors",
137
+ "model.layers.16.moe.gate_proj.weight": "model-00016.safetensors",
138
+ "model.layers.16.moe.router_bias": "model-00001.safetensors",
139
+ "model.layers.16.moe.up_proj.weight": "model-00016.safetensors",
140
+ "model.layers.16.post_attention_layernorm.weight": "model-00001.safetensors",
141
+ "model.layers.16.self_attn.g_proj.weight": "model-00001.safetensors",
142
+ "model.layers.16.self_attn.k_norm.weight": "model-00001.safetensors",
143
+ "model.layers.16.self_attn.k_proj.weight": "model-00001.safetensors",
144
+ "model.layers.16.self_attn.o_proj.weight": "model-00001.safetensors",
145
+ "model.layers.16.self_attn.q_norm.weight": "model-00001.safetensors",
146
+ "model.layers.16.self_attn.q_proj.weight": "model-00001.safetensors",
147
+ "model.layers.16.self_attn.v_proj.weight": "model-00001.safetensors",
148
+ "model.layers.16.share_expert.down_proj.weight": "model-00001.safetensors",
149
+ "model.layers.16.share_expert.gate_proj.weight": "model-00001.safetensors",
150
+ "model.layers.16.share_expert.up_proj.weight": "model-00001.safetensors",
151
+ "model.layers.17.input_layernorm.weight": "model-00001.safetensors",
152
+ "model.layers.17.moe.down_proj.weight": "model-00017.safetensors",
153
+ "model.layers.17.moe.gate.weight": "model-00001.safetensors",
154
+ "model.layers.17.moe.gate_proj.weight": "model-00017.safetensors",
155
+ "model.layers.17.moe.router_bias": "model-00001.safetensors",
156
+ "model.layers.17.moe.up_proj.weight": "model-00017.safetensors",
157
+ "model.layers.17.post_attention_layernorm.weight": "model-00001.safetensors",
158
+ "model.layers.17.self_attn.g_proj.weight": "model-00001.safetensors",
159
+ "model.layers.17.self_attn.k_norm.weight": "model-00001.safetensors",
160
+ "model.layers.17.self_attn.k_proj.weight": "model-00001.safetensors",
161
+ "model.layers.17.self_attn.o_proj.weight": "model-00001.safetensors",
162
+ "model.layers.17.self_attn.q_norm.weight": "model-00001.safetensors",
163
+ "model.layers.17.self_attn.q_proj.weight": "model-00001.safetensors",
164
+ "model.layers.17.self_attn.v_proj.weight": "model-00001.safetensors",
165
+ "model.layers.17.share_expert.down_proj.weight": "model-00001.safetensors",
166
+ "model.layers.17.share_expert.gate_proj.weight": "model-00001.safetensors",
167
+ "model.layers.17.share_expert.up_proj.weight": "model-00001.safetensors",
168
+ "model.layers.18.input_layernorm.weight": "model-00001.safetensors",
169
+ "model.layers.18.moe.down_proj.weight": "model-00018.safetensors",
170
+ "model.layers.18.moe.gate.weight": "model-00001.safetensors",
171
+ "model.layers.18.moe.gate_proj.weight": "model-00018.safetensors",
172
+ "model.layers.18.moe.router_bias": "model-00001.safetensors",
173
+ "model.layers.18.moe.up_proj.weight": "model-00018.safetensors",
174
+ "model.layers.18.post_attention_layernorm.weight": "model-00001.safetensors",
175
+ "model.layers.18.self_attn.g_proj.weight": "model-00001.safetensors",
176
+ "model.layers.18.self_attn.k_norm.weight": "model-00001.safetensors",
177
+ "model.layers.18.self_attn.k_proj.weight": "model-00001.safetensors",
178
+ "model.layers.18.self_attn.o_proj.weight": "model-00001.safetensors",
179
+ "model.layers.18.self_attn.q_norm.weight": "model-00001.safetensors",
180
+ "model.layers.18.self_attn.q_proj.weight": "model-00001.safetensors",
181
+ "model.layers.18.self_attn.v_proj.weight": "model-00001.safetensors",
182
+ "model.layers.18.share_expert.down_proj.weight": "model-00001.safetensors",
183
+ "model.layers.18.share_expert.gate_proj.weight": "model-00001.safetensors",
184
+ "model.layers.18.share_expert.up_proj.weight": "model-00001.safetensors",
185
+ "model.layers.19.input_layernorm.weight": "model-00001.safetensors",
186
+ "model.layers.19.moe.down_proj.weight": "model-00019.safetensors",
187
+ "model.layers.19.moe.gate.weight": "model-00001.safetensors",
188
+ "model.layers.19.moe.gate_proj.weight": "model-00019.safetensors",
189
+ "model.layers.19.moe.router_bias": "model-00001.safetensors",
190
+ "model.layers.19.moe.up_proj.weight": "model-00019.safetensors",
191
+ "model.layers.19.post_attention_layernorm.weight": "model-00001.safetensors",
192
+ "model.layers.19.self_attn.g_proj.weight": "model-00001.safetensors",
193
+ "model.layers.19.self_attn.k_norm.weight": "model-00001.safetensors",
194
+ "model.layers.19.self_attn.k_proj.weight": "model-00001.safetensors",
195
+ "model.layers.19.self_attn.o_proj.weight": "model-00001.safetensors",
196
+ "model.layers.19.self_attn.q_norm.weight": "model-00001.safetensors",
197
+ "model.layers.19.self_attn.q_proj.weight": "model-00001.safetensors",
198
+ "model.layers.19.self_attn.v_proj.weight": "model-00001.safetensors",
199
+ "model.layers.19.share_expert.down_proj.weight": "model-00001.safetensors",
200
+ "model.layers.19.share_expert.gate_proj.weight": "model-00001.safetensors",
201
+ "model.layers.19.share_expert.up_proj.weight": "model-00001.safetensors",
202
+ "model.layers.2.input_layernorm.weight": "model-00001.safetensors",
203
+ "model.layers.2.mlp.down_proj.weight": "model-00001.safetensors",
204
+ "model.layers.2.mlp.gate_proj.weight": "model-00001.safetensors",
205
+ "model.layers.2.mlp.up_proj.weight": "model-00001.safetensors",
206
+ "model.layers.2.post_attention_layernorm.weight": "model-00001.safetensors",
207
+ "model.layers.2.self_attn.g_proj.weight": "model-00001.safetensors",
208
+ "model.layers.2.self_attn.k_norm.weight": "model-00001.safetensors",
209
+ "model.layers.2.self_attn.k_proj.weight": "model-00001.safetensors",
210
+ "model.layers.2.self_attn.o_proj.weight": "model-00001.safetensors",
211
+ "model.layers.2.self_attn.q_norm.weight": "model-00001.safetensors",
212
+ "model.layers.2.self_attn.q_proj.weight": "model-00001.safetensors",
213
+ "model.layers.2.self_attn.v_proj.weight": "model-00001.safetensors",
214
+ "model.layers.20.input_layernorm.weight": "model-00001.safetensors",
215
+ "model.layers.20.moe.down_proj.weight": "model-00020.safetensors",
216
+ "model.layers.20.moe.gate.weight": "model-00001.safetensors",
217
+ "model.layers.20.moe.gate_proj.weight": "model-00020.safetensors",
218
+ "model.layers.20.moe.router_bias": "model-00001.safetensors",
219
+ "model.layers.20.moe.up_proj.weight": "model-00020.safetensors",
220
+ "model.layers.20.post_attention_layernorm.weight": "model-00001.safetensors",
221
+ "model.layers.20.self_attn.g_proj.weight": "model-00001.safetensors",
222
+ "model.layers.20.self_attn.k_norm.weight": "model-00001.safetensors",
223
+ "model.layers.20.self_attn.k_proj.weight": "model-00001.safetensors",
224
+ "model.layers.20.self_attn.o_proj.weight": "model-00001.safetensors",
225
+ "model.layers.20.self_attn.q_norm.weight": "model-00001.safetensors",
226
+ "model.layers.20.self_attn.q_proj.weight": "model-00001.safetensors",
227
+ "model.layers.20.self_attn.v_proj.weight": "model-00001.safetensors",
228
+ "model.layers.20.share_expert.down_proj.weight": "model-00001.safetensors",
229
+ "model.layers.20.share_expert.gate_proj.weight": "model-00001.safetensors",
230
+ "model.layers.20.share_expert.up_proj.weight": "model-00001.safetensors",
231
+ "model.layers.21.input_layernorm.weight": "model-00001.safetensors",
232
+ "model.layers.21.moe.down_proj.weight": "model-00021.safetensors",
233
+ "model.layers.21.moe.gate.weight": "model-00001.safetensors",
234
+ "model.layers.21.moe.gate_proj.weight": "model-00021.safetensors",
235
+ "model.layers.21.moe.router_bias": "model-00001.safetensors",
236
+ "model.layers.21.moe.up_proj.weight": "model-00021.safetensors",
237
+ "model.layers.21.post_attention_layernorm.weight": "model-00001.safetensors",
238
+ "model.layers.21.self_attn.g_proj.weight": "model-00001.safetensors",
239
+ "model.layers.21.self_attn.k_norm.weight": "model-00001.safetensors",
240
+ "model.layers.21.self_attn.k_proj.weight": "model-00001.safetensors",
241
+ "model.layers.21.self_attn.o_proj.weight": "model-00001.safetensors",
242
+ "model.layers.21.self_attn.q_norm.weight": "model-00001.safetensors",
243
+ "model.layers.21.self_attn.q_proj.weight": "model-00001.safetensors",
244
+ "model.layers.21.self_attn.v_proj.weight": "model-00001.safetensors",
245
+ "model.layers.21.share_expert.down_proj.weight": "model-00001.safetensors",
246
+ "model.layers.21.share_expert.gate_proj.weight": "model-00001.safetensors",
247
+ "model.layers.21.share_expert.up_proj.weight": "model-00001.safetensors",
248
+ "model.layers.22.input_layernorm.weight": "model-00001.safetensors",
249
+ "model.layers.22.moe.down_proj.weight": "model-00022.safetensors",
250
+ "model.layers.22.moe.gate.weight": "model-00001.safetensors",
251
+ "model.layers.22.moe.gate_proj.weight": "model-00022.safetensors",
252
+ "model.layers.22.moe.router_bias": "model-00001.safetensors",
253
+ "model.layers.22.moe.up_proj.weight": "model-00022.safetensors",
254
+ "model.layers.22.post_attention_layernorm.weight": "model-00001.safetensors",
255
+ "model.layers.22.self_attn.g_proj.weight": "model-00001.safetensors",
256
+ "model.layers.22.self_attn.k_norm.weight": "model-00001.safetensors",
257
+ "model.layers.22.self_attn.k_proj.weight": "model-00001.safetensors",
258
+ "model.layers.22.self_attn.o_proj.weight": "model-00001.safetensors",
259
+ "model.layers.22.self_attn.q_norm.weight": "model-00001.safetensors",
260
+ "model.layers.22.self_attn.q_proj.weight": "model-00001.safetensors",
261
+ "model.layers.22.self_attn.v_proj.weight": "model-00001.safetensors",
262
+ "model.layers.22.share_expert.down_proj.weight": "model-00001.safetensors",
263
+ "model.layers.22.share_expert.gate_proj.weight": "model-00001.safetensors",
264
+ "model.layers.22.share_expert.up_proj.weight": "model-00001.safetensors",
265
+ "model.layers.23.input_layernorm.weight": "model-00001.safetensors",
266
+ "model.layers.23.moe.down_proj.weight": "model-00023.safetensors",
267
+ "model.layers.23.moe.gate.weight": "model-00001.safetensors",
268
+ "model.layers.23.moe.gate_proj.weight": "model-00023.safetensors",
269
+ "model.layers.23.moe.router_bias": "model-00001.safetensors",
270
+ "model.layers.23.moe.up_proj.weight": "model-00023.safetensors",
271
+ "model.layers.23.post_attention_layernorm.weight": "model-00001.safetensors",
272
+ "model.layers.23.self_attn.g_proj.weight": "model-00001.safetensors",
273
+ "model.layers.23.self_attn.k_norm.weight": "model-00001.safetensors",
274
+ "model.layers.23.self_attn.k_proj.weight": "model-00001.safetensors",
275
+ "model.layers.23.self_attn.o_proj.weight": "model-00001.safetensors",
276
+ "model.layers.23.self_attn.q_norm.weight": "model-00001.safetensors",
277
+ "model.layers.23.self_attn.q_proj.weight": "model-00001.safetensors",
278
+ "model.layers.23.self_attn.v_proj.weight": "model-00001.safetensors",
279
+ "model.layers.23.share_expert.down_proj.weight": "model-00001.safetensors",
280
+ "model.layers.23.share_expert.gate_proj.weight": "model-00001.safetensors",
281
+ "model.layers.23.share_expert.up_proj.weight": "model-00001.safetensors",
282
+ "model.layers.24.input_layernorm.weight": "model-00001.safetensors",
283
+ "model.layers.24.moe.down_proj.weight": "model-00024.safetensors",
284
+ "model.layers.24.moe.gate.weight": "model-00001.safetensors",
285
+ "model.layers.24.moe.gate_proj.weight": "model-00024.safetensors",
286
+ "model.layers.24.moe.router_bias": "model-00001.safetensors",
287
+ "model.layers.24.moe.up_proj.weight": "model-00024.safetensors",
288
+ "model.layers.24.post_attention_layernorm.weight": "model-00001.safetensors",
289
+ "model.layers.24.self_attn.g_proj.weight": "model-00002.safetensors",
290
+ "model.layers.24.self_attn.k_norm.weight": "model-00001.safetensors",
291
+ "model.layers.24.self_attn.k_proj.weight": "model-00002.safetensors",
292
+ "model.layers.24.self_attn.o_proj.weight": "model-00001.safetensors",
293
+ "model.layers.24.self_attn.q_norm.weight": "model-00001.safetensors",
294
+ "model.layers.24.self_attn.q_proj.weight": "model-00002.safetensors",
295
+ "model.layers.24.self_attn.v_proj.weight": "model-00002.safetensors",
296
+ "model.layers.24.share_expert.down_proj.weight": "model-00001.safetensors",
297
+ "model.layers.24.share_expert.gate_proj.weight": "model-00002.safetensors",
298
+ "model.layers.24.share_expert.up_proj.weight": "model-00002.safetensors",
299
+ "model.layers.25.input_layernorm.weight": "model-00001.safetensors",
300
+ "model.layers.25.moe.down_proj.weight": "model-00025.safetensors",
301
+ "model.layers.25.moe.gate.weight": "model-00001.safetensors",
302
+ "model.layers.25.moe.gate_proj.weight": "model-00025.safetensors",
303
+ "model.layers.25.moe.router_bias": "model-00001.safetensors",
304
+ "model.layers.25.moe.up_proj.weight": "model-00025.safetensors",
305
+ "model.layers.25.post_attention_layernorm.weight": "model-00001.safetensors",
306
+ "model.layers.25.self_attn.g_proj.weight": "model-00002.safetensors",
307
+ "model.layers.25.self_attn.k_norm.weight": "model-00001.safetensors",
308
+ "model.layers.25.self_attn.k_proj.weight": "model-00002.safetensors",
309
+ "model.layers.25.self_attn.o_proj.weight": "model-00001.safetensors",
310
+ "model.layers.25.self_attn.q_norm.weight": "model-00001.safetensors",
311
+ "model.layers.25.self_attn.q_proj.weight": "model-00002.safetensors",
312
+ "model.layers.25.self_attn.v_proj.weight": "model-00002.safetensors",
313
+ "model.layers.25.share_expert.down_proj.weight": "model-00001.safetensors",
314
+ "model.layers.25.share_expert.gate_proj.weight": "model-00002.safetensors",
315
+ "model.layers.25.share_expert.up_proj.weight": "model-00002.safetensors",
316
+ "model.layers.26.input_layernorm.weight": "model-00001.safetensors",
317
+ "model.layers.26.moe.down_proj.weight": "model-00026.safetensors",
318
+ "model.layers.26.moe.gate.weight": "model-00001.safetensors",
319
+ "model.layers.26.moe.gate_proj.weight": "model-00026.safetensors",
320
+ "model.layers.26.moe.router_bias": "model-00001.safetensors",
321
+ "model.layers.26.moe.up_proj.weight": "model-00026.safetensors",
322
+ "model.layers.26.post_attention_layernorm.weight": "model-00001.safetensors",
323
+ "model.layers.26.self_attn.g_proj.weight": "model-00002.safetensors",
324
+ "model.layers.26.self_attn.k_norm.weight": "model-00001.safetensors",
325
+ "model.layers.26.self_attn.k_proj.weight": "model-00002.safetensors",
326
+ "model.layers.26.self_attn.o_proj.weight": "model-00001.safetensors",
327
+ "model.layers.26.self_attn.q_norm.weight": "model-00001.safetensors",
328
+ "model.layers.26.self_attn.q_proj.weight": "model-00002.safetensors",
329
+ "model.layers.26.self_attn.v_proj.weight": "model-00002.safetensors",
330
+ "model.layers.26.share_expert.down_proj.weight": "model-00001.safetensors",
331
+ "model.layers.26.share_expert.gate_proj.weight": "model-00002.safetensors",
332
+ "model.layers.26.share_expert.up_proj.weight": "model-00002.safetensors",
333
+ "model.layers.27.input_layernorm.weight": "model-00001.safetensors",
334
+ "model.layers.27.moe.down_proj.weight": "model-00027.safetensors",
335
+ "model.layers.27.moe.gate.weight": "model-00001.safetensors",
336
+ "model.layers.27.moe.gate_proj.weight": "model-00027.safetensors",
337
+ "model.layers.27.moe.router_bias": "model-00001.safetensors",
338
+ "model.layers.27.moe.up_proj.weight": "model-00027.safetensors",
339
+ "model.layers.27.post_attention_layernorm.weight": "model-00001.safetensors",
340
+ "model.layers.27.self_attn.g_proj.weight": "model-00002.safetensors",
341
+ "model.layers.27.self_attn.k_norm.weight": "model-00001.safetensors",
342
+ "model.layers.27.self_attn.k_proj.weight": "model-00002.safetensors",
343
+ "model.layers.27.self_attn.o_proj.weight": "model-00001.safetensors",
344
+ "model.layers.27.self_attn.q_norm.weight": "model-00001.safetensors",
345
+ "model.layers.27.self_attn.q_proj.weight": "model-00002.safetensors",
346
+ "model.layers.27.self_attn.v_proj.weight": "model-00002.safetensors",
347
+ "model.layers.27.share_expert.down_proj.weight": "model-00001.safetensors",
348
+ "model.layers.27.share_expert.gate_proj.weight": "model-00002.safetensors",
349
+ "model.layers.27.share_expert.up_proj.weight": "model-00002.safetensors",
350
+ "model.layers.28.input_layernorm.weight": "model-00001.safetensors",
351
+ "model.layers.28.moe.down_proj.weight": "model-00028.safetensors",
352
+ "model.layers.28.moe.gate.weight": "model-00001.safetensors",
353
+ "model.layers.28.moe.gate_proj.weight": "model-00028.safetensors",
354
+ "model.layers.28.moe.router_bias": "model-00001.safetensors",
355
+ "model.layers.28.moe.up_proj.weight": "model-00028.safetensors",
356
+ "model.layers.28.post_attention_layernorm.weight": "model-00001.safetensors",
357
+ "model.layers.28.self_attn.g_proj.weight": "model-00002.safetensors",
358
+ "model.layers.28.self_attn.k_norm.weight": "model-00001.safetensors",
359
+ "model.layers.28.self_attn.k_proj.weight": "model-00002.safetensors",
360
+ "model.layers.28.self_attn.o_proj.weight": "model-00001.safetensors",
361
+ "model.layers.28.self_attn.q_norm.weight": "model-00001.safetensors",
362
+ "model.layers.28.self_attn.q_proj.weight": "model-00002.safetensors",
363
+ "model.layers.28.self_attn.v_proj.weight": "model-00002.safetensors",
364
+ "model.layers.28.share_expert.down_proj.weight": "model-00001.safetensors",
365
+ "model.layers.28.share_expert.gate_proj.weight": "model-00002.safetensors",
366
+ "model.layers.28.share_expert.up_proj.weight": "model-00002.safetensors",
367
+ "model.layers.29.input_layernorm.weight": "model-00001.safetensors",
368
+ "model.layers.29.moe.down_proj.weight": "model-00029.safetensors",
369
+ "model.layers.29.moe.gate.weight": "model-00001.safetensors",
370
+ "model.layers.29.moe.gate_proj.weight": "model-00029.safetensors",
371
+ "model.layers.29.moe.router_bias": "model-00001.safetensors",
372
+ "model.layers.29.moe.up_proj.weight": "model-00029.safetensors",
373
+ "model.layers.29.post_attention_layernorm.weight": "model-00001.safetensors",
374
+ "model.layers.29.self_attn.g_proj.weight": "model-00002.safetensors",
375
+ "model.layers.29.self_attn.k_norm.weight": "model-00001.safetensors",
376
+ "model.layers.29.self_attn.k_proj.weight": "model-00002.safetensors",
377
+ "model.layers.29.self_attn.o_proj.weight": "model-00001.safetensors",
378
+ "model.layers.29.self_attn.q_norm.weight": "model-00001.safetensors",
379
+ "model.layers.29.self_attn.q_proj.weight": "model-00002.safetensors",
380
+ "model.layers.29.self_attn.v_proj.weight": "model-00002.safetensors",
381
+ "model.layers.29.share_expert.down_proj.weight": "model-00001.safetensors",
382
+ "model.layers.29.share_expert.gate_proj.weight": "model-00002.safetensors",
383
+ "model.layers.29.share_expert.up_proj.weight": "model-00002.safetensors",
384
+ "model.layers.3.input_layernorm.weight": "model-00001.safetensors",
385
+ "model.layers.3.moe.down_proj.weight": "model-00003.safetensors",
386
+ "model.layers.3.moe.gate.weight": "model-00001.safetensors",
387
+ "model.layers.3.moe.gate_proj.weight": "model-00003.safetensors",
388
+ "model.layers.3.moe.router_bias": "model-00001.safetensors",
389
+ "model.layers.3.moe.up_proj.weight": "model-00003.safetensors",
390
+ "model.layers.3.post_attention_layernorm.weight": "model-00001.safetensors",
391
+ "model.layers.3.self_attn.g_proj.weight": "model-00001.safetensors",
392
+ "model.layers.3.self_attn.k_norm.weight": "model-00001.safetensors",
393
+ "model.layers.3.self_attn.k_proj.weight": "model-00001.safetensors",
394
+ "model.layers.3.self_attn.o_proj.weight": "model-00001.safetensors",
395
+ "model.layers.3.self_attn.q_norm.weight": "model-00001.safetensors",
396
+ "model.layers.3.self_attn.q_proj.weight": "model-00001.safetensors",
397
+ "model.layers.3.self_attn.v_proj.weight": "model-00001.safetensors",
398
+ "model.layers.3.share_expert.down_proj.weight": "model-00001.safetensors",
399
+ "model.layers.3.share_expert.gate_proj.weight": "model-00001.safetensors",
400
+ "model.layers.3.share_expert.up_proj.weight": "model-00001.safetensors",
401
+ "model.layers.30.input_layernorm.weight": "model-00001.safetensors",
402
+ "model.layers.30.moe.down_proj.weight": "model-00030.safetensors",
403
+ "model.layers.30.moe.gate.weight": "model-00001.safetensors",
404
+ "model.layers.30.moe.gate_proj.weight": "model-00030.safetensors",
405
+ "model.layers.30.moe.router_bias": "model-00001.safetensors",
406
+ "model.layers.30.moe.up_proj.weight": "model-00030.safetensors",
407
+ "model.layers.30.post_attention_layernorm.weight": "model-00001.safetensors",
408
+ "model.layers.30.self_attn.g_proj.weight": "model-00002.safetensors",
409
+ "model.layers.30.self_attn.k_norm.weight": "model-00001.safetensors",
410
+ "model.layers.30.self_attn.k_proj.weight": "model-00002.safetensors",
411
+ "model.layers.30.self_attn.o_proj.weight": "model-00001.safetensors",
412
+ "model.layers.30.self_attn.q_norm.weight": "model-00001.safetensors",
413
+ "model.layers.30.self_attn.q_proj.weight": "model-00002.safetensors",
414
+ "model.layers.30.self_attn.v_proj.weight": "model-00002.safetensors",
415
+ "model.layers.30.share_expert.down_proj.weight": "model-00001.safetensors",
416
+ "model.layers.30.share_expert.gate_proj.weight": "model-00002.safetensors",
417
+ "model.layers.30.share_expert.up_proj.weight": "model-00002.safetensors",
418
+ "model.layers.31.input_layernorm.weight": "model-00001.safetensors",
419
+ "model.layers.31.moe.down_proj.weight": "model-00031.safetensors",
420
+ "model.layers.31.moe.gate.weight": "model-00001.safetensors",
421
+ "model.layers.31.moe.gate_proj.weight": "model-00031.safetensors",
422
+ "model.layers.31.moe.router_bias": "model-00001.safetensors",
423
+ "model.layers.31.moe.up_proj.weight": "model-00031.safetensors",
424
+ "model.layers.31.post_attention_layernorm.weight": "model-00001.safetensors",
425
+ "model.layers.31.self_attn.g_proj.weight": "model-00002.safetensors",
426
+ "model.layers.31.self_attn.k_norm.weight": "model-00001.safetensors",
427
+ "model.layers.31.self_attn.k_proj.weight": "model-00002.safetensors",
428
+ "model.layers.31.self_attn.o_proj.weight": "model-00001.safetensors",
429
+ "model.layers.31.self_attn.q_norm.weight": "model-00001.safetensors",
430
+ "model.layers.31.self_attn.q_proj.weight": "model-00002.safetensors",
431
+ "model.layers.31.self_attn.v_proj.weight": "model-00002.safetensors",
432
+ "model.layers.31.share_expert.down_proj.weight": "model-00001.safetensors",
433
+ "model.layers.31.share_expert.gate_proj.weight": "model-00002.safetensors",
434
+ "model.layers.31.share_expert.up_proj.weight": "model-00002.safetensors",
435
+ "model.layers.32.input_layernorm.weight": "model-00001.safetensors",
436
+ "model.layers.32.moe.down_proj.weight": "model-00032.safetensors",
437
+ "model.layers.32.moe.gate.weight": "model-00001.safetensors",
438
+ "model.layers.32.moe.gate_proj.weight": "model-00032.safetensors",
439
+ "model.layers.32.moe.router_bias": "model-00001.safetensors",
440
+ "model.layers.32.moe.up_proj.weight": "model-00032.safetensors",
441
+ "model.layers.32.post_attention_layernorm.weight": "model-00001.safetensors",
442
+ "model.layers.32.self_attn.g_proj.weight": "model-00002.safetensors",
443
+ "model.layers.32.self_attn.k_norm.weight": "model-00001.safetensors",
444
+ "model.layers.32.self_attn.k_proj.weight": "model-00002.safetensors",
445
+ "model.layers.32.self_attn.o_proj.weight": "model-00001.safetensors",
446
+ "model.layers.32.self_attn.q_norm.weight": "model-00001.safetensors",
447
+ "model.layers.32.self_attn.q_proj.weight": "model-00002.safetensors",
448
+ "model.layers.32.self_attn.v_proj.weight": "model-00002.safetensors",
449
+ "model.layers.32.share_expert.down_proj.weight": "model-00001.safetensors",
450
+ "model.layers.32.share_expert.gate_proj.weight": "model-00002.safetensors",
451
+ "model.layers.32.share_expert.up_proj.weight": "model-00002.safetensors",
452
+ "model.layers.33.input_layernorm.weight": "model-00001.safetensors",
453
+ "model.layers.33.moe.down_proj.weight": "model-00033.safetensors",
454
+ "model.layers.33.moe.gate.weight": "model-00001.safetensors",
455
+ "model.layers.33.moe.gate_proj.weight": "model-00033.safetensors",
456
+ "model.layers.33.moe.router_bias": "model-00001.safetensors",
457
+ "model.layers.33.moe.up_proj.weight": "model-00033.safetensors",
458
+ "model.layers.33.post_attention_layernorm.weight": "model-00001.safetensors",
459
+ "model.layers.33.self_attn.g_proj.weight": "model-00002.safetensors",
460
+ "model.layers.33.self_attn.k_norm.weight": "model-00001.safetensors",
461
+ "model.layers.33.self_attn.k_proj.weight": "model-00002.safetensors",
462
+ "model.layers.33.self_attn.o_proj.weight": "model-00001.safetensors",
463
+ "model.layers.33.self_attn.q_norm.weight": "model-00001.safetensors",
464
+ "model.layers.33.self_attn.q_proj.weight": "model-00002.safetensors",
465
+ "model.layers.33.self_attn.v_proj.weight": "model-00002.safetensors",
466
+ "model.layers.33.share_expert.down_proj.weight": "model-00001.safetensors",
467
+ "model.layers.33.share_expert.gate_proj.weight": "model-00002.safetensors",
468
+ "model.layers.33.share_expert.up_proj.weight": "model-00002.safetensors",
469
+ "model.layers.34.input_layernorm.weight": "model-00001.safetensors",
470
+ "model.layers.34.moe.down_proj.weight": "model-00034.safetensors",
471
+ "model.layers.34.moe.gate.weight": "model-00001.safetensors",
472
+ "model.layers.34.moe.gate_proj.weight": "model-00034.safetensors",
473
+ "model.layers.34.moe.router_bias": "model-00001.safetensors",
474
+ "model.layers.34.moe.up_proj.weight": "model-00034.safetensors",
475
+ "model.layers.34.post_attention_layernorm.weight": "model-00001.safetensors",
476
+ "model.layers.34.self_attn.g_proj.weight": "model-00002.safetensors",
477
+ "model.layers.34.self_attn.k_norm.weight": "model-00001.safetensors",
478
+ "model.layers.34.self_attn.k_proj.weight": "model-00002.safetensors",
479
+ "model.layers.34.self_attn.o_proj.weight": "model-00001.safetensors",
480
+ "model.layers.34.self_attn.q_norm.weight": "model-00001.safetensors",
481
+ "model.layers.34.self_attn.q_proj.weight": "model-00002.safetensors",
482
+ "model.layers.34.self_attn.v_proj.weight": "model-00002.safetensors",
483
+ "model.layers.34.share_expert.down_proj.weight": "model-00001.safetensors",
484
+ "model.layers.34.share_expert.gate_proj.weight": "model-00002.safetensors",
485
+ "model.layers.34.share_expert.up_proj.weight": "model-00002.safetensors",
486
+ "model.layers.35.input_layernorm.weight": "model-00001.safetensors",
487
+ "model.layers.35.moe.down_proj.weight": "model-00035.safetensors",
488
+ "model.layers.35.moe.gate.weight": "model-00001.safetensors",
489
+ "model.layers.35.moe.gate_proj.weight": "model-00035.safetensors",
490
+ "model.layers.35.moe.router_bias": "model-00001.safetensors",
491
+ "model.layers.35.moe.up_proj.weight": "model-00035.safetensors",
492
+ "model.layers.35.post_attention_layernorm.weight": "model-00001.safetensors",
493
+ "model.layers.35.self_attn.g_proj.weight": "model-00002.safetensors",
494
+ "model.layers.35.self_attn.k_norm.weight": "model-00001.safetensors",
495
+ "model.layers.35.self_attn.k_proj.weight": "model-00002.safetensors",
496
+ "model.layers.35.self_attn.o_proj.weight": "model-00001.safetensors",
497
+ "model.layers.35.self_attn.q_norm.weight": "model-00001.safetensors",
498
+ "model.layers.35.self_attn.q_proj.weight": "model-00002.safetensors",
499
+ "model.layers.35.self_attn.v_proj.weight": "model-00002.safetensors",
500
+ "model.layers.35.share_expert.down_proj.weight": "model-00001.safetensors",
501
+ "model.layers.35.share_expert.gate_proj.weight": "model-00002.safetensors",
502
+ "model.layers.35.share_expert.up_proj.weight": "model-00002.safetensors",
503
+ "model.layers.36.input_layernorm.weight": "model-00001.safetensors",
504
+ "model.layers.36.moe.down_proj.weight": "model-00036.safetensors",
505
+ "model.layers.36.moe.gate.weight": "model-00001.safetensors",
506
+ "model.layers.36.moe.gate_proj.weight": "model-00036.safetensors",
507
+ "model.layers.36.moe.router_bias": "model-00001.safetensors",
508
+ "model.layers.36.moe.up_proj.weight": "model-00036.safetensors",
509
+ "model.layers.36.post_attention_layernorm.weight": "model-00001.safetensors",
510
+ "model.layers.36.self_attn.g_proj.weight": "model-00002.safetensors",
511
+ "model.layers.36.self_attn.k_norm.weight": "model-00001.safetensors",
512
+ "model.layers.36.self_attn.k_proj.weight": "model-00002.safetensors",
513
+ "model.layers.36.self_attn.o_proj.weight": "model-00001.safetensors",
514
+ "model.layers.36.self_attn.q_norm.weight": "model-00001.safetensors",
515
+ "model.layers.36.self_attn.q_proj.weight": "model-00002.safetensors",
516
+ "model.layers.36.self_attn.v_proj.weight": "model-00002.safetensors",
517
+ "model.layers.36.share_expert.down_proj.weight": "model-00001.safetensors",
518
+ "model.layers.36.share_expert.gate_proj.weight": "model-00002.safetensors",
519
+ "model.layers.36.share_expert.up_proj.weight": "model-00002.safetensors",
520
+ "model.layers.37.input_layernorm.weight": "model-00001.safetensors",
521
+ "model.layers.37.moe.down_proj.weight": "model-00037.safetensors",
522
+ "model.layers.37.moe.gate.weight": "model-00001.safetensors",
523
+ "model.layers.37.moe.gate_proj.weight": "model-00037.safetensors",
524
+ "model.layers.37.moe.router_bias": "model-00001.safetensors",
525
+ "model.layers.37.moe.up_proj.weight": "model-00037.safetensors",
526
+ "model.layers.37.post_attention_layernorm.weight": "model-00001.safetensors",
527
+ "model.layers.37.self_attn.g_proj.weight": "model-00002.safetensors",
528
+ "model.layers.37.self_attn.k_norm.weight": "model-00001.safetensors",
529
+ "model.layers.37.self_attn.k_proj.weight": "model-00002.safetensors",
530
+ "model.layers.37.self_attn.o_proj.weight": "model-00001.safetensors",
531
+ "model.layers.37.self_attn.q_norm.weight": "model-00001.safetensors",
532
+ "model.layers.37.self_attn.q_proj.weight": "model-00002.safetensors",
533
+ "model.layers.37.self_attn.v_proj.weight": "model-00002.safetensors",
534
+ "model.layers.37.share_expert.down_proj.weight": "model-00001.safetensors",
535
+ "model.layers.37.share_expert.gate_proj.weight": "model-00002.safetensors",
536
+ "model.layers.37.share_expert.up_proj.weight": "model-00002.safetensors",
537
+ "model.layers.38.input_layernorm.weight": "model-00001.safetensors",
538
+ "model.layers.38.moe.down_proj.weight": "model-00038.safetensors",
539
+ "model.layers.38.moe.gate.weight": "model-00001.safetensors",
540
+ "model.layers.38.moe.gate_proj.weight": "model-00038.safetensors",
541
+ "model.layers.38.moe.router_bias": "model-00001.safetensors",
542
+ "model.layers.38.moe.up_proj.weight": "model-00038.safetensors",
543
+ "model.layers.38.post_attention_layernorm.weight": "model-00001.safetensors",
544
+ "model.layers.38.self_attn.g_proj.weight": "model-00002.safetensors",
545
+ "model.layers.38.self_attn.k_norm.weight": "model-00001.safetensors",
546
+ "model.layers.38.self_attn.k_proj.weight": "model-00002.safetensors",
547
+ "model.layers.38.self_attn.o_proj.weight": "model-00001.safetensors",
548
+ "model.layers.38.self_attn.q_norm.weight": "model-00001.safetensors",
549
+ "model.layers.38.self_attn.q_proj.weight": "model-00002.safetensors",
550
+ "model.layers.38.self_attn.v_proj.weight": "model-00002.safetensors",
551
+ "model.layers.38.share_expert.down_proj.weight": "model-00001.safetensors",
552
+ "model.layers.38.share_expert.gate_proj.weight": "model-00002.safetensors",
553
+ "model.layers.38.share_expert.up_proj.weight": "model-00002.safetensors",
554
+ "model.layers.39.input_layernorm.weight": "model-00001.safetensors",
555
+ "model.layers.39.moe.down_proj.weight": "model-00039.safetensors",
556
+ "model.layers.39.moe.gate.weight": "model-00001.safetensors",
557
+ "model.layers.39.moe.gate_proj.weight": "model-00039.safetensors",
558
+ "model.layers.39.moe.router_bias": "model-00001.safetensors",
559
+ "model.layers.39.moe.up_proj.weight": "model-00039.safetensors",
560
+ "model.layers.39.post_attention_layernorm.weight": "model-00001.safetensors",
561
+ "model.layers.39.self_attn.g_proj.weight": "model-00002.safetensors",
562
+ "model.layers.39.self_attn.k_norm.weight": "model-00001.safetensors",
563
+ "model.layers.39.self_attn.k_proj.weight": "model-00002.safetensors",
564
+ "model.layers.39.self_attn.o_proj.weight": "model-00001.safetensors",
565
+ "model.layers.39.self_attn.q_norm.weight": "model-00001.safetensors",
566
+ "model.layers.39.self_attn.q_proj.weight": "model-00002.safetensors",
567
+ "model.layers.39.self_attn.v_proj.weight": "model-00002.safetensors",
568
+ "model.layers.39.share_expert.down_proj.weight": "model-00001.safetensors",
569
+ "model.layers.39.share_expert.gate_proj.weight": "model-00002.safetensors",
570
+ "model.layers.39.share_expert.up_proj.weight": "model-00002.safetensors",
571
+ "model.layers.4.input_layernorm.weight": "model-00001.safetensors",
572
+ "model.layers.4.moe.down_proj.weight": "model-00004.safetensors",
573
+ "model.layers.4.moe.gate.weight": "model-00001.safetensors",
574
+ "model.layers.4.moe.gate_proj.weight": "model-00004.safetensors",
575
+ "model.layers.4.moe.router_bias": "model-00001.safetensors",
576
+ "model.layers.4.moe.up_proj.weight": "model-00004.safetensors",
577
+ "model.layers.4.post_attention_layernorm.weight": "model-00001.safetensors",
578
+ "model.layers.4.self_attn.g_proj.weight": "model-00001.safetensors",
579
+ "model.layers.4.self_attn.k_norm.weight": "model-00001.safetensors",
580
+ "model.layers.4.self_attn.k_proj.weight": "model-00001.safetensors",
581
+ "model.layers.4.self_attn.o_proj.weight": "model-00001.safetensors",
582
+ "model.layers.4.self_attn.q_norm.weight": "model-00001.safetensors",
583
+ "model.layers.4.self_attn.q_proj.weight": "model-00001.safetensors",
584
+ "model.layers.4.self_attn.v_proj.weight": "model-00001.safetensors",
585
+ "model.layers.4.share_expert.down_proj.weight": "model-00001.safetensors",
586
+ "model.layers.4.share_expert.gate_proj.weight": "model-00001.safetensors",
587
+ "model.layers.4.share_expert.up_proj.weight": "model-00001.safetensors",
588
+ "model.layers.40.input_layernorm.weight": "model-00001.safetensors",
589
+ "model.layers.40.moe.down_proj.weight": "model-00040.safetensors",
590
+ "model.layers.40.moe.gate.weight": "model-00001.safetensors",
591
+ "model.layers.40.moe.gate_proj.weight": "model-00040.safetensors",
592
+ "model.layers.40.moe.router_bias": "model-00001.safetensors",
593
+ "model.layers.40.moe.up_proj.weight": "model-00040.safetensors",
594
+ "model.layers.40.post_attention_layernorm.weight": "model-00001.safetensors",
595
+ "model.layers.40.self_attn.g_proj.weight": "model-00002.safetensors",
596
+ "model.layers.40.self_attn.k_norm.weight": "model-00001.safetensors",
597
+ "model.layers.40.self_attn.k_proj.weight": "model-00002.safetensors",
598
+ "model.layers.40.self_attn.o_proj.weight": "model-00001.safetensors",
599
+ "model.layers.40.self_attn.q_norm.weight": "model-00001.safetensors",
600
+ "model.layers.40.self_attn.q_proj.weight": "model-00002.safetensors",
601
+ "model.layers.40.self_attn.v_proj.weight": "model-00002.safetensors",
602
+ "model.layers.40.share_expert.down_proj.weight": "model-00001.safetensors",
603
+ "model.layers.40.share_expert.gate_proj.weight": "model-00002.safetensors",
604
+ "model.layers.40.share_expert.up_proj.weight": "model-00002.safetensors",
605
+ "model.layers.41.input_layernorm.weight": "model-00001.safetensors",
606
+ "model.layers.41.moe.down_proj.weight": "model-00041.safetensors",
607
+ "model.layers.41.moe.gate.weight": "model-00001.safetensors",
608
+ "model.layers.41.moe.gate_proj.weight": "model-00041.safetensors",
609
+ "model.layers.41.moe.router_bias": "model-00001.safetensors",
610
+ "model.layers.41.moe.up_proj.weight": "model-00041.safetensors",
611
+ "model.layers.41.post_attention_layernorm.weight": "model-00001.safetensors",
612
+ "model.layers.41.self_attn.g_proj.weight": "model-00002.safetensors",
613
+ "model.layers.41.self_attn.k_norm.weight": "model-00001.safetensors",
614
+ "model.layers.41.self_attn.k_proj.weight": "model-00002.safetensors",
615
+ "model.layers.41.self_attn.o_proj.weight": "model-00001.safetensors",
616
+ "model.layers.41.self_attn.q_norm.weight": "model-00001.safetensors",
617
+ "model.layers.41.self_attn.q_proj.weight": "model-00002.safetensors",
618
+ "model.layers.41.self_attn.v_proj.weight": "model-00002.safetensors",
619
+ "model.layers.41.share_expert.down_proj.weight": "model-00001.safetensors",
620
+ "model.layers.41.share_expert.gate_proj.weight": "model-00002.safetensors",
621
+ "model.layers.41.share_expert.up_proj.weight": "model-00002.safetensors",
622
+ "model.layers.42.input_layernorm.weight": "model-00001.safetensors",
623
+ "model.layers.42.moe.down_proj.weight": "model-00042.safetensors",
624
+ "model.layers.42.moe.gate.weight": "model-00001.safetensors",
625
+ "model.layers.42.moe.gate_proj.weight": "model-00042.safetensors",
626
+ "model.layers.42.moe.router_bias": "model-00001.safetensors",
627
+ "model.layers.42.moe.up_proj.weight": "model-00042.safetensors",
628
+ "model.layers.42.post_attention_layernorm.weight": "model-00001.safetensors",
629
+ "model.layers.42.self_attn.g_proj.weight": "model-00002.safetensors",
630
+ "model.layers.42.self_attn.k_norm.weight": "model-00001.safetensors",
631
+ "model.layers.42.self_attn.k_proj.weight": "model-00002.safetensors",
632
+ "model.layers.42.self_attn.o_proj.weight": "model-00001.safetensors",
633
+ "model.layers.42.self_attn.q_norm.weight": "model-00001.safetensors",
634
+ "model.layers.42.self_attn.q_proj.weight": "model-00002.safetensors",
635
+ "model.layers.42.self_attn.v_proj.weight": "model-00002.safetensors",
636
+ "model.layers.42.share_expert.down_proj.weight": "model-00001.safetensors",
637
+ "model.layers.42.share_expert.gate_proj.weight": "model-00002.safetensors",
638
+ "model.layers.42.share_expert.up_proj.weight": "model-00002.safetensors",
639
+ "model.layers.43.input_layernorm.weight": "model-00001.safetensors",
640
+ "model.layers.43.moe.down_proj.weight": "model-00043.safetensors",
641
+ "model.layers.43.moe.gate.weight": "model-00001.safetensors",
642
+ "model.layers.43.moe.gate_proj.weight": "model-00043.safetensors",
643
+ "model.layers.43.moe.router_bias": "model-00001.safetensors",
644
+ "model.layers.43.moe.up_proj.weight": "model-00043.safetensors",
645
+ "model.layers.43.post_attention_layernorm.weight": "model-00001.safetensors",
646
+ "model.layers.43.self_attn.g_proj.weight": "model-00002.safetensors",
647
+ "model.layers.43.self_attn.k_norm.weight": "model-00001.safetensors",
648
+ "model.layers.43.self_attn.k_proj.weight": "model-00002.safetensors",
649
+ "model.layers.43.self_attn.o_proj.weight": "model-00001.safetensors",
650
+ "model.layers.43.self_attn.q_norm.weight": "model-00001.safetensors",
651
+ "model.layers.43.self_attn.q_proj.weight": "model-00002.safetensors",
652
+ "model.layers.43.self_attn.v_proj.weight": "model-00002.safetensors",
653
+ "model.layers.43.share_expert.down_proj.weight": "model-00001.safetensors",
654
+ "model.layers.43.share_expert.gate_proj.weight": "model-00002.safetensors",
655
+ "model.layers.43.share_expert.up_proj.weight": "model-00002.safetensors",
656
+ "model.layers.44.input_layernorm.weight": "model-00001.safetensors",
657
+ "model.layers.44.moe.down_proj.weight": "model-00044.safetensors",
658
+ "model.layers.44.moe.gate.weight": "model-00001.safetensors",
659
+ "model.layers.44.moe.gate_proj.weight": "model-00044.safetensors",
660
+ "model.layers.44.moe.router_bias": "model-00001.safetensors",
661
+ "model.layers.44.moe.up_proj.weight": "model-00044.safetensors",
662
+ "model.layers.44.post_attention_layernorm.weight": "model-00001.safetensors",
663
+ "model.layers.44.self_attn.g_proj.weight": "model-00002.safetensors",
664
+ "model.layers.44.self_attn.k_norm.weight": "model-00001.safetensors",
665
+ "model.layers.44.self_attn.k_proj.weight": "model-00002.safetensors",
666
+ "model.layers.44.self_attn.o_proj.weight": "model-00001.safetensors",
667
+ "model.layers.44.self_attn.q_norm.weight": "model-00001.safetensors",
668
+ "model.layers.44.self_attn.q_proj.weight": "model-00002.safetensors",
669
+ "model.layers.44.self_attn.v_proj.weight": "model-00002.safetensors",
670
+ "model.layers.44.share_expert.down_proj.weight": "model-00001.safetensors",
671
+ "model.layers.44.share_expert.gate_proj.weight": "model-00002.safetensors",
672
+ "model.layers.44.share_expert.up_proj.weight": "model-00002.safetensors",
673
+ "model.layers.45.eh_proj.weight": "model-00002.safetensors",
674
+ "model.layers.45.enorm.weight": "model-00002.safetensors",
675
+ "model.layers.45.hnorm.weight": "model-00002.safetensors",
676
+ "model.layers.45.input_layernorm.weight": "model-00002.safetensors",
677
+ "model.layers.45.mlp.down_proj.weight": "model-00002.safetensors",
678
+ "model.layers.45.mlp.gate_proj.weight": "model-00002.safetensors",
679
+ "model.layers.45.mlp.up_proj.weight": "model-00002.safetensors",
680
+ "model.layers.45.post_attention_layernorm.weight": "model-00002.safetensors",
681
+ "model.layers.45.self_attn.g_proj.weight": "model-00002.safetensors",
682
+ "model.layers.45.self_attn.k_norm.weight": "model-00002.safetensors",
683
+ "model.layers.45.self_attn.k_proj.weight": "model-00002.safetensors",
684
+ "model.layers.45.self_attn.o_proj.weight": "model-00002.safetensors",
685
+ "model.layers.45.self_attn.q_norm.weight": "model-00002.safetensors",
686
+ "model.layers.45.self_attn.q_proj.weight": "model-00002.safetensors",
687
+ "model.layers.45.self_attn.v_proj.weight": "model-00002.safetensors",
688
+ "model.layers.45.transformer.shared_head.norm.weight": "model-00001.safetensors",
689
+ "model.layers.45.transformer.shared_head.output.weight": "model-00002.safetensors",
690
+ "model.layers.46.eh_proj.weight": "model-00002.safetensors",
691
+ "model.layers.46.enorm.weight": "model-00002.safetensors",
692
+ "model.layers.46.hnorm.weight": "model-00002.safetensors",
693
+ "model.layers.46.input_layernorm.weight": "model-00002.safetensors",
694
+ "model.layers.46.mlp.down_proj.weight": "model-00002.safetensors",
695
+ "model.layers.46.mlp.gate_proj.weight": "model-00002.safetensors",
696
+ "model.layers.46.mlp.up_proj.weight": "model-00002.safetensors",
697
+ "model.layers.46.post_attention_layernorm.weight": "model-00002.safetensors",
698
+ "model.layers.46.self_attn.g_proj.weight": "model-00002.safetensors",
699
+ "model.layers.46.self_attn.k_norm.weight": "model-00002.safetensors",
700
+ "model.layers.46.self_attn.k_proj.weight": "model-00002.safetensors",
701
+ "model.layers.46.self_attn.o_proj.weight": "model-00002.safetensors",
702
+ "model.layers.46.self_attn.q_norm.weight": "model-00002.safetensors",
703
+ "model.layers.46.self_attn.q_proj.weight": "model-00002.safetensors",
704
+ "model.layers.46.self_attn.v_proj.weight": "model-00002.safetensors",
705
+ "model.layers.46.transformer.shared_head.norm.weight": "model-00002.safetensors",
706
+ "model.layers.46.transformer.shared_head.output.weight": "model-00002.safetensors",
707
+ "model.layers.47.eh_proj.weight": "model-00002.safetensors",
708
+ "model.layers.47.enorm.weight": "model-00002.safetensors",
709
+ "model.layers.47.hnorm.weight": "model-00002.safetensors",
710
+ "model.layers.47.input_layernorm.weight": "model-00002.safetensors",
711
+ "model.layers.47.mlp.down_proj.weight": "model-00002.safetensors",
712
+ "model.layers.47.mlp.gate_proj.weight": "model-00002.safetensors",
713
+ "model.layers.47.mlp.up_proj.weight": "model-00002.safetensors",
714
+ "model.layers.47.post_attention_layernorm.weight": "model-00002.safetensors",
715
+ "model.layers.47.self_attn.g_proj.weight": "model-00002.safetensors",
716
+ "model.layers.47.self_attn.k_norm.weight": "model-00002.safetensors",
717
+ "model.layers.47.self_attn.k_proj.weight": "model-00002.safetensors",
718
+ "model.layers.47.self_attn.o_proj.weight": "model-00002.safetensors",
719
+ "model.layers.47.self_attn.q_norm.weight": "model-00002.safetensors",
720
+ "model.layers.47.self_attn.q_proj.weight": "model-00002.safetensors",
721
+ "model.layers.47.self_attn.v_proj.weight": "model-00002.safetensors",
722
+ "model.layers.47.transformer.shared_head.norm.weight": "model-00002.safetensors",
723
+ "model.layers.47.transformer.shared_head.output.weight": "model-00002.safetensors",
724
+ "model.layers.5.input_layernorm.weight": "model-00001.safetensors",
725
+ "model.layers.5.moe.down_proj.weight": "model-00005.safetensors",
726
+ "model.layers.5.moe.gate.weight": "model-00001.safetensors",
727
+ "model.layers.5.moe.gate_proj.weight": "model-00005.safetensors",
728
+ "model.layers.5.moe.router_bias": "model-00001.safetensors",
729
+ "model.layers.5.moe.up_proj.weight": "model-00005.safetensors",
730
+ "model.layers.5.post_attention_layernorm.weight": "model-00001.safetensors",
731
+ "model.layers.5.self_attn.g_proj.weight": "model-00001.safetensors",
732
+ "model.layers.5.self_attn.k_norm.weight": "model-00001.safetensors",
733
+ "model.layers.5.self_attn.k_proj.weight": "model-00001.safetensors",
734
+ "model.layers.5.self_attn.o_proj.weight": "model-00001.safetensors",
735
+ "model.layers.5.self_attn.q_norm.weight": "model-00001.safetensors",
736
+ "model.layers.5.self_attn.q_proj.weight": "model-00001.safetensors",
737
+ "model.layers.5.self_attn.v_proj.weight": "model-00001.safetensors",
738
+ "model.layers.5.share_expert.down_proj.weight": "model-00001.safetensors",
739
+ "model.layers.5.share_expert.gate_proj.weight": "model-00001.safetensors",
740
+ "model.layers.5.share_expert.up_proj.weight": "model-00001.safetensors",
741
+ "model.layers.6.input_layernorm.weight": "model-00001.safetensors",
742
+ "model.layers.6.moe.down_proj.weight": "model-00006.safetensors",
743
+ "model.layers.6.moe.gate.weight": "model-00001.safetensors",
744
+ "model.layers.6.moe.gate_proj.weight": "model-00006.safetensors",
745
+ "model.layers.6.moe.router_bias": "model-00001.safetensors",
746
+ "model.layers.6.moe.up_proj.weight": "model-00006.safetensors",
747
+ "model.layers.6.post_attention_layernorm.weight": "model-00001.safetensors",
748
+ "model.layers.6.self_attn.g_proj.weight": "model-00001.safetensors",
749
+ "model.layers.6.self_attn.k_norm.weight": "model-00001.safetensors",
750
+ "model.layers.6.self_attn.k_proj.weight": "model-00001.safetensors",
751
+ "model.layers.6.self_attn.o_proj.weight": "model-00001.safetensors",
752
+ "model.layers.6.self_attn.q_norm.weight": "model-00001.safetensors",
753
+ "model.layers.6.self_attn.q_proj.weight": "model-00001.safetensors",
754
+ "model.layers.6.self_attn.v_proj.weight": "model-00001.safetensors",
755
+ "model.layers.6.share_expert.down_proj.weight": "model-00001.safetensors",
756
+ "model.layers.6.share_expert.gate_proj.weight": "model-00001.safetensors",
757
+ "model.layers.6.share_expert.up_proj.weight": "model-00001.safetensors",
758
+ "model.layers.7.input_layernorm.weight": "model-00001.safetensors",
759
+ "model.layers.7.moe.down_proj.weight": "model-00007.safetensors",
760
+ "model.layers.7.moe.gate.weight": "model-00001.safetensors",
761
+ "model.layers.7.moe.gate_proj.weight": "model-00007.safetensors",
762
+ "model.layers.7.moe.router_bias": "model-00001.safetensors",
763
+ "model.layers.7.moe.up_proj.weight": "model-00007.safetensors",
764
+ "model.layers.7.post_attention_layernorm.weight": "model-00001.safetensors",
765
+ "model.layers.7.self_attn.g_proj.weight": "model-00001.safetensors",
766
+ "model.layers.7.self_attn.k_norm.weight": "model-00001.safetensors",
767
+ "model.layers.7.self_attn.k_proj.weight": "model-00001.safetensors",
768
+ "model.layers.7.self_attn.o_proj.weight": "model-00001.safetensors",
769
+ "model.layers.7.self_attn.q_norm.weight": "model-00001.safetensors",
770
+ "model.layers.7.self_attn.q_proj.weight": "model-00001.safetensors",
771
+ "model.layers.7.self_attn.v_proj.weight": "model-00001.safetensors",
772
+ "model.layers.7.share_expert.down_proj.weight": "model-00001.safetensors",
773
+ "model.layers.7.share_expert.gate_proj.weight": "model-00001.safetensors",
774
+ "model.layers.7.share_expert.up_proj.weight": "model-00001.safetensors",
775
+ "model.layers.8.input_layernorm.weight": "model-00001.safetensors",
776
+ "model.layers.8.moe.down_proj.weight": "model-00008.safetensors",
777
+ "model.layers.8.moe.gate.weight": "model-00001.safetensors",
778
+ "model.layers.8.moe.gate_proj.weight": "model-00008.safetensors",
779
+ "model.layers.8.moe.router_bias": "model-00001.safetensors",
780
+ "model.layers.8.moe.up_proj.weight": "model-00008.safetensors",
781
+ "model.layers.8.post_attention_layernorm.weight": "model-00001.safetensors",
782
+ "model.layers.8.self_attn.g_proj.weight": "model-00001.safetensors",
783
+ "model.layers.8.self_attn.k_norm.weight": "model-00001.safetensors",
784
+ "model.layers.8.self_attn.k_proj.weight": "model-00001.safetensors",
785
+ "model.layers.8.self_attn.o_proj.weight": "model-00001.safetensors",
786
+ "model.layers.8.self_attn.q_norm.weight": "model-00001.safetensors",
787
+ "model.layers.8.self_attn.q_proj.weight": "model-00001.safetensors",
788
+ "model.layers.8.self_attn.v_proj.weight": "model-00001.safetensors",
789
+ "model.layers.8.share_expert.down_proj.weight": "model-00001.safetensors",
790
+ "model.layers.8.share_expert.gate_proj.weight": "model-00001.safetensors",
791
+ "model.layers.8.share_expert.up_proj.weight": "model-00001.safetensors",
792
+ "model.layers.9.input_layernorm.weight": "model-00001.safetensors",
793
+ "model.layers.9.moe.down_proj.weight": "model-00009.safetensors",
794
+ "model.layers.9.moe.gate.weight": "model-00001.safetensors",
795
+ "model.layers.9.moe.gate_proj.weight": "model-00009.safetensors",
796
+ "model.layers.9.moe.router_bias": "model-00001.safetensors",
797
+ "model.layers.9.moe.up_proj.weight": "model-00009.safetensors",
798
+ "model.layers.9.post_attention_layernorm.weight": "model-00001.safetensors",
799
+ "model.layers.9.self_attn.g_proj.weight": "model-00001.safetensors",
800
+ "model.layers.9.self_attn.k_norm.weight": "model-00001.safetensors",
801
+ "model.layers.9.self_attn.k_proj.weight": "model-00001.safetensors",
802
+ "model.layers.9.self_attn.o_proj.weight": "model-00001.safetensors",
803
+ "model.layers.9.self_attn.q_norm.weight": "model-00001.safetensors",
804
+ "model.layers.9.self_attn.q_proj.weight": "model-00001.safetensors",
805
+ "model.layers.9.self_attn.v_proj.weight": "model-00001.safetensors",
806
+ "model.layers.9.share_expert.down_proj.weight": "model-00001.safetensors",
807
+ "model.layers.9.share_expert.gate_proj.weight": "model-00001.safetensors",
808
+ "model.layers.9.share_expert.up_proj.weight": "model-00001.safetensors",
809
+ "model.norm.weight": "model-00002.safetensors"
810
+ }
811
+ }
modeling_step3p5.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Callable, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers.activations import ACT2FN
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+ from transformers.generation import GenerationMixin
24
+ from transformers.masking_utils import (create_causal_mask,
25
+ create_sliding_window_causal_mask)
26
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
+ from transformers.modeling_layers import GradientCheckpointingLayer
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
29
+ from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
30
+ dynamic_rope_update)
31
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
32
+ PreTrainedModel)
33
+ from transformers.processing_utils import Unpack
34
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
35
+
36
+ from .configuration_step3p5 import Step3p5Config
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ __all__ = ["Step3p5Model", "Step3p5ForCausalLM"]
41
+
42
+ class Step3p5RotaryEmbedding(nn.Module):
43
+
44
+ def __init__(self, config: Step3p5Config, device=None, layer_idx=None):
45
+ super().__init__()
46
+ # BC: "rope_type" was originally "type"
47
+ self.layer_idx = layer_idx
48
+ if config.rope_parameters is not None:
49
+ self.rope_type = config.rope_parameters.get(
50
+ "rope_type", config.rope_parameters.get("type"))
51
+ else:
52
+ self.rope_type = "default"
53
+ self.max_seq_len_cached = config.max_position_embeddings
54
+ self.original_max_seq_len = config.max_position_embeddings
55
+
56
+ partial_rotary_factors = getattr(config, "partial_rotary_factors",
57
+ None)
58
+ if partial_rotary_factors is not None:
59
+ config.partial_rotary_factor = partial_rotary_factors[
60
+ self.layer_idx]
61
+ else:
62
+ config.partial_rotary_factor = 1.0
63
+
64
+ self.rope_theta = config.rope_theta
65
+ if isinstance(config.rope_theta, list):
66
+ self.rope_theta = config.rope_theta.copy()
67
+ config.rope_theta = self.rope_theta[self.layer_idx]
68
+
69
+ self.config = config
70
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
71
+ inv_freq, self.attention_scaling = self.rope_init_fn(
72
+ self.config, device)
73
+
74
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
75
+ self.original_inv_freq = self.inv_freq
76
+ config.rope_theta = self.rope_theta
77
+
78
+ @torch.no_grad()
79
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
80
+ def forward(self, x, position_ids):
81
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
82
+ position_ids.shape[0], -1, 1).to(x.device)
83
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
84
+
85
+ device_type = x.device.type if isinstance(
86
+ x.device.type, str) and x.device.type != "mps" else "cpu"
87
+ with torch.autocast(device_type=device_type,
88
+ enabled=False): # Force float32
89
+ freqs = (inv_freq_expanded.float()
90
+ @ position_ids_expanded.float()).transpose(1, 2)
91
+ emb = torch.cat((freqs, freqs), dim=-1)
92
+ cos = emb.cos() * self.attention_scaling
93
+ sin = emb.sin() * self.attention_scaling
94
+
95
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
96
+
97
+
98
+ def rotate_half(x):
99
+ """Rotates half the hidden dims of the input."""
100
+ x1 = x[..., :x.shape[-1] // 2]
101
+ x2 = x[..., x.shape[-1] // 2:]
102
+ return torch.cat((-x2, x1), dim=-1)
103
+
104
+
105
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
106
+ """Applies Rotary Position Embedding to the query and key tensors.
107
+
108
+ Args:
109
+ q (`torch.Tensor`): The query tensor.
110
+ k (`torch.Tensor`): The key tensor.
111
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
112
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
113
+ position_ids (`torch.Tensor`, *optional*):
114
+ Deprecated and unused.
115
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
116
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
117
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
118
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
119
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
120
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
121
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
122
+ Returns:
123
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
124
+ """
125
+ rotary_dim = cos.shape[-1]
126
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
127
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
128
+
129
+ # Apply rotary embeddings on the first half or full tensor
130
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
131
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
132
+
133
+ # Concatenate back to full shape
134
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
135
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
136
+ return q_embed, k_embed
137
+
138
+
139
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
140
+ """
141
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
142
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
143
+ """
144
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
145
+ if n_rep == 1:
146
+ return hidden_states
147
+ hidden_states = hidden_states[:, :,
148
+ None, :, :].expand(batch,
149
+ num_key_value_heads,
150
+ n_rep, slen, head_dim)
151
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
152
+ head_dim)
153
+
154
+
155
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
156
+ def eager_attention_forward(
157
+ module: nn.Module,
158
+ query: torch.Tensor,
159
+ key: torch.Tensor,
160
+ value: torch.Tensor,
161
+ attention_mask: Optional[torch.Tensor],
162
+ scaling: float,
163
+ dropout: float = 0.0,
164
+ **kwargs,
165
+ ):
166
+ key_states = repeat_kv(key, module.num_key_value_groups)
167
+ value_states = repeat_kv(value, module.num_key_value_groups)
168
+ # breakpoint()
169
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
170
+ if attention_mask is not None:
171
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
172
+ attn_weights = attn_weights + causal_mask
173
+
174
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
175
+ attn_weights = nn.functional.dropout(attn_weights,
176
+ p=dropout,
177
+ training=module.training)
178
+ attn_output = torch.matmul(attn_weights, value_states)
179
+ attn_output = attn_output.transpose(1, 2).contiguous()
180
+
181
+ return attn_output, attn_weights
182
+
183
+ @dataclass
184
+ class Step3p5CausalLMOutputWithPast(ModelOutput):
185
+ r"""
186
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
187
+ Language modeling loss (for next-token prediction).
188
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
189
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
190
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
191
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
192
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
193
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
194
+ `past_key_values` input) to speed up sequential decoding.
195
+ """
196
+
197
+ loss: Optional[torch.FloatTensor] = None
198
+ last_hidden_state: Optional[torch.FloatTensor] = None
199
+ logits: torch.FloatTensor = None
200
+ past_key_values: Optional[list[torch.FloatTensor]] = None
201
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
202
+ attentions: Optional[tuple[torch.FloatTensor]] = None
203
+
204
+
205
+ class Step3p5MLP(nn.Module):
206
+
207
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
208
+ super().__init__()
209
+ self.config = config
210
+ self.hidden_size = config.hidden_size
211
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
212
+ self.gate_proj = nn.Linear(self.hidden_size,
213
+ self.intermediate_size,
214
+ bias=False)
215
+ self.up_proj = nn.Linear(self.hidden_size,
216
+ self.intermediate_size,
217
+ bias=False)
218
+ self.down_proj = nn.Linear(self.intermediate_size,
219
+ self.hidden_size,
220
+ bias=False)
221
+ self.act_fn = ACT2FN["silu"]
222
+ self.limit = swiglu_limit
223
+
224
+ def forward(self, x):
225
+ up = self.up_proj(x)
226
+ gate = self.act_fn(self.gate_proj(x))
227
+ if self.limit is not None:
228
+ gate = gate.clamp(min=None, max=self.limit)
229
+ up = up.clamp(min=-self.limit, max=self.limit)
230
+
231
+ return self.down_proj(gate * up)
232
+
233
+
234
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
235
+ renormalize: bool):
236
+ gating_output = gating_output.float()
237
+ gate_prob = torch.sigmoid(gating_output)
238
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
239
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
240
+ expert_topk_weight = topk_prob
241
+ if renormalize:
242
+ expert_topk_weight = expert_topk_weight / torch.sum(
243
+ expert_topk_weight, dim=-1, keepdim=True)
244
+ return expert_topk_weight, indices
245
+
246
+
247
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
248
+ renormalize: bool):
249
+ gating_output = gating_output.float()
250
+ gate_prob = torch.softmax(gating_output, dim=-1)
251
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
252
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
253
+ expert_topk_weight = topk_prob
254
+ if renormalize:
255
+ expert_topk_weight = expert_topk_weight / torch.sum(
256
+ expert_topk_weight, dim=-1, keepdim=True)
257
+ return expert_topk_weight, indices.to(torch.int32)
258
+
259
+
260
+ class MoELinear(nn.Module):
261
+
262
+ def __init__(self, num_experts, in_features, out_features):
263
+ super().__init__()
264
+ self.num_experts = num_experts
265
+ self.in_features = in_features
266
+ self.out_features = out_features
267
+ self.weight = nn.Parameter(
268
+ torch.empty(num_experts, out_features, in_features))
269
+
270
+ def forward(self, x, expert_id):
271
+ x = F.linear(x.float(), self.weight[expert_id].float())
272
+ return x
273
+
274
+
275
+ class Step3p5MoEMLP(nn.Module):
276
+
277
+ def __init__(self, config, swiglu_limit=None):
278
+ super().__init__()
279
+ self.num_experts = config.moe_num_experts
280
+ self.top_k = config.moe_top_k
281
+ self.hidden_size = config.hidden_size
282
+ self.moe_intermediate_size = config.moe_intermediate_size
283
+
284
+ self.use_moe_router_bias = config.use_moe_router_bias
285
+ if self.use_moe_router_bias:
286
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
287
+ dtype=torch.float32),
288
+ requires_grad=False)
289
+ self.custom_routing_function = self.router_bias_func
290
+ elif config.moe_router_activation == "sigmoid":
291
+ self.custom_routing_function = sigmoid_routing_function
292
+ else:
293
+ self.custom_routing_function = None
294
+ self.need_fp32_gate = config.need_fp32_gate
295
+ self.routed_scaling_factor = getattr(config,
296
+ "moe_router_scaling_factor", 1.0)
297
+
298
+ # gating
299
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
300
+
301
+ self.act_fn = ACT2FN["silu"]
302
+ self.limit = swiglu_limit
303
+
304
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
305
+ self.moe_intermediate_size)
306
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
307
+ self.moe_intermediate_size)
308
+ self.down_proj = MoELinear(self.num_experts,
309
+ self.moe_intermediate_size,
310
+ self.hidden_size)
311
+
312
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
313
+ renormalize: bool):
314
+ gate_prob = torch.sigmoid(gating_output.float())
315
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
316
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
317
+ topk_prob = torch.gather(gate_prob, 1, indices)
318
+ expert_topk_weight = topk_prob
319
+ if renormalize:
320
+ expert_topk_weight = expert_topk_weight / (
321
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
322
+ return expert_topk_weight, indices
323
+
324
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
325
+ #if self.limit is None:
326
+ up = self.up_proj(inputs, expert_id)
327
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
328
+ if self.limit is not None:
329
+ gate = gate.clamp(min=None, max=self.limit)
330
+ up = up.clamp(min=-self.limit, max=self.limit)
331
+
332
+ return self.down_proj(gate * up, expert_id)
333
+
334
+ def forward(self, hidden_states):
335
+ """ """
336
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
337
+ hidden_states = hidden_states.view(-1, hidden_dim)
338
+ if self.need_fp32_gate:
339
+ router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
340
+ else:
341
+ # router_logits: (batch * sequence_length, n_experts)
342
+ router_logits = self.gate(hidden_states)
343
+
344
+ if self.custom_routing_function:
345
+ routing_weights, selected_experts = self.custom_routing_function(
346
+ router_logits, self.top_k, renormalize=True)
347
+ else:
348
+ routing_weights = F.softmax(router_logits,
349
+ dim=1,
350
+ dtype=torch.float)
351
+ routing_weights, selected_experts = torch.topk(routing_weights,
352
+ self.top_k,
353
+ dim=-1)
354
+
355
+ routing_weights = routing_weights * self.routed_scaling_factor
356
+
357
+ final_hidden_states = torch.zeros(
358
+ (batch_size * sequence_length, hidden_dim),
359
+ dtype=hidden_states.dtype,
360
+ device=hidden_states.device)
361
+
362
+ # One hot encode the selected experts to create an expert mask
363
+ # this will be used to easily index which expert is going to be sollicitated
364
+ expert_mask = torch.nn.functional.one_hot(
365
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
366
+
367
+ # Loop over all available experts in the model and perform the computation on each expert
368
+ for expert_idx in range(self.num_experts):
369
+ idx, top_x = torch.where(expert_mask[expert_idx])
370
+
371
+ # Index the correct hidden states and compute the expert hidden state for
372
+ # the current expert. We need to make sure to multiply the output hidden
373
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
374
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
375
+ current_hidden_states = (
376
+ self.get_expert_output(current_state, expert_idx) *
377
+ routing_weights[top_x, idx, None])
378
+
379
+ # However `index_add_` only support torch tensors for indexing so we'll use
380
+ # the `top_x` tensor here.
381
+ final_hidden_states.index_add_(
382
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
383
+ final_hidden_states = final_hidden_states.reshape(
384
+ batch_size, sequence_length, hidden_dim)
385
+ return final_hidden_states
386
+
387
+
388
+ class Step3p5RMSNorm(nn.Module):
389
+
390
+ def __init__(
391
+ self,
392
+ hidden_size: int,
393
+ eps: float = 1e-5,
394
+ ) -> None:
395
+ super().__init__()
396
+ self.weight = nn.Parameter(torch.ones(hidden_size))
397
+ self.variance_epsilon = eps
398
+
399
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
400
+ dtype = x.dtype
401
+ x = x.float()
402
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
403
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
404
+ normed = normed * (self.weight.float() + 1)
405
+ return normed.to(dtype)
406
+ class Step3p5Attention(nn.Module):
407
+
408
+ def __init__(self, config: Step3p5Config, layer_idx):
409
+ super().__init__()
410
+ self.config = config
411
+ self.layer_idx = layer_idx
412
+ self.num_attention_heads = config.num_attention_heads
413
+ self.num_key_value_heads = config.num_attention_groups
414
+
415
+ layer_types = getattr(config, "layer_types", [])
416
+ if layer_types:
417
+ enable_sliding_window = layer_types[
418
+ self.layer_idx] == "sliding_attention"
419
+ else:
420
+ enable_sliding_window = self.layer_idx % 2 == 0
421
+
422
+ if hasattr(config, "yarn_only_types") and layer_types[
423
+ self.layer_idx] not in config.yarn_only_types:
424
+ config.rope_parameters = None
425
+ else:
426
+ config.rope_parameters = getattr(config, "rope_scaling", None)
427
+
428
+ self.sliding_window = config.sliding_window
429
+ if enable_sliding_window:
430
+ self.num_attention_heads = config.attention_other_setting[
431
+ "num_attention_heads"]
432
+ self.num_key_value_heads = config.attention_other_setting[
433
+ "num_attention_groups"]
434
+
435
+ if self.sliding_window is not None and enable_sliding_window:
436
+ self.sliding_window = (self.sliding_window)
437
+ else:
438
+ self.sliding_window = None
439
+ self.head_dim = getattr(config, "head_dim",
440
+ config.hidden_size // self.num_attention_heads)
441
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
442
+
443
+ self.rotary_emb = Step3p5RotaryEmbedding(config, layer_idx=layer_idx)
444
+
445
+ self.q_size = self.num_attention_heads * self.head_dim
446
+ self.kv_size = self.num_key_value_heads * self.head_dim
447
+ self.scaling = self.head_dim**-0.5
448
+
449
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
450
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
451
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
452
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
453
+ self.q_norm = Step3p5RMSNorm(self.head_dim,
454
+ eps=config.rms_norm_eps)
455
+ self.k_norm = Step3p5RMSNorm(self.head_dim,
456
+ eps=config.rms_norm_eps)
457
+
458
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
459
+ if self.use_head_wise_attn_gate:
460
+ self.g_proj = nn.Linear(config.hidden_size,
461
+ self.num_attention_heads,
462
+ bias=False)
463
+
464
+ self.use_rope = True
465
+ use_rope_layers = getattr(config, "use_rope_layers", None)
466
+ if use_rope_layers:
467
+ self.use_rope = use_rope_layers[self.layer_idx]
468
+
469
+ def forward(
470
+ self,
471
+ hidden_states: torch.Tensor,
472
+ attention_mask: Optional[torch.Tensor],
473
+ past_key_value: Optional[Cache] = None,
474
+ cache_position: Optional[torch.LongTensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ **kwargs: Unpack[FlashAttentionKwargs],
477
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
478
+ Optional[Tuple[torch.Tensor]]]:
479
+ input_shape = hidden_states.shape[:-1]
480
+ hidden_shape = (*input_shape, -1, self.head_dim)
481
+
482
+ query_states = self.q_norm(
483
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
484
+ key_states = self.k_norm(
485
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
486
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
487
+ 1, 2)
488
+ if self.use_head_wise_attn_gate:
489
+ gate_states = self.g_proj(hidden_states)
490
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
491
+
492
+ # cos, sin = position_embeddings
493
+ query_states, key_states = apply_rotary_pos_emb(
494
+ query_states, key_states, cos, sin)
495
+
496
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
497
+ if past_key_value is not None:
498
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
499
+ cache_kwargs = {
500
+ "sin": sin,
501
+ "cos": cos,
502
+ "cache_position": cache_position
503
+ }
504
+ key_states, value_states = past_key_value.update(
505
+ key_states, value_states, self.layer_idx, cache_kwargs)
506
+
507
+ attention_interface: Callable = eager_attention_forward
508
+ # TODO: considering FP8;
509
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
510
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
511
+ if self.config._attn_implementation != "eager":
512
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
513
+ self.config._attn_implementation]
514
+
515
+ attn_output, attn_weights = attention_interface(
516
+ self,
517
+ query_states,
518
+ key_states,
519
+ value_states,
520
+ attention_mask,
521
+ dropout=0.0 if not self.training else self.attention_dropout,
522
+ scaling=self.scaling,
523
+ sliding_window=self.sliding_window, # main diff with Llama
524
+ **kwargs,
525
+ )
526
+ attn_output = attn_output.reshape(*input_shape, -1)
527
+ if self.use_head_wise_attn_gate:
528
+ output = attn_output.view(
529
+ *attn_output.shape[:-1], self.num_attention_heads,
530
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
531
+ attn_output = output.view(*attn_output.shape)
532
+ attn_output = self.o_proj(attn_output)
533
+
534
+ return attn_output, attn_weights
535
+
536
+
537
+ class Step3p5DecoderLayer(GradientCheckpointingLayer):
538
+
539
+ def __init__(self, config, layer_idx):
540
+ super().__init__()
541
+ self.hidden_size = config.hidden_size
542
+ self.layer_idx = layer_idx
543
+ self.self_attn = Step3p5Attention(config, layer_idx)
544
+ self.attention_type = config.layer_types[layer_idx]
545
+
546
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
547
+ if moe_layers_enum is not None:
548
+ moe_layers_idx = [
549
+ int(i) for i in moe_layers_enum.strip().split(',')
550
+ ]
551
+ else:
552
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
553
+ self.is_moe_layer = layer_idx in moe_layers_idx
554
+ self.use_moe = False
555
+
556
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
557
+ layer_idx] is not None and config.swiglu_limits_shared[
558
+ layer_idx] != 0:
559
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
560
+ else:
561
+ swiglu_limit_shared = None
562
+ if config.swiglu_limits and config.swiglu_limits[
563
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
564
+ swiglu_limit = config.swiglu_limits[layer_idx]
565
+ else:
566
+ swiglu_limit = None
567
+ if self.is_moe_layer:
568
+ self.moe = Step3p5MoEMLP(config, swiglu_limit=swiglu_limit) #
569
+ self.share_expert = Step3p5MLP(
570
+ config,
571
+ intermediate_size=config.share_expert_dim,
572
+ swiglu_limit=swiglu_limit_shared)
573
+ self.use_moe = True
574
+ else:
575
+ self.mlp = Step3p5MLP(config,
576
+ intermediate_size=config.intermediate_size,
577
+ swiglu_limit=swiglu_limit_shared)
578
+
579
+ self.input_layernorm = Step3p5RMSNorm(
580
+ config.hidden_size,
581
+ eps=config.rms_norm_eps)
582
+ self.post_attention_layernorm = Step3p5RMSNorm(
583
+ config.hidden_size,
584
+ eps=config.rms_norm_eps)
585
+
586
+ def forward(
587
+ self,
588
+ hidden_states: torch.Tensor,
589
+ attention_mask: Optional[torch.Tensor] = None,
590
+ position_ids: Optional[torch.LongTensor] = None,
591
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
592
+ cache_position: Optional[torch.LongTensor] = None,
593
+ **kwargs: Unpack[FlashAttentionKwargs],
594
+ ) -> torch.FloatTensor:
595
+ residual = hidden_states
596
+ hidden_states = self.input_layernorm(hidden_states)
597
+ hidden_states, _ = self.self_attn(
598
+ hidden_states=hidden_states,
599
+ attention_mask=attention_mask,
600
+ position_ids=position_ids,
601
+ past_key_value=past_key_value,
602
+ cache_position=cache_position,
603
+ **kwargs,
604
+ )
605
+ hidden_states = residual + hidden_states
606
+
607
+ # Fully Connected
608
+ residual = hidden_states
609
+ hidden_states = self.post_attention_layernorm(hidden_states)
610
+ if self.use_moe:
611
+ share_output = self.share_expert(hidden_states)
612
+ moe_output = self.moe(hidden_states)
613
+ ffn_output = moe_output + share_output
614
+ else:
615
+ ffn_output = self.mlp(hidden_states)
616
+ if isinstance(ffn_output, tuple):
617
+ hidden_states, _ = ffn_output
618
+ else:
619
+ hidden_states = ffn_output
620
+
621
+ hidden_states = residual + hidden_states
622
+ return hidden_states
623
+
624
+
625
+ class Step3p5PreTrainedModel(PreTrainedModel):
626
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
627
+ # can load the config instead of failing with a NoneType error.
628
+ config_class = Step3p5Config
629
+ supports_gradient_checkpointing = True
630
+ _skip_keys_device_placement = ["past_key_values"]
631
+ _keys_to_ignore_on_load_unexpected = [
632
+ r"model\.layers\.45\.*",
633
+ r"model\.layers\.46\.*",
634
+ r"model\.layers\.47\.*"
635
+ ]
636
+ _supports_flash_attn = False
637
+ _supports_sdpa = True
638
+ _supports_flex_attn = True
639
+ _supports_static_cache = True
640
+ _supports_attention_backend = True
641
+
642
+
643
+ class Step3p5Model(Step3p5PreTrainedModel, GenerationMixin):
644
+ _no_split_modules = ["Step3p5DecoderLayer"]
645
+ base_model_prefix = "model"
646
+ _tied_weights_keys = ["lm_head.weight"]
647
+ config: Step3p5Config
648
+ def __init__(self, config: Step3p5Config):
649
+ super().__init__(config)
650
+ self.padding_idx = config.pad_token_id
651
+ self.vocab_size = config.vocab_size
652
+
653
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
654
+ self.padding_idx)
655
+ self.layers = nn.ModuleList([
656
+ Step3p5DecoderLayer(config, layer_idx)
657
+ for layer_idx in range(config.num_hidden_layers)
658
+ ])
659
+ self.norm = Step3p5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
660
+ self.gradient_checkpointing = False
661
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
662
+
663
+ # Initialize weights and apply final processing
664
+ self.post_init()
665
+
666
+ def get_input_embeddings(self, input_ids):
667
+ return self.embed_tokens(input_ids)
668
+
669
+ @can_return_tuple
670
+ def forward(
671
+ self,
672
+ input_ids: torch.LongTensor = None,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_values: Optional[Cache] = None,
676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
677
+ use_cache: Optional[bool] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ cache_position: Optional[torch.LongTensor] = None,
682
+ **kwargs: Unpack[TransformersKwargs],
683
+ ) -> Union[tuple, BaseModelOutputWithPast]:
684
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
685
+ output_hidden_states = (output_hidden_states
686
+ if output_hidden_states is not None else
687
+ self.config.output_hidden_states)
688
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
689
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
690
+ if (input_ids is None) ^ (inputs_embeds is not None):
691
+ raise ValueError(
692
+ "You must specify exactly one of input_ids or inputs_embeds")
693
+
694
+ if self.gradient_checkpointing and self.training and use_cache:
695
+ logger.warning_once(
696
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
697
+ )
698
+ use_cache = False
699
+
700
+ if inputs_embeds is None:
701
+ inputs_embeds = self.embed_tokens(
702
+ input_ids.to(self.embed_tokens.weight.device))
703
+
704
+ if use_cache and past_key_values is None:
705
+ past_key_values = DynamicCache()
706
+
707
+ if cache_position is None:
708
+ past_seen_tokens = past_key_values.get_seq_length(
709
+ ) if past_key_values is not None else 0
710
+ cache_position = torch.arange(past_seen_tokens,
711
+ past_seen_tokens +
712
+ inputs_embeds.shape[1],
713
+ device=inputs_embeds.device)
714
+
715
+ if position_ids is None:
716
+ position_ids = cache_position.unsqueeze(0)
717
+
718
+ hidden_states = inputs_embeds
719
+
720
+ # It may already have been prepared by e.g. `generate`
721
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
722
+ # Prepare mask arguments
723
+ mask_kwargs = {
724
+ "config": self.config,
725
+ "input_embeds": inputs_embeds,
726
+ "attention_mask": attention_mask,
727
+ "cache_position": cache_position,
728
+ "past_key_values": past_key_values,
729
+ "position_ids": position_ids,
730
+ }
731
+ # Create the masks
732
+ causal_mask_mapping = {
733
+ "full_attention": create_causal_mask(**mask_kwargs),
734
+ }
735
+
736
+ # The sliding window alternating layers are not always activated depending on the config
737
+ if self.has_sliding_layers:
738
+ causal_mask_mapping[
739
+ "sliding_attention"] = create_sliding_window_causal_mask(
740
+ **mask_kwargs)
741
+
742
+ # # create position embeddings to be shared across the decoder layers
743
+ # decoder layers
744
+ all_hidden_states = () if output_hidden_states else None
745
+ all_self_attns = () if output_attentions else None
746
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
747
+ if output_hidden_states:
748
+ all_hidden_states += (hidden_states, )
749
+
750
+ layer_outputs = decoder_layer(
751
+ hidden_states,
752
+ attention_mask=causal_mask_mapping[
753
+ decoder_layer.attention_type],
754
+ position_ids=position_ids,
755
+ past_key_value=past_key_values,
756
+ output_attentions=output_attentions,
757
+ use_cache=use_cache,
758
+ cache_position=cache_position,
759
+ **kwargs,
760
+ )
761
+
762
+ hidden_states = layer_outputs
763
+
764
+ hidden_states = self.norm(hidden_states)
765
+
766
+ return BaseModelOutputWithPast(
767
+ last_hidden_state=hidden_states,
768
+ past_key_values=past_key_values if use_cache else None,
769
+ hidden_states=all_hidden_states,
770
+ attentions=all_self_attns,
771
+ )
772
+
773
+
774
+ class Step3p5ForCausalLM(Step3p5PreTrainedModel, GenerationMixin):
775
+ _tied_weights_keys = ["lm_head.weight"]
776
+ config: Step3p5Config
777
+
778
+ def __init__(self, config: Step3p5Config):
779
+ super().__init__(config)
780
+ self.model = Step3p5Model(config)
781
+ self.lm_head = nn.Linear(config.hidden_size,
782
+ config.vocab_size,
783
+ bias=False)
784
+
785
+ self.post_init()
786
+
787
+ def get_input_embeddings(self):
788
+ return self.model.get_input_embeddings()
789
+
790
+ def set_input_embeddings(self, value):
791
+ self.model.set_input_embeddings(value)
792
+
793
+ def get_output_embeddings(self):
794
+ return self.model.get_output_embeddings()
795
+
796
+ def set_output_embeddings(self, new_embeddings):
797
+ self.model.set_output_embeddings(new_embeddings)
798
+
799
+ def set_decoder(self, decoder):
800
+ self.model.set_decoder(decoder)
801
+
802
+ def get_decoder(self):
803
+ return self.model.get_decoder()
804
+
805
+ def forward(
806
+ self,
807
+ input_ids: torch.LongTensor = None,
808
+ num_patches=None,
809
+ patch_pixel_values=None,
810
+ patch_newline_mask=None,
811
+ attention_mask: Optional[torch.Tensor] = None,
812
+ position_ids: Optional[torch.LongTensor] = None,
813
+ past_key_values: Optional[Cache] = None,
814
+ inputs_embeds: Optional[torch.FloatTensor] = None,
815
+ labels: Optional[torch.LongTensor] = None,
816
+ use_cache: Optional[bool] = None,
817
+ output_attentions: Optional[bool] = None,
818
+ output_hidden_states: Optional[bool] = None,
819
+ return_dict: Optional[bool] = None,
820
+ cache_position: Optional[torch.LongTensor] = None,
821
+ **kwargs: Unpack[TransformersKwargs],
822
+ ) -> Union[tuple, Step3p5CausalLMOutputWithPast]:
823
+ r"""
824
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
826
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
827
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
828
+ Example:
829
+ ```python
830
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
831
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
832
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
833
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
834
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
835
+ >>> # Generate
836
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
837
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
838
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
839
+ ```"""
840
+
841
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
842
+ output_hidden_states = (output_hidden_states
843
+ if output_hidden_states is not None else
844
+ self.config.output_hidden_states)
845
+ # breakpoint()
846
+ outputs = self.model(
847
+ input_ids=input_ids,
848
+ num_patches=num_patches,
849
+ patch_pixel_values=patch_pixel_values,
850
+ patch_newline_mask=patch_newline_mask,
851
+ position_ids=position_ids,
852
+ attention_mask=attention_mask,
853
+ past_key_values=past_key_values,
854
+ inputs_embeds=inputs_embeds,
855
+ use_cache=use_cache,
856
+ output_attentions=output_attentions,
857
+ output_hidden_states=output_hidden_states,
858
+ return_dict=return_dict,
859
+ cache_position=cache_position,
860
+ **kwargs,
861
+ )
862
+ hidden_states = outputs.last_hidden_state
863
+ logits = self.lm_head(hidden_states)
864
+
865
+ return Step3p5CausalLMOutputWithPast(logits=logits, )
866
+
867
+ def prepare_inputs_for_generation(
868
+ self,
869
+ input_ids,
870
+ past_key_values=None,
871
+ inputs_embeds=None,
872
+ pixel_values=None,
873
+ attention_mask=None,
874
+ cache_position=None,
875
+ logits_to_keep=None,
876
+ **kwargs,
877
+ ):
878
+
879
+ model_inputs = super().prepare_inputs_for_generation(
880
+ input_ids,
881
+ past_key_values=past_key_values,
882
+ inputs_embeds=inputs_embeds,
883
+ attention_mask=attention_mask,
884
+ cache_position=cache_position,
885
+ logits_to_keep=logits_to_keep,
886
+ **kwargs,
887
+ )
888
+
889
+ if cache_position[0] == 0:
890
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
891
+ # Otherwise we need pixel values to be passed to model
892
+ model_inputs["pixel_values"] = pixel_values
893
+
894
+ return model_inputs
895
+
896
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
897
+ if key.startswith("language_model."):
898
+ return key[len("language_model."):], True
899
+
900
+ return key, False