awni commited on
Commit
35eb7ff
·
verified ·
1 Parent(s): 59ee79a

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +36 -0
  2. chat_template.jinja +80 -0
  3. config.json +660 -0
  4. configuration_step3p5.py +59 -0
  5. model-00001-of-00043.safetensors +3 -0
  6. model-00002-of-00043.safetensors +3 -0
  7. model-00003-of-00043.safetensors +3 -0
  8. model-00004-of-00043.safetensors +3 -0
  9. model-00005-of-00043.safetensors +3 -0
  10. model-00006-of-00043.safetensors +3 -0
  11. model-00007-of-00043.safetensors +3 -0
  12. model-00008-of-00043.safetensors +3 -0
  13. model-00009-of-00043.safetensors +3 -0
  14. model-00010-of-00043.safetensors +3 -0
  15. model-00011-of-00043.safetensors +3 -0
  16. model-00012-of-00043.safetensors +3 -0
  17. model-00013-of-00043.safetensors +3 -0
  18. model-00014-of-00043.safetensors +3 -0
  19. model-00016-of-00043.safetensors +3 -0
  20. model-00017-of-00043.safetensors +3 -0
  21. model-00018-of-00043.safetensors +3 -0
  22. model-00019-of-00043.safetensors +3 -0
  23. model-00020-of-00043.safetensors +3 -0
  24. model-00021-of-00043.safetensors +3 -0
  25. model-00022-of-00043.safetensors +3 -0
  26. model-00023-of-00043.safetensors +3 -0
  27. model-00024-of-00043.safetensors +3 -0
  28. model-00025-of-00043.safetensors +3 -0
  29. model-00026-of-00043.safetensors +3 -0
  30. model-00027-of-00043.safetensors +3 -0
  31. model-00028-of-00043.safetensors +3 -0
  32. model-00029-of-00043.safetensors +3 -0
  33. model-00030-of-00043.safetensors +3 -0
  34. model-00031-of-00043.safetensors +3 -0
  35. model-00032-of-00043.safetensors +3 -0
  36. model-00033-of-00043.safetensors +3 -0
  37. model-00034-of-00043.safetensors +3 -0
  38. model-00035-of-00043.safetensors +3 -0
  39. model-00036-of-00043.safetensors +3 -0
  40. model-00037-of-00043.safetensors +3 -0
  41. model-00038-of-00043.safetensors +3 -0
  42. model-00039-of-00043.safetensors +3 -0
  43. model-00040-of-00043.safetensors +3 -0
  44. model-00041-of-00043.safetensors +3 -0
  45. model-00042-of-00043.safetensors +3 -0
  46. model-00043-of-00043.safetensors +3 -0
  47. model.safetensors.index.json +0 -0
  48. modeling_step3p5.py +900 -0
  49. tokenizer.json +0 -0
  50. tokenizer_config.json +17 -0
README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: stepfun-ai/Step-3.5-Flash
4
+ library_name: mlx
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - mlx
8
+ ---
9
+
10
+ # mlx-community/Step-3.5-Flash-8bit
11
+
12
+ This model [mlx-community/Step-3.5-Flash-8bit](https://huggingface.co/mlx-community/Step-3.5-Flash-8bit) was
13
+ converted to MLX format from [stepfun-ai/Step-3.5-Flash](https://huggingface.co/stepfun-ai/Step-3.5-Flash)
14
+ using mlx-lm version **0.30.6**.
15
+
16
+ ## Use with mlx
17
+
18
+ ```bash
19
+ pip install mlx-lm
20
+ ```
21
+
22
+ ```python
23
+ from mlx_lm import load, generate
24
+
25
+ model, tokenizer = load("mlx-community/Step-3.5-Flash-8bit")
26
+
27
+ prompt = "hello"
28
+
29
+ if tokenizer.chat_template is not None:
30
+ messages = [{"role": "user", "content": prompt}]
31
+ prompt = tokenizer.apply_chat_template(
32
+ messages, add_generation_prompt=True, return_dict=False,
33
+ )
34
+
35
+ response = generate(model, tokenizer, prompt=prompt, verbose=True)
36
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {{bos_token}}{%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if messages[0].role == 'system' %}
5
+ {{- render_content(messages[0].content) + '\n\n' }}
6
+ {%- endif %}
7
+ {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
8
+ {%- for tool in tools %}
9
+ {{- "\n" }}
10
+ {{- tool | tojson(ensure_ascii=False) }}
11
+ {%- endfor %}
12
+ {{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
13
+ {%- else %}
14
+ {%- if messages[0].role == 'system' %}
15
+ {{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
16
+ {%- endif %}
17
+ {%- endif %}
18
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
19
+ {%- for message in messages[::-1] %}
20
+ {%- set index = (messages|length - 1) - loop.index0 %}
21
+ {%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
22
+ {%- set ns.multi_step_tool = false %}
23
+ {%- set ns.last_query_index = index %}
24
+ {%- endif %}
25
+ {%- endfor %}
26
+ {%- for message in messages %}
27
+ {%- set content = render_content(message.content) %}
28
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
29
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
30
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
31
+ {%- elif message.role == "assistant" %}
32
+ {%- if message.reasoning_content is string %}
33
+ {%- set reasoning_content = render_content(message.reasoning_content) %}
34
+ {%- else %}
35
+ {%- if '</think>' in content %}
36
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
38
+ {%- else %}
39
+ {%- set reasoning_content = '' %}
40
+ {%- endif %}
41
+ {%- endif %}
42
+ {%- if loop.index0 > ns.last_query_index %}
43
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
44
+ {%- else %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content }}
46
+ {%- endif %}
47
+ {%- if message.tool_calls %}
48
+ {%- for tool_call in message.tool_calls %}
49
+ {%- if tool_call.function is defined %}
50
+ {%- set tool_call = tool_call.function %}
51
+ {%- endif %}
52
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
53
+ {%- if tool_call.arguments is defined %}
54
+ {%- set arguments = tool_call.arguments %}
55
+ {%- for args_name, args_value in arguments|items %}
56
+ {{- '<parameter=' + args_name + '>\n' }}
57
+ {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
58
+ {{- args_value }}
59
+ {{- '\n</parameter>\n' }}
60
+ {%- endfor %}
61
+ {%- endif %}
62
+ {{- '</function>\n</tool_call>' }}
63
+ {%- endfor %}
64
+ {%- endif %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- elif message.role == "tool" %}
67
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
68
+ {{- '<|im_start|>tool_response\n' }}
69
+ {%- endif %}
70
+ {{- '<tool_response>' }}
71
+ {{- content }}
72
+ {{- '</tool_response>' }}
73
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
74
+ {{- '<|im_end|>\n' }}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if add_generation_prompt %}
79
+ {{- '<|im_start|>assistant\n<think>\n' }}
80
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step3p5ForCausalLM"
4
+ ],
5
+ "att_impl_type": "GQA",
6
+ "attention_other_setting": {
7
+ "attention_type": "sliding_attention",
8
+ "num_attention_heads": 96,
9
+ "num_attention_groups": 8,
10
+ "head_dim": 128,
11
+ "true_head_dim": 128
12
+ },
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_step3p5.Step3p5Config",
15
+ "AutoModelForCausalLM": "modeling_step3p5.Step3p5ForCausalLM"
16
+ },
17
+ "bos_token_id": 0,
18
+ "eos_token_id": [
19
+ 1,
20
+ 2,
21
+ 128007
22
+ ],
23
+ "head_dim": 128,
24
+ "hidden_size": 4096,
25
+ "intermediate_size": 11264,
26
+ "layer_types": [
27
+ "full_attention",
28
+ "sliding_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "full_attention",
32
+ "sliding_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "full_attention",
36
+ "sliding_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "full_attention",
40
+ "sliding_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "full_attention",
44
+ "sliding_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "full_attention",
48
+ "sliding_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "full_attention",
52
+ "sliding_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "full_attention",
56
+ "sliding_attention",
57
+ "sliding_attention",
58
+ "sliding_attention",
59
+ "full_attention",
60
+ "sliding_attention",
61
+ "sliding_attention",
62
+ "sliding_attention",
63
+ "full_attention",
64
+ "sliding_attention",
65
+ "sliding_attention",
66
+ "sliding_attention",
67
+ "full_attention",
68
+ "sliding_attention",
69
+ "sliding_attention",
70
+ "sliding_attention",
71
+ "full_attention",
72
+ "sliding_attention",
73
+ "sliding_attention",
74
+ "sliding_attention"
75
+ ],
76
+ "max_position_embeddings": 262144,
77
+ "max_seq_len": 262144,
78
+ "model_type": "step3p5",
79
+ "moe_every_n_layer": 1,
80
+ "moe_intermediate_size": 1280,
81
+ "moe_layer_offset": 0,
82
+ "moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
83
+ "moe_num_experts": 288,
84
+ "moe_router_activation": "sigmoid",
85
+ "moe_router_scaling_factor": 3.0,
86
+ "moe_top_k": 8,
87
+ "need_fp32_gate": true,
88
+ "norm_expert_weight": true,
89
+ "num_attention_groups": 8,
90
+ "num_attention_heads": 64,
91
+ "num_hidden_layers": 45,
92
+ "num_nextn_predict_layers": 3,
93
+ "partial_rotary_factors": [
94
+ 0.5,
95
+ 1.0,
96
+ 1.0,
97
+ 1.0,
98
+ 0.5,
99
+ 1.0,
100
+ 1.0,
101
+ 1.0,
102
+ 0.5,
103
+ 1.0,
104
+ 1.0,
105
+ 1.0,
106
+ 0.5,
107
+ 1.0,
108
+ 1.0,
109
+ 1.0,
110
+ 0.5,
111
+ 1.0,
112
+ 1.0,
113
+ 1.0,
114
+ 0.5,
115
+ 1.0,
116
+ 1.0,
117
+ 1.0,
118
+ 0.5,
119
+ 1.0,
120
+ 1.0,
121
+ 1.0,
122
+ 0.5,
123
+ 1.0,
124
+ 1.0,
125
+ 1.0,
126
+ 0.5,
127
+ 1.0,
128
+ 1.0,
129
+ 1.0,
130
+ 0.5,
131
+ 1.0,
132
+ 1.0,
133
+ 1.0,
134
+ 0.5,
135
+ 1.0,
136
+ 1.0,
137
+ 1.0,
138
+ 0.5,
139
+ 1.0,
140
+ 1.0,
141
+ 1.0
142
+ ],
143
+ "quantization": {
144
+ "group_size": 64,
145
+ "bits": 8,
146
+ "mode": "affine",
147
+ "model.layers.3.mlp.gate.gate": {
148
+ "group_size": 64,
149
+ "bits": 8
150
+ },
151
+ "model.layers.4.mlp.gate.gate": {
152
+ "group_size": 64,
153
+ "bits": 8
154
+ },
155
+ "model.layers.5.mlp.gate.gate": {
156
+ "group_size": 64,
157
+ "bits": 8
158
+ },
159
+ "model.layers.6.mlp.gate.gate": {
160
+ "group_size": 64,
161
+ "bits": 8
162
+ },
163
+ "model.layers.7.mlp.gate.gate": {
164
+ "group_size": 64,
165
+ "bits": 8
166
+ },
167
+ "model.layers.8.mlp.gate.gate": {
168
+ "group_size": 64,
169
+ "bits": 8
170
+ },
171
+ "model.layers.9.mlp.gate.gate": {
172
+ "group_size": 64,
173
+ "bits": 8
174
+ },
175
+ "model.layers.10.mlp.gate.gate": {
176
+ "group_size": 64,
177
+ "bits": 8
178
+ },
179
+ "model.layers.11.mlp.gate.gate": {
180
+ "group_size": 64,
181
+ "bits": 8
182
+ },
183
+ "model.layers.12.mlp.gate.gate": {
184
+ "group_size": 64,
185
+ "bits": 8
186
+ },
187
+ "model.layers.13.mlp.gate.gate": {
188
+ "group_size": 64,
189
+ "bits": 8
190
+ },
191
+ "model.layers.14.mlp.gate.gate": {
192
+ "group_size": 64,
193
+ "bits": 8
194
+ },
195
+ "model.layers.15.mlp.gate.gate": {
196
+ "group_size": 64,
197
+ "bits": 8
198
+ },
199
+ "model.layers.16.mlp.gate.gate": {
200
+ "group_size": 64,
201
+ "bits": 8
202
+ },
203
+ "model.layers.17.mlp.gate.gate": {
204
+ "group_size": 64,
205
+ "bits": 8
206
+ },
207
+ "model.layers.18.mlp.gate.gate": {
208
+ "group_size": 64,
209
+ "bits": 8
210
+ },
211
+ "model.layers.19.mlp.gate.gate": {
212
+ "group_size": 64,
213
+ "bits": 8
214
+ },
215
+ "model.layers.20.mlp.gate.gate": {
216
+ "group_size": 64,
217
+ "bits": 8
218
+ },
219
+ "model.layers.21.mlp.gate.gate": {
220
+ "group_size": 64,
221
+ "bits": 8
222
+ },
223
+ "model.layers.22.mlp.gate.gate": {
224
+ "group_size": 64,
225
+ "bits": 8
226
+ },
227
+ "model.layers.23.mlp.gate.gate": {
228
+ "group_size": 64,
229
+ "bits": 8
230
+ },
231
+ "model.layers.24.mlp.gate.gate": {
232
+ "group_size": 64,
233
+ "bits": 8
234
+ },
235
+ "model.layers.25.mlp.gate.gate": {
236
+ "group_size": 64,
237
+ "bits": 8
238
+ },
239
+ "model.layers.26.mlp.gate.gate": {
240
+ "group_size": 64,
241
+ "bits": 8
242
+ },
243
+ "model.layers.27.mlp.gate.gate": {
244
+ "group_size": 64,
245
+ "bits": 8
246
+ },
247
+ "model.layers.28.mlp.gate.gate": {
248
+ "group_size": 64,
249
+ "bits": 8
250
+ },
251
+ "model.layers.29.mlp.gate.gate": {
252
+ "group_size": 64,
253
+ "bits": 8
254
+ },
255
+ "model.layers.30.mlp.gate.gate": {
256
+ "group_size": 64,
257
+ "bits": 8
258
+ },
259
+ "model.layers.31.mlp.gate.gate": {
260
+ "group_size": 64,
261
+ "bits": 8
262
+ },
263
+ "model.layers.32.mlp.gate.gate": {
264
+ "group_size": 64,
265
+ "bits": 8
266
+ },
267
+ "model.layers.33.mlp.gate.gate": {
268
+ "group_size": 64,
269
+ "bits": 8
270
+ },
271
+ "model.layers.34.mlp.gate.gate": {
272
+ "group_size": 64,
273
+ "bits": 8
274
+ },
275
+ "model.layers.35.mlp.gate.gate": {
276
+ "group_size": 64,
277
+ "bits": 8
278
+ },
279
+ "model.layers.36.mlp.gate.gate": {
280
+ "group_size": 64,
281
+ "bits": 8
282
+ },
283
+ "model.layers.37.mlp.gate.gate": {
284
+ "group_size": 64,
285
+ "bits": 8
286
+ },
287
+ "model.layers.38.mlp.gate.gate": {
288
+ "group_size": 64,
289
+ "bits": 8
290
+ },
291
+ "model.layers.39.mlp.gate.gate": {
292
+ "group_size": 64,
293
+ "bits": 8
294
+ },
295
+ "model.layers.40.mlp.gate.gate": {
296
+ "group_size": 64,
297
+ "bits": 8
298
+ },
299
+ "model.layers.41.mlp.gate.gate": {
300
+ "group_size": 64,
301
+ "bits": 8
302
+ },
303
+ "model.layers.42.mlp.gate.gate": {
304
+ "group_size": 64,
305
+ "bits": 8
306
+ },
307
+ "model.layers.43.mlp.gate.gate": {
308
+ "group_size": 64,
309
+ "bits": 8
310
+ },
311
+ "model.layers.44.mlp.gate.gate": {
312
+ "group_size": 64,
313
+ "bits": 8
314
+ }
315
+ },
316
+ "quantization_config": {
317
+ "group_size": 64,
318
+ "bits": 8,
319
+ "mode": "affine",
320
+ "model.layers.3.mlp.gate.gate": {
321
+ "group_size": 64,
322
+ "bits": 8
323
+ },
324
+ "model.layers.4.mlp.gate.gate": {
325
+ "group_size": 64,
326
+ "bits": 8
327
+ },
328
+ "model.layers.5.mlp.gate.gate": {
329
+ "group_size": 64,
330
+ "bits": 8
331
+ },
332
+ "model.layers.6.mlp.gate.gate": {
333
+ "group_size": 64,
334
+ "bits": 8
335
+ },
336
+ "model.layers.7.mlp.gate.gate": {
337
+ "group_size": 64,
338
+ "bits": 8
339
+ },
340
+ "model.layers.8.mlp.gate.gate": {
341
+ "group_size": 64,
342
+ "bits": 8
343
+ },
344
+ "model.layers.9.mlp.gate.gate": {
345
+ "group_size": 64,
346
+ "bits": 8
347
+ },
348
+ "model.layers.10.mlp.gate.gate": {
349
+ "group_size": 64,
350
+ "bits": 8
351
+ },
352
+ "model.layers.11.mlp.gate.gate": {
353
+ "group_size": 64,
354
+ "bits": 8
355
+ },
356
+ "model.layers.12.mlp.gate.gate": {
357
+ "group_size": 64,
358
+ "bits": 8
359
+ },
360
+ "model.layers.13.mlp.gate.gate": {
361
+ "group_size": 64,
362
+ "bits": 8
363
+ },
364
+ "model.layers.14.mlp.gate.gate": {
365
+ "group_size": 64,
366
+ "bits": 8
367
+ },
368
+ "model.layers.15.mlp.gate.gate": {
369
+ "group_size": 64,
370
+ "bits": 8
371
+ },
372
+ "model.layers.16.mlp.gate.gate": {
373
+ "group_size": 64,
374
+ "bits": 8
375
+ },
376
+ "model.layers.17.mlp.gate.gate": {
377
+ "group_size": 64,
378
+ "bits": 8
379
+ },
380
+ "model.layers.18.mlp.gate.gate": {
381
+ "group_size": 64,
382
+ "bits": 8
383
+ },
384
+ "model.layers.19.mlp.gate.gate": {
385
+ "group_size": 64,
386
+ "bits": 8
387
+ },
388
+ "model.layers.20.mlp.gate.gate": {
389
+ "group_size": 64,
390
+ "bits": 8
391
+ },
392
+ "model.layers.21.mlp.gate.gate": {
393
+ "group_size": 64,
394
+ "bits": 8
395
+ },
396
+ "model.layers.22.mlp.gate.gate": {
397
+ "group_size": 64,
398
+ "bits": 8
399
+ },
400
+ "model.layers.23.mlp.gate.gate": {
401
+ "group_size": 64,
402
+ "bits": 8
403
+ },
404
+ "model.layers.24.mlp.gate.gate": {
405
+ "group_size": 64,
406
+ "bits": 8
407
+ },
408
+ "model.layers.25.mlp.gate.gate": {
409
+ "group_size": 64,
410
+ "bits": 8
411
+ },
412
+ "model.layers.26.mlp.gate.gate": {
413
+ "group_size": 64,
414
+ "bits": 8
415
+ },
416
+ "model.layers.27.mlp.gate.gate": {
417
+ "group_size": 64,
418
+ "bits": 8
419
+ },
420
+ "model.layers.28.mlp.gate.gate": {
421
+ "group_size": 64,
422
+ "bits": 8
423
+ },
424
+ "model.layers.29.mlp.gate.gate": {
425
+ "group_size": 64,
426
+ "bits": 8
427
+ },
428
+ "model.layers.30.mlp.gate.gate": {
429
+ "group_size": 64,
430
+ "bits": 8
431
+ },
432
+ "model.layers.31.mlp.gate.gate": {
433
+ "group_size": 64,
434
+ "bits": 8
435
+ },
436
+ "model.layers.32.mlp.gate.gate": {
437
+ "group_size": 64,
438
+ "bits": 8
439
+ },
440
+ "model.layers.33.mlp.gate.gate": {
441
+ "group_size": 64,
442
+ "bits": 8
443
+ },
444
+ "model.layers.34.mlp.gate.gate": {
445
+ "group_size": 64,
446
+ "bits": 8
447
+ },
448
+ "model.layers.35.mlp.gate.gate": {
449
+ "group_size": 64,
450
+ "bits": 8
451
+ },
452
+ "model.layers.36.mlp.gate.gate": {
453
+ "group_size": 64,
454
+ "bits": 8
455
+ },
456
+ "model.layers.37.mlp.gate.gate": {
457
+ "group_size": 64,
458
+ "bits": 8
459
+ },
460
+ "model.layers.38.mlp.gate.gate": {
461
+ "group_size": 64,
462
+ "bits": 8
463
+ },
464
+ "model.layers.39.mlp.gate.gate": {
465
+ "group_size": 64,
466
+ "bits": 8
467
+ },
468
+ "model.layers.40.mlp.gate.gate": {
469
+ "group_size": 64,
470
+ "bits": 8
471
+ },
472
+ "model.layers.41.mlp.gate.gate": {
473
+ "group_size": 64,
474
+ "bits": 8
475
+ },
476
+ "model.layers.42.mlp.gate.gate": {
477
+ "group_size": 64,
478
+ "bits": 8
479
+ },
480
+ "model.layers.43.mlp.gate.gate": {
481
+ "group_size": 64,
482
+ "bits": 8
483
+ },
484
+ "model.layers.44.mlp.gate.gate": {
485
+ "group_size": 64,
486
+ "bits": 8
487
+ }
488
+ },
489
+ "rope_scaling": {
490
+ "rope_type": "llama3",
491
+ "factor": 2.0,
492
+ "original_max_position_embeddings": 131072,
493
+ "low_freq_factor": 1.0,
494
+ "high_freq_factor": 32.0
495
+ },
496
+ "rope_theta": [
497
+ 5000000.0,
498
+ 10000.0,
499
+ 10000.0,
500
+ 10000.0,
501
+ 5000000.0,
502
+ 10000.0,
503
+ 10000.0,
504
+ 10000.0,
505
+ 5000000.0,
506
+ 10000.0,
507
+ 10000.0,
508
+ 10000.0,
509
+ 5000000.0,
510
+ 10000.0,
511
+ 10000.0,
512
+ 10000.0,
513
+ 5000000.0,
514
+ 10000.0,
515
+ 10000.0,
516
+ 10000.0,
517
+ 5000000.0,
518
+ 10000.0,
519
+ 10000.0,
520
+ 10000.0,
521
+ 5000000.0,
522
+ 10000.0,
523
+ 10000.0,
524
+ 10000.0,
525
+ 5000000.0,
526
+ 10000.0,
527
+ 10000.0,
528
+ 10000.0,
529
+ 5000000.0,
530
+ 10000.0,
531
+ 10000.0,
532
+ 10000.0,
533
+ 5000000.0,
534
+ 10000.0,
535
+ 10000.0,
536
+ 10000.0,
537
+ 5000000.0,
538
+ 10000.0,
539
+ 10000.0,
540
+ 10000.0,
541
+ 5000000.0,
542
+ 10000.0,
543
+ 10000.0,
544
+ 10000.0
545
+ ],
546
+ "share_expert_dim": 1280,
547
+ "sink": false,
548
+ "sliding_window": 512,
549
+ "swiglu_limits": [
550
+ 0.0,
551
+ 0.0,
552
+ 0.0,
553
+ 0.0,
554
+ 0.0,
555
+ 0.0,
556
+ 0.0,
557
+ 0.0,
558
+ 0.0,
559
+ 0.0,
560
+ 0.0,
561
+ 0.0,
562
+ 0.0,
563
+ 0.0,
564
+ 0.0,
565
+ 0.0,
566
+ 0.0,
567
+ 0.0,
568
+ 0.0,
569
+ 0.0,
570
+ 0.0,
571
+ 0.0,
572
+ 0.0,
573
+ 0.0,
574
+ 0.0,
575
+ 0.0,
576
+ 0.0,
577
+ 0.0,
578
+ 0.0,
579
+ 0.0,
580
+ 0.0,
581
+ 0.0,
582
+ 0.0,
583
+ 0.0,
584
+ 0.0,
585
+ 0.0,
586
+ 0.0,
587
+ 0.0,
588
+ 0.0,
589
+ 0.0,
590
+ 0.0,
591
+ 0.0,
592
+ 0.0,
593
+ 7,
594
+ 7,
595
+ 0.0,
596
+ 0.0,
597
+ 0.0
598
+ ],
599
+ "swiglu_limits_shared": [
600
+ 0.0,
601
+ 0.0,
602
+ 0.0,
603
+ 0.0,
604
+ 0.0,
605
+ 0.0,
606
+ 0.0,
607
+ 0.0,
608
+ 0.0,
609
+ 0.0,
610
+ 0.0,
611
+ 0.0,
612
+ 0.0,
613
+ 0.0,
614
+ 0.0,
615
+ 0.0,
616
+ 0.0,
617
+ 0.0,
618
+ 0.0,
619
+ 0.0,
620
+ 0.0,
621
+ 0.0,
622
+ 0.0,
623
+ 0.0,
624
+ 0.0,
625
+ 0.0,
626
+ 0.0,
627
+ 0.0,
628
+ 0.0,
629
+ 0.0,
630
+ 0.0,
631
+ 0.0,
632
+ 0.0,
633
+ 0.0,
634
+ 0.0,
635
+ 0.0,
636
+ 0.0,
637
+ 0.0,
638
+ 0.0,
639
+ 0.0,
640
+ 0.0,
641
+ 0.0,
642
+ 0.0,
643
+ 0.0,
644
+ 16,
645
+ 0.0,
646
+ 0.0,
647
+ 0.0
648
+ ],
649
+ "torch_dtype": "bfloat16",
650
+ "use_head_wise_attn_gate": true,
651
+ "use_moe": true,
652
+ "use_moe_router_bias": true,
653
+ "use_qk_norm": true,
654
+ "use_rope_layers": [],
655
+ "vocab_size": 128896,
656
+ "yarn_only_types": [
657
+ "full_attention"
658
+ ],
659
+ "zero_centered": true
660
+ }
configuration_step3p5.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+
7
+ class Step3p5Config(PretrainedConfig):
8
+ model_type = "step3p5"
9
+ architectures = ["Step3p5ForCausalLM"]
10
+
11
+ def __init__(
12
+ self,
13
+ hidden_size: int = 4096,
14
+ intermediate_size: int = 11264,
15
+ num_attention_heads: int = 64,
16
+ num_attention_groups: int = 8,
17
+ num_hidden_layers: int = 45,
18
+ max_seq_len: int = 128000,
19
+ vocab_size: int = 128815,
20
+ rms_norm_eps: float = 1e-5,
21
+ moe_intermediate_size: int = 1280,
22
+ moe_num_experts: int = 288,
23
+ moe_top_k: int = 8,
24
+ rope_theta: float = 10000,
25
+ rope_scaling: Optional[dict[str, Any]] = None,
26
+ max_position_embeddings: int = 128000,
27
+ share_expert_dims: int = 1280,
28
+ head_dim: int = 128,
29
+ norm_expert_weight: bool = True,
30
+ layer_types: list[str] = None,
31
+ sliding_window: Optional[int] = None,
32
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
33
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
34
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
36
+ **kwargs,
37
+ ) -> None:
38
+ self.hidden_size = hidden_size
39
+ self.intermediate_size = intermediate_size
40
+ self.num_attention_heads = num_attention_heads
41
+ self.num_attention_groups = num_attention_groups
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.max_seq_len = max_seq_len
44
+ self.vocab_size = vocab_size
45
+ self.rms_norm_eps = rms_norm_eps
46
+ self.moe_intermediate_size = moe_intermediate_size
47
+ self.moe_num_experts = moe_num_experts
48
+ self.moe_top_k = moe_top_k
49
+ self.rope_theta = rope_theta
50
+ self.rope_scaling = rope_scaling
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.share_expert_dim = share_expert_dims
53
+ self.head_dim = head_dim
54
+ self.norm_expert_weight = norm_expert_weight
55
+ self.moe_layers_enum = moe_layers_enum
56
+ self.layer_types = layer_types
57
+ self.sliding_window = sliding_window
58
+ super().__init__(**kwargs)
59
+
model-00001-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d770a5a89e4252738c9a68a83ddcb5dbc45f300bc5a27913f60d3681ff968278
3
+ size 4641456013
model-00002-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d1c775394a619df550237c4dceb7ccc3c923b9c38717699aa17ec62bb6ec9e4
3
+ size 4911446438
model-00003-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53d74b67292244b783bd19aa6ef7bd5bc61ce74e7343de9634569cb6f89f93d4
3
+ size 4947237235
model-00004-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af37c008dfab4502b7f68da1dd9351e8fa7ec4748e646177587b9d12b1927a3b
3
+ size 4947237299
model-00005-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca9a70715b4ad1bbac647e2b8af738258d800bf93bd71253a5ba475a4644136c
3
+ size 4947237291
model-00006-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c26080c6faea68cb5a0013d68992962ddc61fe77bbf91cc81d880521f504bbb
3
+ size 4911446440
model-00007-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1974db0f22bf4b592e8edcc0b97451b8f63bea560f19232e755c8f28b7a354b6
3
+ size 4947237261
model-00008-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92e14eddd13236ee3faad12682fb700113e76bf21af0126191893b2b10d6f498
3
+ size 4947237326
model-00009-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8c0f41ca546e31f6b2ecad78b5bbc4f0aa083f6a16b59459969aa56923c0511
3
+ size 4947237340
model-00010-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c52a5f792ca1e290c640b4800dce9885c3a9f6daf5b82879e021fd7c1ad5ca77
3
+ size 4911446425
model-00011-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29716e395e0e3436a6c51fb611774785a8baf8cf4310b8ace005867ae5c75cc
3
+ size 4947237330
model-00012-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aee33f764a43e86884e8f21b7384b3fca093a33aa8d16a74da00f06162b825cd
3
+ size 4947237332
model-00013-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43ec976e7b816d4bb625aa67a3ce1523516ff1c15c6c4b4f5ff682b09e887960
3
+ size 4947237330
model-00014-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a561d60a00b07cb265c864b03955924ccb0406a55b11e9287c53a413f6cf8888
3
+ size 4911446485
model-00016-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0c411b06cb2ae7cba66102a7b4d48cc1e869f1a78ced95406102c4088ddee2b
3
+ size 4947237332
model-00017-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd044a39f364fd4a025153b981d5c30487119dd6737252d754020e683a02a6ed
3
+ size 4947237336
model-00018-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea9fc4ca2218c701f4ebc04c0a025306b9dd4482aee62f82705f7d3f0267a125
3
+ size 4911446481
model-00019-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cbf3755cf573c890b398251d92d5811126862ee75bd1cff7cadfe8f579d424b
3
+ size 4947237340
model-00020-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea3a4bf6c8e22b8e0858c0a92e5bc263f9d50a8b455c65a61a78a29f728895fa
3
+ size 4947237272
model-00021-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:346aee1d5859712459bbe728d4b6c7c9b1ee7668306e329a9ac6352d03aba1b2
3
+ size 4947237330
model-00022-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae449e3ba5d679fef7d88f8eca3098a8e068c071a3b9b63002b7abaf6bbb15c6
3
+ size 4911446485
model-00023-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdca71bb56e0f173c385933c503b2c72648bf1fbf21cedfb5f2f9d98c4ce9989
3
+ size 4947237330
model-00024-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3a5738637981043b5f5c3a89c963f5a9734df1d684f6c1302a108ce983908cc
3
+ size 4947237298
model-00025-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:173805baec916ef8dccfeeab689e3c8840ffc6100085ddeb260f5f236d307449
3
+ size 4947237336
model-00026-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:404d713c6d3836005d5c0d857adbc3fe47873000557472b2cdab0e96aed813d4
3
+ size 4911446479
model-00027-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42159e0c99d0cbac835b41c47ba82c681946731e95e0347cd4f4d45950bfde75
3
+ size 4947237260
model-00028-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5c3ce7e953c6b7bd4a1f9757d97370d6024c02f54617adb4f148e8248c0aeff
3
+ size 4947237288
model-00029-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc9cb80097c5ed635ba65ac75b664873f29d9f880e3e662b5fd2942b9760f061
3
+ size 4947237332
model-00030-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88c5598ea69a3396bf6a83ef9f9b4433e4313240b9b2c64b180f3728fc4811e7
3
+ size 4911446485
model-00031-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e923907b5e4bbf4cfc8317a82ff7e4d483092820c176afc3e42219ec01bb843d
3
+ size 4947237328
model-00032-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1510805d306e8dce381b3f4c81934bf57112be31d33bc805b7b0eaa124b6ae95
3
+ size 4947237316
model-00033-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24ff3d5995f916b35f5686200d51cb5e7c888bdf8fd4882b07b70b66590dd3dd
3
+ size 4947237302
model-00034-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fae8b52f08759a7fb5c7293a747b53bd5b00b9933b95b567c7a0ed80e0f90b70
3
+ size 4911446481
model-00035-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9f30597203f30ea6f589346d1e7c32f7c27f4a3aba4dabed879fefbb818389b
3
+ size 4947237336
model-00036-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:003c0642a7bdf84ed60b64497e720df885b1d0508fc327a5969212827f28221e
3
+ size 4947237290
model-00037-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:026b07e8cef4b3212bffa1425694832abe7facb1de1effc24942a141df59bcc8
3
+ size 4947237340
model-00038-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93dbd87f8b269cef885e3ba140f833419e632a657ef67a8aaef7d550df313946
3
+ size 4911446485
model-00039-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:644076c229805a95924cf8b089c5e7e9131363d248fd6638441d665a7746efb0
3
+ size 4947237332
model-00040-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f213f3070080b02d79dd6aae5483ef75cccc9c26700b8956a808822c406d4e75
3
+ size 4947237336
model-00041-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19f5db8f803473878a8b1555a2a0d92a98928f179a564b1adce2a29d0c3439fb
3
+ size 4947237292
model-00042-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb48f795d7c69636f08d8f66c0d201b28476351f9565dcab9b3fe87b2ded6f1
3
+ size 4911446485
model-00043-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf4db23fd4b6d2446332d2e8b17ef32aca3b1d66aa360bc3434c8af0cda20ba2
3
+ size 2182015062
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_step3p5.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Callable, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers.activations import ACT2FN
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+ from transformers.generation import GenerationMixin
24
+ from transformers.masking_utils import (create_causal_mask,
25
+ create_sliding_window_causal_mask)
26
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
+ from transformers.modeling_layers import GradientCheckpointingLayer
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
29
+ from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
30
+ dynamic_rope_update)
31
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
32
+ PreTrainedModel)
33
+ from transformers.processing_utils import Unpack
34
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
35
+
36
+ from .configuration_step3p5 import Step3p5Config
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ __all__ = ["Step3p5Model", "Step3p5ForCausalLM"]
41
+
42
+ class Step3p5RotaryEmbedding(nn.Module):
43
+
44
+ def __init__(self, config: Step3p5Config, device=None, layer_idx=None):
45
+ super().__init__()
46
+ # BC: "rope_type" was originally "type"
47
+ self.layer_idx = layer_idx
48
+ if config.rope_parameters is not None:
49
+ self.rope_type = config.rope_parameters.get(
50
+ "rope_type", config.rope_parameters.get("type"))
51
+ else:
52
+ self.rope_type = "default"
53
+ self.max_seq_len_cached = config.max_position_embeddings
54
+ self.original_max_seq_len = config.max_position_embeddings
55
+
56
+ partial_rotary_factors = getattr(config, "partial_rotary_factors",
57
+ None)
58
+ if partial_rotary_factors is not None:
59
+ config.partial_rotary_factor = partial_rotary_factors[
60
+ self.layer_idx]
61
+ else:
62
+ config.partial_rotary_factor = 1.0
63
+
64
+ self.rope_theta = config.rope_theta
65
+ if isinstance(config.rope_theta, list):
66
+ self.rope_theta = config.rope_theta.copy()
67
+ config.rope_theta = self.rope_theta[self.layer_idx]
68
+
69
+ self.config = config
70
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
71
+ inv_freq, self.attention_scaling = self.rope_init_fn(
72
+ self.config, device)
73
+
74
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
75
+ self.original_inv_freq = self.inv_freq
76
+ config.rope_theta = self.rope_theta
77
+
78
+ @torch.no_grad()
79
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
80
+ def forward(self, x, position_ids):
81
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
82
+ position_ids.shape[0], -1, 1).to(x.device)
83
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
84
+
85
+ device_type = x.device.type if isinstance(
86
+ x.device.type, str) and x.device.type != "mps" else "cpu"
87
+ with torch.autocast(device_type=device_type,
88
+ enabled=False): # Force float32
89
+ freqs = (inv_freq_expanded.float()
90
+ @ position_ids_expanded.float()).transpose(1, 2)
91
+ emb = torch.cat((freqs, freqs), dim=-1)
92
+ cos = emb.cos() * self.attention_scaling
93
+ sin = emb.sin() * self.attention_scaling
94
+
95
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
96
+
97
+
98
+ def rotate_half(x):
99
+ """Rotates half the hidden dims of the input."""
100
+ x1 = x[..., :x.shape[-1] // 2]
101
+ x2 = x[..., x.shape[-1] // 2:]
102
+ return torch.cat((-x2, x1), dim=-1)
103
+
104
+
105
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
106
+ """Applies Rotary Position Embedding to the query and key tensors.
107
+
108
+ Args:
109
+ q (`torch.Tensor`): The query tensor.
110
+ k (`torch.Tensor`): The key tensor.
111
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
112
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
113
+ position_ids (`torch.Tensor`, *optional*):
114
+ Deprecated and unused.
115
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
116
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
117
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
118
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
119
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
120
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
121
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
122
+ Returns:
123
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
124
+ """
125
+ rotary_dim = cos.shape[-1]
126
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
127
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
128
+
129
+ # Apply rotary embeddings on the first half or full tensor
130
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
131
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
132
+
133
+ # Concatenate back to full shape
134
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
135
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
136
+ return q_embed, k_embed
137
+
138
+
139
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
140
+ """
141
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
142
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
143
+ """
144
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
145
+ if n_rep == 1:
146
+ return hidden_states
147
+ hidden_states = hidden_states[:, :,
148
+ None, :, :].expand(batch,
149
+ num_key_value_heads,
150
+ n_rep, slen, head_dim)
151
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
152
+ head_dim)
153
+
154
+
155
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
156
+ def eager_attention_forward(
157
+ module: nn.Module,
158
+ query: torch.Tensor,
159
+ key: torch.Tensor,
160
+ value: torch.Tensor,
161
+ attention_mask: Optional[torch.Tensor],
162
+ scaling: float,
163
+ dropout: float = 0.0,
164
+ **kwargs,
165
+ ):
166
+ key_states = repeat_kv(key, module.num_key_value_groups)
167
+ value_states = repeat_kv(value, module.num_key_value_groups)
168
+ # breakpoint()
169
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
170
+ if attention_mask is not None:
171
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
172
+ attn_weights = attn_weights + causal_mask
173
+
174
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
175
+ attn_weights = nn.functional.dropout(attn_weights,
176
+ p=dropout,
177
+ training=module.training)
178
+ attn_output = torch.matmul(attn_weights, value_states)
179
+ attn_output = attn_output.transpose(1, 2).contiguous()
180
+
181
+ return attn_output, attn_weights
182
+
183
+ @dataclass
184
+ class Step3p5CausalLMOutputWithPast(ModelOutput):
185
+ r"""
186
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
187
+ Language modeling loss (for next-token prediction).
188
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
189
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
190
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
191
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
192
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
193
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
194
+ `past_key_values` input) to speed up sequential decoding.
195
+ """
196
+
197
+ loss: Optional[torch.FloatTensor] = None
198
+ last_hidden_state: Optional[torch.FloatTensor] = None
199
+ logits: torch.FloatTensor = None
200
+ past_key_values: Optional[list[torch.FloatTensor]] = None
201
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
202
+ attentions: Optional[tuple[torch.FloatTensor]] = None
203
+
204
+
205
+ class Step3p5MLP(nn.Module):
206
+
207
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
208
+ super().__init__()
209
+ self.config = config
210
+ self.hidden_size = config.hidden_size
211
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
212
+ self.gate_proj = nn.Linear(self.hidden_size,
213
+ self.intermediate_size,
214
+ bias=False)
215
+ self.up_proj = nn.Linear(self.hidden_size,
216
+ self.intermediate_size,
217
+ bias=False)
218
+ self.down_proj = nn.Linear(self.intermediate_size,
219
+ self.hidden_size,
220
+ bias=False)
221
+ self.act_fn = ACT2FN["silu"]
222
+ self.limit = swiglu_limit
223
+
224
+ def forward(self, x):
225
+ up = self.up_proj(x)
226
+ gate = self.act_fn(self.gate_proj(x))
227
+ if self.limit is not None:
228
+ gate = gate.clamp(min=None, max=self.limit)
229
+ up = up.clamp(min=-self.limit, max=self.limit)
230
+
231
+ return self.down_proj(gate * up)
232
+
233
+
234
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
235
+ renormalize: bool):
236
+ gating_output = gating_output.float()
237
+ gate_prob = torch.sigmoid(gating_output)
238
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
239
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
240
+ expert_topk_weight = topk_prob
241
+ if renormalize:
242
+ expert_topk_weight = expert_topk_weight / torch.sum(
243
+ expert_topk_weight, dim=-1, keepdim=True)
244
+ return expert_topk_weight, indices
245
+
246
+
247
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
248
+ renormalize: bool):
249
+ gating_output = gating_output.float()
250
+ gate_prob = torch.softmax(gating_output, dim=-1)
251
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
252
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
253
+ expert_topk_weight = topk_prob
254
+ if renormalize:
255
+ expert_topk_weight = expert_topk_weight / torch.sum(
256
+ expert_topk_weight, dim=-1, keepdim=True)
257
+ return expert_topk_weight, indices.to(torch.int32)
258
+
259
+
260
+ class MoELinear(nn.Module):
261
+
262
+ def __init__(self, num_experts, in_features, out_features):
263
+ super().__init__()
264
+ self.num_experts = num_experts
265
+ self.in_features = in_features
266
+ self.out_features = out_features
267
+ self.weight = nn.Parameter(
268
+ torch.empty(num_experts, out_features, in_features))
269
+
270
+ def forward(self, x, expert_id):
271
+ x = F.linear(x.float(), self.weight[expert_id].float())
272
+ return x
273
+
274
+
275
+ class Step3p5MoEMLP(nn.Module):
276
+
277
+ def __init__(self, config, swiglu_limit=None):
278
+ super().__init__()
279
+ self.num_experts = config.moe_num_experts
280
+ self.top_k = config.moe_top_k
281
+ self.hidden_size = config.hidden_size
282
+ self.moe_intermediate_size = config.moe_intermediate_size
283
+
284
+ self.use_moe_router_bias = config.use_moe_router_bias
285
+ if self.use_moe_router_bias:
286
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
287
+ dtype=torch.float32),
288
+ requires_grad=False)
289
+ self.custom_routing_function = self.router_bias_func
290
+ elif config.moe_router_activation == "sigmoid":
291
+ self.custom_routing_function = sigmoid_routing_function
292
+ else:
293
+ self.custom_routing_function = None
294
+ self.need_fp32_gate = config.need_fp32_gate
295
+ self.routed_scaling_factor = getattr(config,
296
+ "moe_router_scaling_factor", 1.0)
297
+
298
+ # gating
299
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
300
+
301
+ self.act_fn = ACT2FN["silu"]
302
+ self.limit = swiglu_limit
303
+
304
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
305
+ self.moe_intermediate_size)
306
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
307
+ self.moe_intermediate_size)
308
+ self.down_proj = MoELinear(self.num_experts,
309
+ self.moe_intermediate_size,
310
+ self.hidden_size)
311
+
312
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
313
+ renormalize: bool):
314
+ gate_prob = torch.sigmoid(gating_output.float())
315
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
316
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
317
+ topk_prob = torch.gather(gate_prob, 1, indices)
318
+ expert_topk_weight = topk_prob
319
+ if renormalize:
320
+ expert_topk_weight = expert_topk_weight / (
321
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
322
+ return expert_topk_weight, indices
323
+
324
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
325
+ #if self.limit is None:
326
+ up = self.up_proj(inputs, expert_id)
327
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
328
+ if self.limit is not None:
329
+ gate = gate.clamp(min=None, max=self.limit)
330
+ up = up.clamp(min=-self.limit, max=self.limit)
331
+
332
+ return self.down_proj(gate * up, expert_id)
333
+
334
+ def forward(self, hidden_states):
335
+ """ """
336
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
337
+ hidden_states = hidden_states.view(-1, hidden_dim)
338
+ if self.need_fp32_gate:
339
+ router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
340
+ else:
341
+ # router_logits: (batch * sequence_length, n_experts)
342
+ router_logits = self.gate(hidden_states)
343
+
344
+ if self.custom_routing_function:
345
+ routing_weights, selected_experts = self.custom_routing_function(
346
+ router_logits, self.top_k, renormalize=True)
347
+ else:
348
+ routing_weights = F.softmax(router_logits,
349
+ dim=1,
350
+ dtype=torch.float)
351
+ routing_weights, selected_experts = torch.topk(routing_weights,
352
+ self.top_k,
353
+ dim=-1)
354
+
355
+ routing_weights = routing_weights * self.routed_scaling_factor
356
+
357
+ final_hidden_states = torch.zeros(
358
+ (batch_size * sequence_length, hidden_dim),
359
+ dtype=hidden_states.dtype,
360
+ device=hidden_states.device)
361
+
362
+ # One hot encode the selected experts to create an expert mask
363
+ # this will be used to easily index which expert is going to be sollicitated
364
+ expert_mask = torch.nn.functional.one_hot(
365
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
366
+
367
+ # Loop over all available experts in the model and perform the computation on each expert
368
+ for expert_idx in range(self.num_experts):
369
+ idx, top_x = torch.where(expert_mask[expert_idx])
370
+
371
+ # Index the correct hidden states and compute the expert hidden state for
372
+ # the current expert. We need to make sure to multiply the output hidden
373
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
374
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
375
+ current_hidden_states = (
376
+ self.get_expert_output(current_state, expert_idx) *
377
+ routing_weights[top_x, idx, None])
378
+
379
+ # However `index_add_` only support torch tensors for indexing so we'll use
380
+ # the `top_x` tensor here.
381
+ final_hidden_states.index_add_(
382
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
383
+ final_hidden_states = final_hidden_states.reshape(
384
+ batch_size, sequence_length, hidden_dim)
385
+ return final_hidden_states
386
+
387
+
388
+ class Step3p5RMSNorm(nn.Module):
389
+
390
+ def __init__(
391
+ self,
392
+ hidden_size: int,
393
+ eps: float = 1e-5,
394
+ ) -> None:
395
+ super().__init__()
396
+ self.weight = nn.Parameter(torch.ones(hidden_size))
397
+ self.variance_epsilon = eps
398
+
399
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
400
+ dtype = x.dtype
401
+ x = x.float()
402
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
403
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
404
+ normed = normed * (self.weight.float() + 1)
405
+ return normed.to(dtype)
406
+ class Step3p5Attention(nn.Module):
407
+
408
+ def __init__(self, config: Step3p5Config, layer_idx):
409
+ super().__init__()
410
+ self.config = config
411
+ self.layer_idx = layer_idx
412
+ self.num_attention_heads = config.num_attention_heads
413
+ self.num_key_value_heads = config.num_attention_groups
414
+
415
+ layer_types = getattr(config, "layer_types", [])
416
+ if layer_types:
417
+ enable_sliding_window = layer_types[
418
+ self.layer_idx] == "sliding_attention"
419
+ else:
420
+ enable_sliding_window = self.layer_idx % 2 == 0
421
+
422
+ if hasattr(config, "yarn_only_types") and layer_types[
423
+ self.layer_idx] not in config.yarn_only_types:
424
+ config.rope_parameters = None
425
+ else:
426
+ config.rope_parameters = getattr(config, "rope_scaling", None)
427
+
428
+ self.sliding_window = config.sliding_window
429
+ if enable_sliding_window:
430
+ self.num_attention_heads = config.attention_other_setting[
431
+ "num_attention_heads"]
432
+ self.num_key_value_heads = config.attention_other_setting[
433
+ "num_attention_groups"]
434
+
435
+ if self.sliding_window is not None and enable_sliding_window:
436
+ self.sliding_window = (self.sliding_window)
437
+ else:
438
+ self.sliding_window = None
439
+ self.head_dim = getattr(config, "head_dim",
440
+ config.hidden_size // self.num_attention_heads)
441
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
442
+
443
+ self.rotary_emb = Step3p5RotaryEmbedding(config, layer_idx=layer_idx)
444
+
445
+ self.q_size = self.num_attention_heads * self.head_dim
446
+ self.kv_size = self.num_key_value_heads * self.head_dim
447
+ self.scaling = self.head_dim**-0.5
448
+
449
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
450
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
451
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
452
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
453
+ self.q_norm = Step3p5RMSNorm(self.head_dim,
454
+ eps=config.rms_norm_eps)
455
+ self.k_norm = Step3p5RMSNorm(self.head_dim,
456
+ eps=config.rms_norm_eps)
457
+
458
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
459
+ if self.use_head_wise_attn_gate:
460
+ self.g_proj = nn.Linear(config.hidden_size,
461
+ self.num_attention_heads,
462
+ bias=False)
463
+
464
+ self.use_rope = True
465
+ use_rope_layers = getattr(config, "use_rope_layers", None)
466
+ if use_rope_layers:
467
+ self.use_rope = use_rope_layers[self.layer_idx]
468
+
469
+ def forward(
470
+ self,
471
+ hidden_states: torch.Tensor,
472
+ attention_mask: Optional[torch.Tensor],
473
+ past_key_value: Optional[Cache] = None,
474
+ cache_position: Optional[torch.LongTensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ **kwargs: Unpack[FlashAttentionKwargs],
477
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
478
+ Optional[Tuple[torch.Tensor]]]:
479
+ input_shape = hidden_states.shape[:-1]
480
+ hidden_shape = (*input_shape, -1, self.head_dim)
481
+
482
+ query_states = self.q_norm(
483
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
484
+ key_states = self.k_norm(
485
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
486
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
487
+ 1, 2)
488
+ if self.use_head_wise_attn_gate:
489
+ gate_states = self.g_proj(hidden_states)
490
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
491
+
492
+ # cos, sin = position_embeddings
493
+ query_states, key_states = apply_rotary_pos_emb(
494
+ query_states, key_states, cos, sin)
495
+
496
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
497
+ if past_key_value is not None:
498
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
499
+ cache_kwargs = {
500
+ "sin": sin,
501
+ "cos": cos,
502
+ "cache_position": cache_position
503
+ }
504
+ key_states, value_states = past_key_value.update(
505
+ key_states, value_states, self.layer_idx, cache_kwargs)
506
+
507
+ attention_interface: Callable = eager_attention_forward
508
+ # TODO: considering FP8;
509
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
510
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
511
+ if self.config._attn_implementation != "eager":
512
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
513
+ self.config._attn_implementation]
514
+
515
+ attn_output, attn_weights = attention_interface(
516
+ self,
517
+ query_states,
518
+ key_states,
519
+ value_states,
520
+ attention_mask,
521
+ dropout=0.0 if not self.training else self.attention_dropout,
522
+ scaling=self.scaling,
523
+ sliding_window=self.sliding_window, # main diff with Llama
524
+ **kwargs,
525
+ )
526
+ attn_output = attn_output.reshape(*input_shape, -1)
527
+ if self.use_head_wise_attn_gate:
528
+ output = attn_output.view(
529
+ *attn_output.shape[:-1], self.num_attention_heads,
530
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
531
+ attn_output = output.view(*attn_output.shape)
532
+ attn_output = self.o_proj(attn_output)
533
+
534
+ return attn_output, attn_weights
535
+
536
+
537
+ class Step3p5DecoderLayer(GradientCheckpointingLayer):
538
+
539
+ def __init__(self, config, layer_idx):
540
+ super().__init__()
541
+ self.hidden_size = config.hidden_size
542
+ self.layer_idx = layer_idx
543
+ self.self_attn = Step3p5Attention(config, layer_idx)
544
+ self.attention_type = config.layer_types[layer_idx]
545
+
546
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
547
+ if moe_layers_enum is not None:
548
+ moe_layers_idx = [
549
+ int(i) for i in moe_layers_enum.strip().split(',')
550
+ ]
551
+ else:
552
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
553
+ self.is_moe_layer = layer_idx in moe_layers_idx
554
+ self.use_moe = False
555
+
556
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
557
+ layer_idx] is not None and config.swiglu_limits_shared[
558
+ layer_idx] != 0:
559
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
560
+ else:
561
+ swiglu_limit_shared = None
562
+ if config.swiglu_limits and config.swiglu_limits[
563
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
564
+ swiglu_limit = config.swiglu_limits[layer_idx]
565
+ else:
566
+ swiglu_limit = None
567
+ if self.is_moe_layer:
568
+ self.moe = Step3p5MoEMLP(config, swiglu_limit=swiglu_limit) #
569
+ self.share_expert = Step3p5MLP(
570
+ config,
571
+ intermediate_size=config.share_expert_dim,
572
+ swiglu_limit=swiglu_limit_shared)
573
+ self.use_moe = True
574
+ else:
575
+ self.mlp = Step3p5MLP(config,
576
+ intermediate_size=config.intermediate_size,
577
+ swiglu_limit=swiglu_limit_shared)
578
+
579
+ self.input_layernorm = Step3p5RMSNorm(
580
+ config.hidden_size,
581
+ eps=config.rms_norm_eps)
582
+ self.post_attention_layernorm = Step3p5RMSNorm(
583
+ config.hidden_size,
584
+ eps=config.rms_norm_eps)
585
+
586
+ def forward(
587
+ self,
588
+ hidden_states: torch.Tensor,
589
+ attention_mask: Optional[torch.Tensor] = None,
590
+ position_ids: Optional[torch.LongTensor] = None,
591
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
592
+ cache_position: Optional[torch.LongTensor] = None,
593
+ **kwargs: Unpack[FlashAttentionKwargs],
594
+ ) -> torch.FloatTensor:
595
+ residual = hidden_states
596
+ hidden_states = self.input_layernorm(hidden_states)
597
+ hidden_states, _ = self.self_attn(
598
+ hidden_states=hidden_states,
599
+ attention_mask=attention_mask,
600
+ position_ids=position_ids,
601
+ past_key_value=past_key_value,
602
+ cache_position=cache_position,
603
+ **kwargs,
604
+ )
605
+ hidden_states = residual + hidden_states
606
+
607
+ # Fully Connected
608
+ residual = hidden_states
609
+ hidden_states = self.post_attention_layernorm(hidden_states)
610
+ if self.use_moe:
611
+ share_output = self.share_expert(hidden_states)
612
+ moe_output = self.moe(hidden_states)
613
+ ffn_output = moe_output + share_output
614
+ else:
615
+ ffn_output = self.mlp(hidden_states)
616
+ if isinstance(ffn_output, tuple):
617
+ hidden_states, _ = ffn_output
618
+ else:
619
+ hidden_states = ffn_output
620
+
621
+ hidden_states = residual + hidden_states
622
+ return hidden_states
623
+
624
+
625
+ class Step3p5PreTrainedModel(PreTrainedModel):
626
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
627
+ # can load the config instead of failing with a NoneType error.
628
+ config_class = Step3p5Config
629
+ supports_gradient_checkpointing = True
630
+ _skip_keys_device_placement = ["past_key_values"]
631
+ _keys_to_ignore_on_load_unexpected = [
632
+ r"model\.layers\.45\.*",
633
+ r"model\.layers\.46\.*",
634
+ r"model\.layers\.47\.*"
635
+ ]
636
+ _supports_flash_attn = False
637
+ _supports_sdpa = True
638
+ _supports_flex_attn = True
639
+ _supports_static_cache = True
640
+ _supports_attention_backend = True
641
+
642
+
643
+ class Step3p5Model(Step3p5PreTrainedModel, GenerationMixin):
644
+ _no_split_modules = ["Step3p5DecoderLayer"]
645
+ base_model_prefix = "model"
646
+ _tied_weights_keys = ["lm_head.weight"]
647
+ config: Step3p5Config
648
+ def __init__(self, config: Step3p5Config):
649
+ super().__init__(config)
650
+ self.padding_idx = config.pad_token_id
651
+ self.vocab_size = config.vocab_size
652
+
653
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
654
+ self.padding_idx)
655
+ self.layers = nn.ModuleList([
656
+ Step3p5DecoderLayer(config, layer_idx)
657
+ for layer_idx in range(config.num_hidden_layers)
658
+ ])
659
+ self.norm = Step3p5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
660
+ self.gradient_checkpointing = False
661
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
662
+
663
+ # Initialize weights and apply final processing
664
+ self.post_init()
665
+
666
+ def get_input_embeddings(self, input_ids):
667
+ return self.embed_tokens(input_ids)
668
+
669
+ @can_return_tuple
670
+ def forward(
671
+ self,
672
+ input_ids: torch.LongTensor = None,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_values: Optional[Cache] = None,
676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
677
+ use_cache: Optional[bool] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ cache_position: Optional[torch.LongTensor] = None,
682
+ **kwargs: Unpack[TransformersKwargs],
683
+ ) -> Union[tuple, BaseModelOutputWithPast]:
684
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
685
+ output_hidden_states = (output_hidden_states
686
+ if output_hidden_states is not None else
687
+ self.config.output_hidden_states)
688
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
689
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
690
+ if (input_ids is None) ^ (inputs_embeds is not None):
691
+ raise ValueError(
692
+ "You must specify exactly one of input_ids or inputs_embeds")
693
+
694
+ if self.gradient_checkpointing and self.training and use_cache:
695
+ logger.warning_once(
696
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
697
+ )
698
+ use_cache = False
699
+
700
+ if inputs_embeds is None:
701
+ inputs_embeds = self.embed_tokens(
702
+ input_ids.to(self.embed_tokens.weight.device))
703
+
704
+ if use_cache and past_key_values is None:
705
+ past_key_values = DynamicCache()
706
+
707
+ if cache_position is None:
708
+ past_seen_tokens = past_key_values.get_seq_length(
709
+ ) if past_key_values is not None else 0
710
+ cache_position = torch.arange(past_seen_tokens,
711
+ past_seen_tokens +
712
+ inputs_embeds.shape[1],
713
+ device=inputs_embeds.device)
714
+
715
+ if position_ids is None:
716
+ position_ids = cache_position.unsqueeze(0)
717
+
718
+ hidden_states = inputs_embeds
719
+
720
+ # It may already have been prepared by e.g. `generate`
721
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
722
+ # Prepare mask arguments
723
+ mask_kwargs = {
724
+ "config": self.config,
725
+ "input_embeds": inputs_embeds,
726
+ "attention_mask": attention_mask,
727
+ "cache_position": cache_position,
728
+ "past_key_values": past_key_values,
729
+ "position_ids": position_ids,
730
+ }
731
+ # Create the masks
732
+ causal_mask_mapping = {
733
+ "full_attention": create_causal_mask(**mask_kwargs),
734
+ }
735
+
736
+ # The sliding window alternating layers are not always activated depending on the config
737
+ if self.has_sliding_layers:
738
+ causal_mask_mapping[
739
+ "sliding_attention"] = create_sliding_window_causal_mask(
740
+ **mask_kwargs)
741
+
742
+ # # create position embeddings to be shared across the decoder layers
743
+ # decoder layers
744
+ all_hidden_states = () if output_hidden_states else None
745
+ all_self_attns = () if output_attentions else None
746
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
747
+ if output_hidden_states:
748
+ all_hidden_states += (hidden_states, )
749
+
750
+ layer_outputs = decoder_layer(
751
+ hidden_states,
752
+ attention_mask=causal_mask_mapping[
753
+ decoder_layer.attention_type],
754
+ position_ids=position_ids,
755
+ past_key_value=past_key_values,
756
+ output_attentions=output_attentions,
757
+ use_cache=use_cache,
758
+ cache_position=cache_position,
759
+ **kwargs,
760
+ )
761
+
762
+ hidden_states = layer_outputs
763
+
764
+ hidden_states = self.norm(hidden_states)
765
+
766
+ return BaseModelOutputWithPast(
767
+ last_hidden_state=hidden_states,
768
+ past_key_values=past_key_values if use_cache else None,
769
+ hidden_states=all_hidden_states,
770
+ attentions=all_self_attns,
771
+ )
772
+
773
+
774
+ class Step3p5ForCausalLM(Step3p5PreTrainedModel, GenerationMixin):
775
+ _tied_weights_keys = ["lm_head.weight"]
776
+ config: Step3p5Config
777
+
778
+ def __init__(self, config: Step3p5Config):
779
+ super().__init__(config)
780
+ self.model = Step3p5Model(config)
781
+ self.lm_head = nn.Linear(config.hidden_size,
782
+ config.vocab_size,
783
+ bias=False)
784
+
785
+ self.post_init()
786
+
787
+ def get_input_embeddings(self):
788
+ return self.model.get_input_embeddings()
789
+
790
+ def set_input_embeddings(self, value):
791
+ self.model.set_input_embeddings(value)
792
+
793
+ def get_output_embeddings(self):
794
+ return self.model.get_output_embeddings()
795
+
796
+ def set_output_embeddings(self, new_embeddings):
797
+ self.model.set_output_embeddings(new_embeddings)
798
+
799
+ def set_decoder(self, decoder):
800
+ self.model.set_decoder(decoder)
801
+
802
+ def get_decoder(self):
803
+ return self.model.get_decoder()
804
+
805
+ def forward(
806
+ self,
807
+ input_ids: torch.LongTensor = None,
808
+ num_patches=None,
809
+ patch_pixel_values=None,
810
+ patch_newline_mask=None,
811
+ attention_mask: Optional[torch.Tensor] = None,
812
+ position_ids: Optional[torch.LongTensor] = None,
813
+ past_key_values: Optional[Cache] = None,
814
+ inputs_embeds: Optional[torch.FloatTensor] = None,
815
+ labels: Optional[torch.LongTensor] = None,
816
+ use_cache: Optional[bool] = None,
817
+ output_attentions: Optional[bool] = None,
818
+ output_hidden_states: Optional[bool] = None,
819
+ return_dict: Optional[bool] = None,
820
+ cache_position: Optional[torch.LongTensor] = None,
821
+ **kwargs: Unpack[TransformersKwargs],
822
+ ) -> Union[tuple, Step3p5CausalLMOutputWithPast]:
823
+ r"""
824
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
826
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
827
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
828
+ Example:
829
+ ```python
830
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
831
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
832
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
833
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
834
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
835
+ >>> # Generate
836
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
837
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
838
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
839
+ ```"""
840
+
841
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
842
+ output_hidden_states = (output_hidden_states
843
+ if output_hidden_states is not None else
844
+ self.config.output_hidden_states)
845
+ # breakpoint()
846
+ outputs = self.model(
847
+ input_ids=input_ids,
848
+ num_patches=num_patches,
849
+ patch_pixel_values=patch_pixel_values,
850
+ patch_newline_mask=patch_newline_mask,
851
+ position_ids=position_ids,
852
+ attention_mask=attention_mask,
853
+ past_key_values=past_key_values,
854
+ inputs_embeds=inputs_embeds,
855
+ use_cache=use_cache,
856
+ output_attentions=output_attentions,
857
+ output_hidden_states=output_hidden_states,
858
+ return_dict=return_dict,
859
+ cache_position=cache_position,
860
+ **kwargs,
861
+ )
862
+ hidden_states = outputs.last_hidden_state
863
+ logits = self.lm_head(hidden_states)
864
+
865
+ return Step3p5CausalLMOutputWithPast(logits=logits, )
866
+
867
+ def prepare_inputs_for_generation(
868
+ self,
869
+ input_ids,
870
+ past_key_values=None,
871
+ inputs_embeds=None,
872
+ pixel_values=None,
873
+ attention_mask=None,
874
+ cache_position=None,
875
+ logits_to_keep=None,
876
+ **kwargs,
877
+ ):
878
+
879
+ model_inputs = super().prepare_inputs_for_generation(
880
+ input_ids,
881
+ past_key_values=past_key_values,
882
+ inputs_embeds=inputs_embeds,
883
+ attention_mask=attention_mask,
884
+ cache_position=cache_position,
885
+ logits_to_keep=logits_to_keep,
886
+ **kwargs,
887
+ )
888
+
889
+ if cache_position[0] == 0:
890
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
891
+ # Otherwise we need pixel values to be passed to model
892
+ model_inputs["pixel_values"] = pixel_values
893
+
894
+ return model_inputs
895
+
896
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
897
+ if key.startswith("language_model."):
898
+ return key[len("language_model."):], True
899
+
900
+ return key, False
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|begin▁of▁sentence|>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "is_local": true,
8
+ "legacy": true,
9
+ "model_max_length": 131072,
10
+ "model_specific_special_tokens": {},
11
+ "pad_token": "<|end▁of▁sentence|>",
12
+ "sp_model_kwargs": {},
13
+ "tokenizer_class": "TokenizersBackend",
14
+ "tool_parser_type": "qwen3_coder",
15
+ "unk_token": null,
16
+ "use_default_system_prompt": false
17
+ }