kernelpool commited on
Commit
c0a1b2a
·
verified ·
1 Parent(s): 6daed83

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +41 -0
  2. chat_template.jinja +89 -0
  3. config.json +677 -0
  4. configuration_step3p7.py +207 -0
  5. model-00001-of-00043.safetensors +3 -0
  6. model-00002-of-00043.safetensors +3 -0
  7. model-00003-of-00043.safetensors +3 -0
  8. model-00004-of-00043.safetensors +3 -0
  9. model-00005-of-00043.safetensors +3 -0
  10. model-00006-of-00043.safetensors +3 -0
  11. model-00007-of-00043.safetensors +3 -0
  12. model-00008-of-00043.safetensors +3 -0
  13. model-00009-of-00043.safetensors +3 -0
  14. model-00011-of-00043.safetensors +3 -0
  15. model-00012-of-00043.safetensors +3 -0
  16. model-00013-of-00043.safetensors +3 -0
  17. model-00014-of-00043.safetensors +3 -0
  18. model-00016-of-00043.safetensors +3 -0
  19. model-00017-of-00043.safetensors +3 -0
  20. model-00018-of-00043.safetensors +3 -0
  21. model-00019-of-00043.safetensors +3 -0
  22. model-00020-of-00043.safetensors +3 -0
  23. model-00021-of-00043.safetensors +3 -0
  24. model-00022-of-00043.safetensors +3 -0
  25. model-00023-of-00043.safetensors +3 -0
  26. model-00024-of-00043.safetensors +3 -0
  27. model-00025-of-00043.safetensors +3 -0
  28. model-00026-of-00043.safetensors +3 -0
  29. model-00027-of-00043.safetensors +3 -0
  30. model-00029-of-00043.safetensors +3 -0
  31. model-00030-of-00043.safetensors +3 -0
  32. model-00031-of-00043.safetensors +3 -0
  33. model-00032-of-00043.safetensors +3 -0
  34. model-00033-of-00043.safetensors +3 -0
  35. model-00034-of-00043.safetensors +3 -0
  36. model-00035-of-00043.safetensors +3 -0
  37. model-00036-of-00043.safetensors +3 -0
  38. model-00037-of-00043.safetensors +3 -0
  39. model-00038-of-00043.safetensors +3 -0
  40. model-00039-of-00043.safetensors +3 -0
  41. model-00040-of-00043.safetensors +3 -0
  42. model-00041-of-00043.safetensors +3 -0
  43. model-00042-of-00043.safetensors +3 -0
  44. model-00043-of-00043.safetensors +3 -0
  45. model.safetensors.index.json +0 -0
  46. modeling_step3p7.py +1405 -0
  47. processing_step3.py +475 -0
  48. tokenizer.json +0 -0
  49. tokenizer_config.json +16 -0
  50. vision_encoder.py +452 -0
README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: mlx
4
+ pipeline_tag: text-generation
5
+ language:
6
+ - en
7
+ tags:
8
+ - vision-language
9
+ - multimodal
10
+ - moe
11
+ - mlx
12
+ base_model: stepfun-ai/Step-3.7-Flash
13
+ ---
14
+
15
+ # mlx-community/Step-3.7-Flash-8bit
16
+
17
+ This model [mlx-community/Step-3.7-Flash-8bit](https://huggingface.co/mlx-community/Step-3.7-Flash-8bit) was
18
+ converted to MLX format from [stepfun-ai/Step-3.7-Flash](https://huggingface.co/stepfun-ai/Step-3.7-Flash)
19
+ using mlx-lm version **0.31.3**.
20
+
21
+ ## Use with mlx
22
+
23
+ ```bash
24
+ pip install mlx-lm
25
+ ```
26
+
27
+ ```python
28
+ from mlx_lm import load, generate
29
+
30
+ model, tokenizer = load("mlx-community/Step-3.7-Flash-8bit")
31
+
32
+ prompt = "hello"
33
+
34
+ if tokenizer.chat_template is not None:
35
+ messages = [{"role": "user", "content": prompt}]
36
+ prompt = tokenizer.apply_chat_template(
37
+ messages, add_generation_prompt=True, return_dict=False,
38
+ )
39
+
40
+ response = generate(model, tokenizer, prompt=prompt, verbose=True)
41
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_message_content(message) %}{% if message.content is none %}{{- '' }}{% elif message.content is string %}{{- message.content }}{% elif message.content is mapping %}{{- message.content['value'] if 'value' in message.content else message.content['text'] }}{% elif message.content is iterable %}{% set ns = namespace(needs_text_separator=false) %}{% for item in message.content %}{% if item.type == 'text' %}{% if ns.needs_text_separator %}{{- ' ' }}{% endif %}{{- item['value'] if 'value' in item else item['text'] }}{% set ns.needs_text_separator = true %}{% elif item.type == 'image' %}<im_patch>{% set ns.needs_text_separator = false %}{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {{bos_token}}{%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if reasoning_effort is defined %}
5
+ {{- "Reasoning: " + reasoning_effort + '\n\n' }}
6
+ {%- endif %}
7
+ {%- if messages[0].role == 'system' %}
8
+ {{- render_message_content(messages[0]) + '\n\n' }}
9
+ {%- endif %}
10
+ {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
11
+ {%- for tool in tools %}
12
+ {{- "\n" }}
13
+ {{- tool | tojson(ensure_ascii=False) }}
14
+ {%- endfor %}
15
+ {{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
16
+ {%- else %}
17
+ {%- if messages[0].role == 'system' %}
18
+ {{- '<|im_start|>system\n' }}
19
+ {%- if reasoning_effort is defined %}
20
+ {{- "Reasoning: " + reasoning_effort + '\n\n' }}
21
+ {%- endif %}
22
+ {{- render_message_content(messages[0]) + '<|im_end|>\n' }}
23
+ {%- elif reasoning_effort is defined %}
24
+ {{- '<|im_start|>system\n' + "Reasoning: " + reasoning_effort + '\n\n' + '<|im_end|>\n' }}
25
+ {%- endif %}
26
+ {%- endif %}
27
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
28
+ {%- for message in messages[::-1] %}
29
+ {%- set index = (messages|length - 1) - loop.index0 %}
30
+ {%- if ns.multi_step_tool and message.role == "user" and render_message_content(message) is string and not(render_message_content(message).startswith('<tool_response>') and render_message_content(message).endswith('</tool_response>')) %}
31
+ {%- set ns.multi_step_tool = false %}
32
+ {%- set ns.last_query_index = index %}
33
+ {%- endif %}
34
+ {%- endfor %}
35
+ {%- for message in messages %}
36
+ {%- set content = render_message_content(message) %}
37
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
38
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
39
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
40
+ {%- elif message.role == "assistant" %}
41
+ {%- if message.reasoning_content is string %}
42
+ {%- set reasoning_content = message.reasoning_content %}
43
+ {%- else %}
44
+ {%- if '</think>' in content %}
45
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
46
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
47
+ {%- else %}
48
+ {%- set reasoning_content = '' %}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- if loop.index0 > ns.last_query_index %}
52
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
53
+ {%- else %}
54
+ {{- '<|im_start|>' + message.role + '\n' + content }}
55
+ {%- endif %}
56
+ {%- if message.tool_calls %}
57
+ {%- for tool_call in message.tool_calls %}
58
+ {%- if tool_call.function is defined %}
59
+ {%- set tool_call = tool_call.function %}
60
+ {%- endif %}
61
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
62
+ {%- if tool_call.arguments is defined %}
63
+ {%- set arguments = tool_call.arguments | fromjson if tool_call.arguments is string else tool_call.arguments %}
64
+ {%- for args_name, args_value in arguments|items %}
65
+ {{- '<parameter=' + args_name + '>\n' }}
66
+ {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
67
+ {{- args_value }}
68
+ {{- '\n</parameter>\n' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '</function>\n</tool_call>' }}
72
+ {%- endfor %}
73
+ {%- endif %}
74
+ {{- '<|im_end|>\n' }}
75
+ {%- elif message.role == "tool" %}
76
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
77
+ {{- '<|im_start|>tool_response\n' }}
78
+ {%- endif %}
79
+ {{- '<tool_response>' }}
80
+ {{- content }}
81
+ {{- '</tool_response>' }}
82
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
83
+ {{- '<|im_end|>\n' }}
84
+ {%- endif %}
85
+ {%- endif %}
86
+ {%- endfor %}
87
+ {%- if add_generation_prompt %}
88
+ {{- '<|im_start|>assistant\n<think>\n' }}
89
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step3p7ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_step3p7.Step3p7Config",
7
+ "AutoProcessor": "processing_step3.Step3VLProcessor",
8
+ "AutoModelForCausalLM": "modeling_step3p7.Step3p7ForConditionalGeneration"
9
+ },
10
+ "im_end_token": "<im_end>",
11
+ "im_patch_token": "<im_patch>",
12
+ "im_start_token": "<im_start>",
13
+ "image_token_id": 128001,
14
+ "image_token_len": 169,
15
+ "model_type": "step3p7",
16
+ "patch_token_len": 81,
17
+ "projector_bias": false,
18
+ "quantization": {
19
+ "group_size": 64,
20
+ "bits": 8,
21
+ "mode": "affine",
22
+ "language_model.model.layers.3.mlp.gate.gate": {
23
+ "group_size": 64,
24
+ "bits": 8
25
+ },
26
+ "language_model.model.layers.4.mlp.gate.gate": {
27
+ "group_size": 64,
28
+ "bits": 8
29
+ },
30
+ "language_model.model.layers.5.mlp.gate.gate": {
31
+ "group_size": 64,
32
+ "bits": 8
33
+ },
34
+ "language_model.model.layers.6.mlp.gate.gate": {
35
+ "group_size": 64,
36
+ "bits": 8
37
+ },
38
+ "language_model.model.layers.7.mlp.gate.gate": {
39
+ "group_size": 64,
40
+ "bits": 8
41
+ },
42
+ "language_model.model.layers.8.mlp.gate.gate": {
43
+ "group_size": 64,
44
+ "bits": 8
45
+ },
46
+ "language_model.model.layers.9.mlp.gate.gate": {
47
+ "group_size": 64,
48
+ "bits": 8
49
+ },
50
+ "language_model.model.layers.10.mlp.gate.gate": {
51
+ "group_size": 64,
52
+ "bits": 8
53
+ },
54
+ "language_model.model.layers.11.mlp.gate.gate": {
55
+ "group_size": 64,
56
+ "bits": 8
57
+ },
58
+ "language_model.model.layers.12.mlp.gate.gate": {
59
+ "group_size": 64,
60
+ "bits": 8
61
+ },
62
+ "language_model.model.layers.13.mlp.gate.gate": {
63
+ "group_size": 64,
64
+ "bits": 8
65
+ },
66
+ "language_model.model.layers.14.mlp.gate.gate": {
67
+ "group_size": 64,
68
+ "bits": 8
69
+ },
70
+ "language_model.model.layers.15.mlp.gate.gate": {
71
+ "group_size": 64,
72
+ "bits": 8
73
+ },
74
+ "language_model.model.layers.16.mlp.gate.gate": {
75
+ "group_size": 64,
76
+ "bits": 8
77
+ },
78
+ "language_model.model.layers.17.mlp.gate.gate": {
79
+ "group_size": 64,
80
+ "bits": 8
81
+ },
82
+ "language_model.model.layers.18.mlp.gate.gate": {
83
+ "group_size": 64,
84
+ "bits": 8
85
+ },
86
+ "language_model.model.layers.19.mlp.gate.gate": {
87
+ "group_size": 64,
88
+ "bits": 8
89
+ },
90
+ "language_model.model.layers.20.mlp.gate.gate": {
91
+ "group_size": 64,
92
+ "bits": 8
93
+ },
94
+ "language_model.model.layers.21.mlp.gate.gate": {
95
+ "group_size": 64,
96
+ "bits": 8
97
+ },
98
+ "language_model.model.layers.22.mlp.gate.gate": {
99
+ "group_size": 64,
100
+ "bits": 8
101
+ },
102
+ "language_model.model.layers.23.mlp.gate.gate": {
103
+ "group_size": 64,
104
+ "bits": 8
105
+ },
106
+ "language_model.model.layers.24.mlp.gate.gate": {
107
+ "group_size": 64,
108
+ "bits": 8
109
+ },
110
+ "language_model.model.layers.25.mlp.gate.gate": {
111
+ "group_size": 64,
112
+ "bits": 8
113
+ },
114
+ "language_model.model.layers.26.mlp.gate.gate": {
115
+ "group_size": 64,
116
+ "bits": 8
117
+ },
118
+ "language_model.model.layers.27.mlp.gate.gate": {
119
+ "group_size": 64,
120
+ "bits": 8
121
+ },
122
+ "language_model.model.layers.28.mlp.gate.gate": {
123
+ "group_size": 64,
124
+ "bits": 8
125
+ },
126
+ "language_model.model.layers.29.mlp.gate.gate": {
127
+ "group_size": 64,
128
+ "bits": 8
129
+ },
130
+ "language_model.model.layers.30.mlp.gate.gate": {
131
+ "group_size": 64,
132
+ "bits": 8
133
+ },
134
+ "language_model.model.layers.31.mlp.gate.gate": {
135
+ "group_size": 64,
136
+ "bits": 8
137
+ },
138
+ "language_model.model.layers.32.mlp.gate.gate": {
139
+ "group_size": 64,
140
+ "bits": 8
141
+ },
142
+ "language_model.model.layers.33.mlp.gate.gate": {
143
+ "group_size": 64,
144
+ "bits": 8
145
+ },
146
+ "language_model.model.layers.34.mlp.gate.gate": {
147
+ "group_size": 64,
148
+ "bits": 8
149
+ },
150
+ "language_model.model.layers.35.mlp.gate.gate": {
151
+ "group_size": 64,
152
+ "bits": 8
153
+ },
154
+ "language_model.model.layers.36.mlp.gate.gate": {
155
+ "group_size": 64,
156
+ "bits": 8
157
+ },
158
+ "language_model.model.layers.37.mlp.gate.gate": {
159
+ "group_size": 64,
160
+ "bits": 8
161
+ },
162
+ "language_model.model.layers.38.mlp.gate.gate": {
163
+ "group_size": 64,
164
+ "bits": 8
165
+ },
166
+ "language_model.model.layers.39.mlp.gate.gate": {
167
+ "group_size": 64,
168
+ "bits": 8
169
+ },
170
+ "language_model.model.layers.40.mlp.gate.gate": {
171
+ "group_size": 64,
172
+ "bits": 8
173
+ },
174
+ "language_model.model.layers.41.mlp.gate.gate": {
175
+ "group_size": 64,
176
+ "bits": 8
177
+ },
178
+ "language_model.model.layers.42.mlp.gate.gate": {
179
+ "group_size": 64,
180
+ "bits": 8
181
+ },
182
+ "language_model.model.layers.43.mlp.gate.gate": {
183
+ "group_size": 64,
184
+ "bits": 8
185
+ },
186
+ "language_model.model.layers.44.mlp.gate.gate": {
187
+ "group_size": 64,
188
+ "bits": 8
189
+ }
190
+ },
191
+ "quantization_config": {
192
+ "group_size": 64,
193
+ "bits": 8,
194
+ "mode": "affine",
195
+ "language_model.model.layers.3.mlp.gate.gate": {
196
+ "group_size": 64,
197
+ "bits": 8
198
+ },
199
+ "language_model.model.layers.4.mlp.gate.gate": {
200
+ "group_size": 64,
201
+ "bits": 8
202
+ },
203
+ "language_model.model.layers.5.mlp.gate.gate": {
204
+ "group_size": 64,
205
+ "bits": 8
206
+ },
207
+ "language_model.model.layers.6.mlp.gate.gate": {
208
+ "group_size": 64,
209
+ "bits": 8
210
+ },
211
+ "language_model.model.layers.7.mlp.gate.gate": {
212
+ "group_size": 64,
213
+ "bits": 8
214
+ },
215
+ "language_model.model.layers.8.mlp.gate.gate": {
216
+ "group_size": 64,
217
+ "bits": 8
218
+ },
219
+ "language_model.model.layers.9.mlp.gate.gate": {
220
+ "group_size": 64,
221
+ "bits": 8
222
+ },
223
+ "language_model.model.layers.10.mlp.gate.gate": {
224
+ "group_size": 64,
225
+ "bits": 8
226
+ },
227
+ "language_model.model.layers.11.mlp.gate.gate": {
228
+ "group_size": 64,
229
+ "bits": 8
230
+ },
231
+ "language_model.model.layers.12.mlp.gate.gate": {
232
+ "group_size": 64,
233
+ "bits": 8
234
+ },
235
+ "language_model.model.layers.13.mlp.gate.gate": {
236
+ "group_size": 64,
237
+ "bits": 8
238
+ },
239
+ "language_model.model.layers.14.mlp.gate.gate": {
240
+ "group_size": 64,
241
+ "bits": 8
242
+ },
243
+ "language_model.model.layers.15.mlp.gate.gate": {
244
+ "group_size": 64,
245
+ "bits": 8
246
+ },
247
+ "language_model.model.layers.16.mlp.gate.gate": {
248
+ "group_size": 64,
249
+ "bits": 8
250
+ },
251
+ "language_model.model.layers.17.mlp.gate.gate": {
252
+ "group_size": 64,
253
+ "bits": 8
254
+ },
255
+ "language_model.model.layers.18.mlp.gate.gate": {
256
+ "group_size": 64,
257
+ "bits": 8
258
+ },
259
+ "language_model.model.layers.19.mlp.gate.gate": {
260
+ "group_size": 64,
261
+ "bits": 8
262
+ },
263
+ "language_model.model.layers.20.mlp.gate.gate": {
264
+ "group_size": 64,
265
+ "bits": 8
266
+ },
267
+ "language_model.model.layers.21.mlp.gate.gate": {
268
+ "group_size": 64,
269
+ "bits": 8
270
+ },
271
+ "language_model.model.layers.22.mlp.gate.gate": {
272
+ "group_size": 64,
273
+ "bits": 8
274
+ },
275
+ "language_model.model.layers.23.mlp.gate.gate": {
276
+ "group_size": 64,
277
+ "bits": 8
278
+ },
279
+ "language_model.model.layers.24.mlp.gate.gate": {
280
+ "group_size": 64,
281
+ "bits": 8
282
+ },
283
+ "language_model.model.layers.25.mlp.gate.gate": {
284
+ "group_size": 64,
285
+ "bits": 8
286
+ },
287
+ "language_model.model.layers.26.mlp.gate.gate": {
288
+ "group_size": 64,
289
+ "bits": 8
290
+ },
291
+ "language_model.model.layers.27.mlp.gate.gate": {
292
+ "group_size": 64,
293
+ "bits": 8
294
+ },
295
+ "language_model.model.layers.28.mlp.gate.gate": {
296
+ "group_size": 64,
297
+ "bits": 8
298
+ },
299
+ "language_model.model.layers.29.mlp.gate.gate": {
300
+ "group_size": 64,
301
+ "bits": 8
302
+ },
303
+ "language_model.model.layers.30.mlp.gate.gate": {
304
+ "group_size": 64,
305
+ "bits": 8
306
+ },
307
+ "language_model.model.layers.31.mlp.gate.gate": {
308
+ "group_size": 64,
309
+ "bits": 8
310
+ },
311
+ "language_model.model.layers.32.mlp.gate.gate": {
312
+ "group_size": 64,
313
+ "bits": 8
314
+ },
315
+ "language_model.model.layers.33.mlp.gate.gate": {
316
+ "group_size": 64,
317
+ "bits": 8
318
+ },
319
+ "language_model.model.layers.34.mlp.gate.gate": {
320
+ "group_size": 64,
321
+ "bits": 8
322
+ },
323
+ "language_model.model.layers.35.mlp.gate.gate": {
324
+ "group_size": 64,
325
+ "bits": 8
326
+ },
327
+ "language_model.model.layers.36.mlp.gate.gate": {
328
+ "group_size": 64,
329
+ "bits": 8
330
+ },
331
+ "language_model.model.layers.37.mlp.gate.gate": {
332
+ "group_size": 64,
333
+ "bits": 8
334
+ },
335
+ "language_model.model.layers.38.mlp.gate.gate": {
336
+ "group_size": 64,
337
+ "bits": 8
338
+ },
339
+ "language_model.model.layers.39.mlp.gate.gate": {
340
+ "group_size": 64,
341
+ "bits": 8
342
+ },
343
+ "language_model.model.layers.40.mlp.gate.gate": {
344
+ "group_size": 64,
345
+ "bits": 8
346
+ },
347
+ "language_model.model.layers.41.mlp.gate.gate": {
348
+ "group_size": 64,
349
+ "bits": 8
350
+ },
351
+ "language_model.model.layers.42.mlp.gate.gate": {
352
+ "group_size": 64,
353
+ "bits": 8
354
+ },
355
+ "language_model.model.layers.43.mlp.gate.gate": {
356
+ "group_size": 64,
357
+ "bits": 8
358
+ },
359
+ "language_model.model.layers.44.mlp.gate.gate": {
360
+ "group_size": 64,
361
+ "bits": 8
362
+ }
363
+ },
364
+ "text_config": {
365
+ "architectures": [
366
+ "Step3p5ForCausalLM"
367
+ ],
368
+ "rope_scaling": {
369
+ "rope_type": "llama3",
370
+ "factor": 2.0,
371
+ "original_max_position_embeddings": 131072,
372
+ "low_freq_factor": 1.0,
373
+ "high_freq_factor": 32.0
374
+ },
375
+ "yarn_only_types": [
376
+ "full_attention"
377
+ ],
378
+ "model_type": "step3p5",
379
+ "hidden_size": 4096,
380
+ "intermediate_size": 11264,
381
+ "num_hidden_layers": 45,
382
+ "max_seq_len": 262144,
383
+ "max_position_embeddings": 262144,
384
+ "vocab_size": 128896,
385
+ "torch_dtype": "bfloat16",
386
+ "use_qk_norm": false,
387
+ "moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
388
+ "use_mfa": false,
389
+ "num_attention_heads": 64,
390
+ "num_attention_groups": 8,
391
+ "head_dim": 128,
392
+ "use_moe": true,
393
+ "moe_num_experts": 288,
394
+ "moe_top_k": 8,
395
+ "moe_intermediate_size": 1280,
396
+ "share_expert_dim": 1280,
397
+ "moe_layer_offset": 0,
398
+ "moe_every_n_layer": 1,
399
+ "norm_expert_weight": true,
400
+ "moe_router_activation": "sigmoid",
401
+ "moe_router_scaling_factor": 3.0,
402
+ "att_impl_type": "GQA",
403
+ "num_nextn_predict_layers": 3,
404
+ "rope_theta": [
405
+ 5000000.0,
406
+ 10000.0,
407
+ 10000.0,
408
+ 10000.0,
409
+ 5000000.0,
410
+ 10000.0,
411
+ 10000.0,
412
+ 10000.0,
413
+ 5000000.0,
414
+ 10000.0,
415
+ 10000.0,
416
+ 10000.0,
417
+ 5000000.0,
418
+ 10000.0,
419
+ 10000.0,
420
+ 10000.0,
421
+ 5000000.0,
422
+ 10000.0,
423
+ 10000.0,
424
+ 10000.0,
425
+ 5000000.0,
426
+ 10000.0,
427
+ 10000.0,
428
+ 10000.0,
429
+ 5000000.0,
430
+ 10000.0,
431
+ 10000.0,
432
+ 10000.0,
433
+ 5000000.0,
434
+ 10000.0,
435
+ 10000.0,
436
+ 10000.0,
437
+ 5000000.0,
438
+ 10000.0,
439
+ 10000.0,
440
+ 10000.0,
441
+ 5000000.0,
442
+ 10000.0,
443
+ 10000.0,
444
+ 10000.0,
445
+ 5000000.0,
446
+ 10000.0,
447
+ 10000.0,
448
+ 10000.0,
449
+ 5000000.0,
450
+ 10000.0,
451
+ 10000.0,
452
+ 10000.0
453
+ ],
454
+ "use_head_wise_attn_gate": true,
455
+ "sliding_window": 512,
456
+ "use_moe_router_bias": true,
457
+ "need_fp32_gate": true,
458
+ "sink": false,
459
+ "layer_types": [
460
+ "full_attention",
461
+ "sliding_attention",
462
+ "sliding_attention",
463
+ "sliding_attention",
464
+ "full_attention",
465
+ "sliding_attention",
466
+ "sliding_attention",
467
+ "sliding_attention",
468
+ "full_attention",
469
+ "sliding_attention",
470
+ "sliding_attention",
471
+ "sliding_attention",
472
+ "full_attention",
473
+ "sliding_attention",
474
+ "sliding_attention",
475
+ "sliding_attention",
476
+ "full_attention",
477
+ "sliding_attention",
478
+ "sliding_attention",
479
+ "sliding_attention",
480
+ "full_attention",
481
+ "sliding_attention",
482
+ "sliding_attention",
483
+ "sliding_attention",
484
+ "full_attention",
485
+ "sliding_attention",
486
+ "sliding_attention",
487
+ "sliding_attention",
488
+ "full_attention",
489
+ "sliding_attention",
490
+ "sliding_attention",
491
+ "sliding_attention",
492
+ "full_attention",
493
+ "sliding_attention",
494
+ "sliding_attention",
495
+ "sliding_attention",
496
+ "full_attention",
497
+ "sliding_attention",
498
+ "sliding_attention",
499
+ "sliding_attention",
500
+ "full_attention",
501
+ "sliding_attention",
502
+ "sliding_attention",
503
+ "sliding_attention",
504
+ "full_attention",
505
+ "sliding_attention",
506
+ "sliding_attention",
507
+ "sliding_attention"
508
+ ],
509
+ "use_rope_layers": [],
510
+ "partial_rotary_factors": [
511
+ 0.5,
512
+ 1.0,
513
+ 1.0,
514
+ 1.0,
515
+ 0.5,
516
+ 1.0,
517
+ 1.0,
518
+ 1.0,
519
+ 0.5,
520
+ 1.0,
521
+ 1.0,
522
+ 1.0,
523
+ 0.5,
524
+ 1.0,
525
+ 1.0,
526
+ 1.0,
527
+ 0.5,
528
+ 1.0,
529
+ 1.0,
530
+ 1.0,
531
+ 0.5,
532
+ 1.0,
533
+ 1.0,
534
+ 1.0,
535
+ 0.5,
536
+ 1.0,
537
+ 1.0,
538
+ 1.0,
539
+ 0.5,
540
+ 1.0,
541
+ 1.0,
542
+ 1.0,
543
+ 0.5,
544
+ 1.0,
545
+ 1.0,
546
+ 1.0,
547
+ 0.5,
548
+ 1.0,
549
+ 1.0,
550
+ 1.0,
551
+ 0.5,
552
+ 1.0,
553
+ 1.0,
554
+ 1.0,
555
+ 0.5,
556
+ 1.0,
557
+ 1.0,
558
+ 1.0
559
+ ],
560
+ "eos_token_id": [
561
+ 1,
562
+ 2,
563
+ 128007
564
+ ],
565
+ "bos_token_id": 0,
566
+ "attention_other_setting": {
567
+ "attention_type": "sliding_attention",
568
+ "num_attention_heads": 96,
569
+ "num_attention_groups": 8,
570
+ "head_dim": 128,
571
+ "true_head_dim": 128
572
+ },
573
+ "swiglu_limits": [
574
+ 0.0,
575
+ 0.0,
576
+ 0.0,
577
+ 0.0,
578
+ 0.0,
579
+ 0.0,
580
+ 0.0,
581
+ 0.0,
582
+ 0.0,
583
+ 0.0,
584
+ 0.0,
585
+ 0.0,
586
+ 0.0,
587
+ 0.0,
588
+ 0.0,
589
+ 0.0,
590
+ 0.0,
591
+ 0.0,
592
+ 0.0,
593
+ 0.0,
594
+ 0.0,
595
+ 0.0,
596
+ 0.0,
597
+ 0.0,
598
+ 0.0,
599
+ 0.0,
600
+ 0.0,
601
+ 0.0,
602
+ 0.0,
603
+ 0.0,
604
+ 0.0,
605
+ 0.0,
606
+ 0.0,
607
+ 0.0,
608
+ 0.0,
609
+ 0.0,
610
+ 0.0,
611
+ 0.0,
612
+ 0.0,
613
+ 0.0,
614
+ 0.0,
615
+ 0.0,
616
+ 0.0,
617
+ 7,
618
+ 7,
619
+ 0.0,
620
+ 0.0,
621
+ 0.0
622
+ ],
623
+ "swiglu_limits_shared": [
624
+ 0.0,
625
+ 0.0,
626
+ 0.0,
627
+ 0.0,
628
+ 0.0,
629
+ 0.0,
630
+ 0.0,
631
+ 0.0,
632
+ 0.0,
633
+ 0.0,
634
+ 0.0,
635
+ 0.0,
636
+ 0.0,
637
+ 0.0,
638
+ 0.0,
639
+ 0.0,
640
+ 0.0,
641
+ 0.0,
642
+ 0.0,
643
+ 0.0,
644
+ 0.0,
645
+ 0.0,
646
+ 0.0,
647
+ 0.0,
648
+ 0.0,
649
+ 0.0,
650
+ 0.0,
651
+ 0.0,
652
+ 0.0,
653
+ 0.0,
654
+ 0.0,
655
+ 0.0,
656
+ 0.0,
657
+ 0.0,
658
+ 0.0,
659
+ 0.0,
660
+ 0.0,
661
+ 0.0,
662
+ 0.0,
663
+ 0.0,
664
+ 0.0,
665
+ 0.0,
666
+ 0.0,
667
+ 16,
668
+ 16,
669
+ 0.0,
670
+ 0.0,
671
+ 0.0
672
+ ]
673
+ },
674
+ "understand_projector_stride": 2,
675
+ "use_im_start_end": "true",
676
+ "vision_select_layer": -1
677
+ }
configuration_step3p7.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Sequence, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class StepRoboticsVisionEncoderConfig(PretrainedConfig):
6
+ model_type = "perception_encoder"
7
+
8
+ def __init__(
9
+ self,
10
+ width=1536,
11
+ layers=47,
12
+ heads=16,
13
+ num_channels=3,
14
+ image_size=728,
15
+ mlp_ratio = 8960/1536,
16
+ patch_size=14,
17
+ hidden_act="quick_gelu",
18
+ layer_norm_eps=1e-5,
19
+ ues_cls_token=False,
20
+ use_cls_token: Optional[bool] = None,
21
+ use_ln_pre=True,
22
+ use_ln_post=False,
23
+ use_abs_posemb=True,
24
+ use_rope2d=True,
25
+ ls_init_value=0.1,
26
+ **kwargs,
27
+ ):
28
+ self.width = width
29
+ self.layers = layers
30
+ self.heads = heads
31
+ self.num_channels = num_channels
32
+ self.patch_size = patch_size
33
+ self.image_size = image_size
34
+ self.mlp_ratio = mlp_ratio
35
+ self.layer_norm_eps = layer_norm_eps
36
+ self.hidden_act = hidden_act
37
+ if use_cls_token is None:
38
+ use_cls_token = ues_cls_token
39
+ self.ues_cls_token = use_cls_token
40
+ self.use_cls_token = use_cls_token
41
+ self.use_ln_pre = use_ln_pre
42
+ self.ls_init_value = ls_init_value
43
+ self.use_ln_post = use_ln_post
44
+ self.use_abs_posemb = use_abs_posemb
45
+ self.use_rope2d = use_rope2d
46
+ super().__init__(**kwargs)
47
+
48
+
49
+ class Step3p7TextConfig(PretrainedConfig):
50
+ model_type = "step3p5"
51
+ architectures = ["Step3p5ForCausalLM"]
52
+
53
+ def __init__(
54
+ self,
55
+ hidden_size: int = 4096,
56
+ intermediate_size: int = 11264,
57
+ num_attention_heads: int = 64,
58
+ num_attention_groups: int = 8,
59
+ num_hidden_layers: int = 45,
60
+ max_seq_len: int = 128000,
61
+ vocab_size: int = 128815,
62
+ rms_norm_eps: float = 1e-5,
63
+ moe_intermediate_size: int = 1280,
64
+ moe_num_experts: int = 288,
65
+ moe_top_k: int = 8,
66
+ rope_theta: float = 10000,
67
+ rope_scaling: Optional[dict[str, Any]] = None,
68
+ max_position_embeddings: int = 128000,
69
+ share_expert_dims: int = 1280,
70
+ share_expert_dim: Optional[int] = None,
71
+ head_dim: int = 128,
72
+ norm_expert_weight: bool = True,
73
+ layer_types: list[str] = None,
74
+ sliding_window: Optional[int] = None,
75
+ pad_token_id: int = 1,
76
+ attention_dropout: float = 0.0,
77
+ use_head_wise_attn_gate: bool = False,
78
+ use_moe_router_bias: bool = False,
79
+ moe_router_activation: str = "softmax",
80
+ moe_router_scaling_factor: float = 1.0,
81
+ need_fp32_gate: bool = False,
82
+ attention_other_setting: Optional[dict[str, Any]] = None,
83
+ swiglu_limits: Optional[list[Optional[float]]] = None,
84
+ swiglu_limits_shared: Optional[list[Optional[float]]] = None,
85
+ use_rope_layers: Optional[list[bool]] = None,
86
+ yarn_only_types: Optional[list[str]] = None,
87
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
88
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
89
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
90
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
91
+ **kwargs,
92
+ ) -> None:
93
+ torch_dtype = kwargs.get("torch_dtype")
94
+ trim_layer_types = _normalize_per_layer_values(layer_types,
95
+ num_hidden_layers)
96
+ if isinstance(rope_scaling, dict):
97
+ rope_scaling = dict(rope_scaling)
98
+ if share_expert_dim is None:
99
+ share_expert_dim = share_expert_dims
100
+ self.hidden_size = hidden_size
101
+ self.intermediate_size = intermediate_size
102
+ self.num_attention_heads = num_attention_heads
103
+ self.num_attention_groups = num_attention_groups
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.max_seq_len = max_seq_len
106
+ self.vocab_size = vocab_size
107
+ self.rms_norm_eps = rms_norm_eps
108
+ self.moe_intermediate_size = moe_intermediate_size
109
+ self.moe_num_experts = moe_num_experts
110
+ self.moe_top_k = moe_top_k
111
+ self.rope_theta = rope_theta
112
+ self.rope_scaling = rope_scaling
113
+ self.max_position_embeddings = max_position_embeddings
114
+ self.share_expert_dim = share_expert_dim
115
+ self.head_dim = head_dim
116
+ self.norm_expert_weight = norm_expert_weight
117
+ self.moe_layers_enum = moe_layers_enum
118
+ self.layer_types = trim_layer_types
119
+ self.sliding_window = sliding_window
120
+ self.pad_token_id = pad_token_id
121
+ self.attention_dropout = attention_dropout
122
+ self.use_head_wise_attn_gate = use_head_wise_attn_gate
123
+ self.use_moe_router_bias = use_moe_router_bias
124
+ self.moe_router_activation = moe_router_activation
125
+ self.moe_router_scaling_factor = moe_router_scaling_factor
126
+ self.need_fp32_gate = need_fp32_gate
127
+ self.attention_other_setting = attention_other_setting
128
+ self.swiglu_limits = swiglu_limits
129
+ self.swiglu_limits_shared = swiglu_limits_shared
130
+ self.use_rope_layers = use_rope_layers
131
+ self.yarn_only_types = yarn_only_types
132
+ super().__init__(**kwargs)
133
+ if torch_dtype is not None:
134
+ self.torch_dtype = torch_dtype
135
+ self.layer_types = layer_types
136
+
137
+ def to_dict(self):
138
+ output = super().to_dict()
139
+ torch_dtype = getattr(self, "torch_dtype", None)
140
+ if torch_dtype is not None:
141
+ output["torch_dtype"] = torch_dtype
142
+ return output
143
+
144
+
145
+ def _normalize_per_layer_values(
146
+ values: Optional[Sequence[Any]],
147
+ num_hidden_layers: int,
148
+ ) -> Optional[list[Any]]:
149
+ if values is None:
150
+ return None
151
+ normalized = list(values)
152
+ if not normalized:
153
+ return normalized
154
+ if len(normalized) < num_hidden_layers:
155
+ normalized.extend([normalized[-1]] *
156
+ (num_hidden_layers - len(normalized)))
157
+ # Some checkpoints keep MTP/spec layer entries after the decoder layers.
158
+ # This config only builds num_hidden_layers decoder layers, and HF strict
159
+ # validation requires per-layer fields to match that decoder count.
160
+ return normalized[:num_hidden_layers]
161
+
162
+ class Step3p7Config(PretrainedConfig):
163
+ # This loader is a compatibility shim for original Step VL checkpoints
164
+ # whose top-level config model_type is `step3p7`.
165
+ model_type = "step3p7"
166
+
167
+ def __init__(
168
+ self,
169
+ vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
170
+ text_config: Optional[Union[dict, Step3p7TextConfig]] = None,
171
+ understand_projector_stride: int = 2,
172
+ projector_bias: bool = False,
173
+ image_token_id: int = 151679,
174
+ **kwargs,
175
+ ) -> None:
176
+ shared_rope_scaling = kwargs.get("rope_scaling")
177
+ if isinstance(shared_rope_scaling, dict):
178
+ shared_rope_scaling = dict(shared_rope_scaling)
179
+
180
+ if vision_config is None:
181
+ vision_config = StepRoboticsVisionEncoderConfig()
182
+ elif isinstance(vision_config, dict):
183
+ vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
184
+ self.vision_config = vision_config
185
+
186
+ if text_config is None:
187
+ text_config = Step3p7TextConfig(rope_scaling=shared_rope_scaling)
188
+ elif isinstance(text_config, dict):
189
+ text_config = dict(text_config)
190
+ if shared_rope_scaling is not None and "rope_scaling" not in text_config:
191
+ text_config["rope_scaling"] = shared_rope_scaling
192
+ text_config = Step3p7TextConfig(**text_config)
193
+ elif shared_rope_scaling is not None and text_config.rope_scaling is None:
194
+ text_config.rope_scaling = dict(shared_rope_scaling)
195
+ self.text_config = text_config
196
+
197
+ rope_scaling = kwargs.get("rope_scaling")
198
+ if isinstance(rope_scaling, dict):
199
+ kwargs["rope_scaling"] = dict(rope_scaling)
200
+
201
+ self.understand_projector_stride = understand_projector_stride
202
+ self.projector_bias = projector_bias
203
+ self.hidden_size = text_config.hidden_size
204
+ self.max_position_embeddings = text_config.max_position_embeddings
205
+ self.image_token_id = image_token_id
206
+ # Help Auto classes find the correct implementation when saving/loading.
207
+ super().__init__(**kwargs)
model-00001-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4f76cf00710f7bc94fa22f35b8c0f02b0629a7cdf4a4be57b2961a885633455
3
+ size 4641457723
model-00002-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11935fcdf99e3801534c21ed384465ecf1ff1aa3fd25258293fd5f790c04ee98
3
+ size 4911447063
model-00003-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1777c6fc299f32eb73c4bd0a57fa4662490f1e22a24477404cc2a0c4f41c75b6
3
+ size 4947237906
model-00004-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b75c3605604da29ee564ea9f0f9b9411716c286aa4c615aa611adbf421dd1ed
3
+ size 4947237914
model-00005-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b713ede9be5a34d1ed04c5879d780c164c35890ecb290b37ea08ebbcde87ad0
3
+ size 4947237894
model-00006-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b018696af7e4979971e707fe661f8f5de1efb92af2b0a07fefb727acfc85831
3
+ size 4911447063
model-00007-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e14c8afa89f67fa932aba497d02811cc9ff5144e9641d4f6eb0ae193804f013
3
+ size 4947237910
model-00008-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b67fcbd2b5f39908ad321afa370e564bacc0652b306fad4bb67197283e37f9f
3
+ size 4947237941
model-00009-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22f41ab3d20ad70eb118e389005e329afa95942efeb63abf75197d9c833ff5af
3
+ size 4947237947
model-00011-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:098818ce9abcb4ae50fc15cf7751ae900d98f29b9e1a248bad920a072394dfad
3
+ size 4947237955
model-00012-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a029a7a29122f7e74e56a98453fc4bb99ae14c93955d82f85df93966a0e8696d
3
+ size 4947237951
model-00013-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ad6a505bee54199514b7b45893c6c0cc0aa62b07fae0225bdbd2c11b9daedd1
3
+ size 4947237951
model-00014-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:549d437e3fc9f6e3c88d99efc37dc00fa8b789ab380701ed7d26e2fd508316aa
3
+ size 4911447100
model-00016-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1cc69b57e34c51137d19ea9e06497c52ca4abc29ed0bfdf4d5d9dc360a61013
3
+ size 4947237931
model-00017-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:455b4e547fd1011e9f68a6e30788099e9fbeb00efa0979563432c1063ff91db9
3
+ size 4947237951
model-00018-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aec432b46a946bb5d09d4fbfa6b5487ef00b489183392c2960891b83bff774c7
3
+ size 4911447104
model-00019-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f9d0a735291662a45e934df5cd08ddd6409765fbef559d7d02d502011452cbe
3
+ size 4947237889
model-00020-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:385be52b69b548947cbdaed618390e4bbcfbca95819d20254d4bc168b1a30d13
3
+ size 4947237935
model-00021-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e804766c1b0277a723001cd81a02d8f7cb996fe7f5ddc6d737c58c717fe736b2
3
+ size 4947237951
model-00022-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d426927ed7d481a789d9987a4509290429ea06af573f4e56666e17744bd2f5c2
3
+ size 4911447100
model-00023-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa87e502abd510931ee189faee65a369d775f5a2cfe0ea0af88840fe0ab15cb7
3
+ size 4947237951
model-00024-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f07d666ba5abf6e58091e0b7df6ac5144fca51a37952078888f8da5cde82e5ea
3
+ size 4947237937
model-00025-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a40581350005a6de1b9f87b83900fe51654ae02b4179e6ced23eed79f021d5cd
3
+ size 4947237955
model-00026-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b405f01f0d66696b7b27b301b853b9fd93d6c8208b134c21e693baf465eaf7e1
3
+ size 4911447096
model-00027-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8a233abaca4172de2f4e6854e63b5c925313e3d6595219e8b6fac35e80ef558
3
+ size 4947237951
model-00029-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d886c2fe5317e315aca114f7afb680959eba9554077c5cdd2b4949aa0e69ac1
3
+ size 4947237891
model-00030-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:666e5ddc4f5ec7c06060d5eb5200da12a4b9a0f6a3d230e948dd1bf0315a8c05
3
+ size 4911447096
model-00031-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99deea271c6332765dfff9038c43c5424855ef558bf7e28fc44ecf81a761393f
3
+ size 4947237951
model-00032-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b83b1e8f2b534a6178f3a091d72b020928f622dc79cd1a6a7a4e27fd28e57374
3
+ size 4947237947
model-00033-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1473a8308b19decb10239436fdbbd9ee68a3f171a582a2ed10f4c6838721122b
3
+ size 4947237947
model-00034-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96710cd297545249e6295fa1facecd0eedfc6b23b5cfea25292ddc3f65414422
3
+ size 4911447100
model-00035-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb8e15c926bca2aa15288e8fd9fe77348e25890ed69684dd69330f7e6c80ceea
3
+ size 4947237951
model-00036-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97f7133d01017beddb2451b8905ef21f9716607eda2c5e664989f7b82c63aa65
3
+ size 4947237909
model-00037-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7742ab540626ef393055e0b61a74f86c0342a37e527c1d1fc73295908e23f43e
3
+ size 4947237879
model-00038-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fbc10cca62227b9af89b9e8fb817da7777e6abf9c68cc08feb8742f5af4586a
3
+ size 4911447092
model-00039-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f89c343d6aaa8bf5f6b229ea1a72108e055c7ee1f2f6edf99996bf11e04e35a0
3
+ size 4947237897
model-00040-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:028d0ce4d04ed883844e1821f10f11acd104b61e9db070d8f538513a50288846
3
+ size 4947237933
model-00041-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6296b2954c4d38fdebd0df64a4a1e38be34d08b62f694e67734771bac151784
3
+ size 4947237947
model-00042-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f45f6e3f0bfc8fd6ce761f7246dd1cf943b393e61bca5fd877cf6730216fd141
3
+ size 4911447092
model-00043-of-00043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36fe1bd2eca255d6991989538a33e99d78a44e8c18e80574ee89fa01c9cb9233
3
+ size 2182015326
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_step3p7.py ADDED
@@ -0,0 +1,1405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Callable, Literal, Optional, Tuple, TypedDict, Union
19
+
20
+ from PIL import Image
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.masking_utils import (
29
+ create_causal_mask,
30
+ create_sliding_window_causal_mask,
31
+ )
32
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_layers import GradientCheckpointingLayer
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
39
+ from .configuration_step3p7 import Step3p7Config, Step3p7TextConfig
40
+ from .vision_encoder import StepRoboticsVisionEncoder
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _MASK_INPUT_EMBEDS_ARG = (
45
+ "inputs_embeds"
46
+ if "inputs_embeds" in inspect.signature(create_causal_mask).parameters
47
+ else "input_embeds"
48
+ )
49
+
50
+ __all__ = [
51
+ "Step3p7Model",
52
+ ]
53
+
54
+
55
+ class StepVLImagePixelInputs(TypedDict):
56
+ type: Literal["pixel_values"]
57
+ pixel_values: torch.Tensor
58
+ patch_pixel_values: Optional[torch.Tensor]
59
+ num_patches: list[int]
60
+
61
+
62
+ class StepVLImageEmbeddingInputs(TypedDict):
63
+ type: Literal["image_embeds"]
64
+ image_embeds: torch.Tensor
65
+
66
+
67
+ StepVLImageInputs = Union[StepVLImagePixelInputs, StepVLImageEmbeddingInputs]
68
+
69
+
70
+ def _flatten_embeddings(embeddings) -> torch.Tensor:
71
+ """
72
+ Recursively flattens and concatenates NestedTensors on all but the last
73
+ dimension.
74
+ """
75
+
76
+ if isinstance(embeddings, torch.Tensor):
77
+ # Flatten all but the last dimension.
78
+ return embeddings.flatten(0, -2)
79
+
80
+ return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
81
+
82
+ def _embedding_count_expression(embeddings) -> str:
83
+ """
84
+ Constructs a debugging representation of the number of embeddings in the
85
+ NestedTensors.
86
+ """
87
+
88
+ if isinstance(embeddings, torch.Tensor):
89
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
90
+
91
+ return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
92
+
93
+
94
+ def _merge_multimodal_embeddings(
95
+ inputs_embeds: torch.Tensor,
96
+ is_multimodal: torch.Tensor,
97
+ multimodal_embeddings,
98
+ ) -> torch.Tensor:
99
+ """
100
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
101
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
102
+ ``input_ids``.
103
+ Note:
104
+ This updates ``inputs_embeds`` in place.
105
+ """
106
+ num_expected_tokens = is_multimodal.sum().item()
107
+ assert isinstance(num_expected_tokens, int)
108
+
109
+ flattened = _flatten_embeddings(multimodal_embeddings)
110
+ if flattened.shape[0] != num_expected_tokens:
111
+ expr = _embedding_count_expression(multimodal_embeddings)
112
+ raise ValueError(
113
+ f"Attempted to assign {expr} = {flattened.shape[0]} "
114
+ f"multimodal tokens to {num_expected_tokens} placeholders"
115
+ )
116
+
117
+ is_multimodal = is_multimodal.to(inputs_embeds.device)
118
+ flattened = flattened.to(inputs_embeds.device)
119
+ inputs_embeds[is_multimodal] = flattened
120
+ return inputs_embeds
121
+
122
+ def merge_multimodal_embeddings(
123
+ input_ids: torch.Tensor,
124
+ inputs_embeds: torch.Tensor,
125
+ multimodal_embeddings,
126
+ placeholder_token_id: Union[int, list[int]],
127
+ ) -> torch.Tensor:
128
+ """
129
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
130
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
131
+ ``input_ids``.
132
+
133
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
134
+ of img_start, img_break, and img_end tokens) when needed: This means
135
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
136
+ their embeddings in ``multimodal_embeddings`` since we need to
137
+ slice-merge instead of individually scattering.
138
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
139
+ - T is text token
140
+ - S is image start token
141
+ - I is image embedding token
142
+ - B is image break token
143
+ - E is image end token.
144
+
145
+ Then the image embeddings (that correspond to I's) from vision encoder
146
+ must be padded with embeddings of S, B, and E in the same order of
147
+ input_ids for a correct embedding merge.
148
+ Note:
149
+ This updates ``inputs_embeds`` in place.
150
+ """
151
+ if isinstance(placeholder_token_id, list):
152
+ placeholder_token_id = torch.tensor(
153
+ placeholder_token_id, device=input_ids.device
154
+ )
155
+ return _merge_multimodal_embeddings(
156
+ inputs_embeds,
157
+ torch.isin(input_ids, placeholder_token_id),
158
+ multimodal_embeddings,
159
+ )
160
+
161
+ return _merge_multimodal_embeddings(
162
+ inputs_embeds,
163
+ (input_ids == placeholder_token_id),
164
+ multimodal_embeddings,
165
+ )
166
+
167
+
168
+ class Step3p7PreTrainedModel(PreTrainedModel):
169
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
170
+ # can load the config instead of failing with a NoneType error.
171
+ config_class = Step3p7Config
172
+ supports_gradient_checkpointing = True
173
+ _skip_keys_device_placement = ["past_key_values"]
174
+ _keys_to_ignore_on_load_unexpected = [
175
+ r"model\.layers\.45\.*",
176
+ r"model\.layers\.46\.*",
177
+ r"model\.layers\.47\.*",
178
+ ]
179
+ _supports_flash_attn = False
180
+ _supports_sdpa = True
181
+ _supports_flex_attn = True
182
+ _supports_static_cache = True
183
+ _supports_attention_backend = True
184
+
185
+ @classmethod
186
+ def from_pretrained(
187
+ cls, pretrained_model_name_or_path, *model_args, **kwargs
188
+ ):
189
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
190
+ if key_mapping is not None and kwargs.get("key_mapping") is None:
191
+ # Transformers only applies checkpoint renaming when key_mapping is
192
+ # passed explicitly; inheriting the class attribute alone is not enough.
193
+ kwargs["key_mapping"] = copy.deepcopy(key_mapping)
194
+ return super().from_pretrained(
195
+ pretrained_model_name_or_path, *model_args, **kwargs
196
+ )
197
+
198
+
199
+ class Step3p7RotaryEmbedding(nn.Module):
200
+ def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
201
+ super().__init__()
202
+ self.layer_idx = layer_idx
203
+ self.max_seq_len_cached = config.max_position_embeddings
204
+ self.original_max_seq_len = config.max_position_embeddings
205
+
206
+ rope_theta = config.rope_theta
207
+ if isinstance(rope_theta, list):
208
+ rope_theta = rope_theta[0 if layer_idx is None else layer_idx]
209
+
210
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
211
+ partial_rotary_factors = getattr(config, "partial_rotary_factors", None)
212
+ if partial_rotary_factors is not None:
213
+ partial_rotary_factor = partial_rotary_factors[
214
+ 0 if layer_idx is None else layer_idx
215
+ ]
216
+
217
+ self.rope_theta = rope_theta
218
+ self.partial_rotary_factor = partial_rotary_factor
219
+
220
+ self.config = copy.copy(config)
221
+ self.config.rope_theta = rope_theta
222
+ self.config.partial_rotary_factor = partial_rotary_factor
223
+
224
+ if config.rope_parameters is not None:
225
+ self.config.rope_parameters = copy.deepcopy(config.rope_parameters)
226
+ self.config.rope_parameters["rope_theta"] = rope_theta
227
+ self.config.rope_parameters["partial_rotary_factor"] = (
228
+ partial_rotary_factor
229
+ )
230
+ self.rope_type = self.config.rope_parameters.get(
231
+ "rope_type", self.config.rope_parameters.get("type")
232
+ )
233
+ else:
234
+ self.rope_type = "default"
235
+
236
+ self.rope_init_fn = self.compute_default_rope_parameters
237
+ if self.rope_type != "default":
238
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
239
+ inv_freq, self.attention_scaling = self.rope_init_fn(
240
+ self.config, device
241
+ )
242
+
243
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
244
+ self.original_inv_freq = self.inv_freq
245
+
246
+ @torch.no_grad()
247
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
248
+ def forward(self, x, position_ids):
249
+ inv_freq_expanded = (
250
+ self.inv_freq[None, :, None]
251
+ .float()
252
+ .expand(position_ids.shape[0], -1, 1)
253
+ .to(x.device)
254
+ )
255
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
256
+
257
+ device_type = (
258
+ x.device.type
259
+ if isinstance(x.device.type, str) and x.device.type != "mps"
260
+ else "cpu"
261
+ )
262
+ with torch.autocast(
263
+ device_type=device_type, enabled=False
264
+ ): # Force float32
265
+ freqs = (
266
+ inv_freq_expanded.float() @ position_ids_expanded.float()
267
+ ).transpose(1, 2)
268
+ emb = torch.cat((freqs, freqs), dim=-1)
269
+ cos = emb.cos() * self.attention_scaling
270
+ sin = emb.sin() * self.attention_scaling
271
+
272
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
273
+
274
+ @staticmethod
275
+ def compute_default_rope_parameters(
276
+ config: Step3p7TextConfig | None = None,
277
+ device: Optional["torch.device"] = None,
278
+ ) -> tuple["torch.Tensor", float]:
279
+ """
280
+ Computes the inverse frequencies according to the original RoPE implementation
281
+ Args:
282
+ config ([`~transformers.PreTrainedConfig`]):
283
+ The model configuration.
284
+ device (`torch.device`):
285
+ The device to use for initialization of the inverse frequencies.
286
+ seq_len (`int`, *optional*):
287
+ The current sequence length. Unused for this type of RoPE.
288
+ Returns:
289
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
290
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
291
+ """
292
+ base = config.rope_theta
293
+ partial_rotary_factor = getattr(
294
+ config, "partial_rotary_factor", 1.0
295
+ )
296
+ head_dim = (
297
+ getattr(config, "head_dim", None)
298
+ or config.hidden_size // config.num_attention_heads
299
+ )
300
+ dim = int(head_dim * partial_rotary_factor)
301
+
302
+ attention_factor = 1.0 # Unused in this type of RoPE
303
+
304
+ # Compute the inverse frequencies
305
+ inv_freq = 1.0 / (
306
+ base
307
+ ** (
308
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
309
+ device=device, dtype=torch.float
310
+ )
311
+ / dim
312
+ )
313
+ )
314
+ return inv_freq, attention_factor
315
+
316
+ def rotate_half(x):
317
+ """Rotates half the hidden dims of the input."""
318
+ x1 = x[..., :x.shape[-1] // 2]
319
+ x2 = x[..., x.shape[-1] // 2:]
320
+ return torch.cat((-x2, x1), dim=-1)
321
+
322
+
323
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
324
+ """Applies Rotary Position Embedding to the query and key tensors.
325
+
326
+ Args:
327
+ q (`torch.Tensor`): The query tensor.
328
+ k (`torch.Tensor`): The key tensor.
329
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
330
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
331
+ position_ids (`torch.Tensor`, *optional*):
332
+ Deprecated and unused.
333
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
334
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
335
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
336
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
337
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
338
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
339
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
340
+ Returns:
341
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
342
+ """
343
+ rotary_dim = cos.shape[-1]
344
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
345
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
346
+
347
+ # Apply rotary embeddings on the first half or full tensor
348
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
349
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
350
+
351
+ # Concatenate back to full shape
352
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
353
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
354
+ return q_embed, k_embed
355
+
356
+
357
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
358
+ """
359
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
360
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
361
+ """
362
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
363
+ if n_rep == 1:
364
+ return hidden_states
365
+ hidden_states = hidden_states[:, :, None, :, :].expand(
366
+ batch, num_key_value_heads, n_rep, slen, head_dim
367
+ )
368
+ return hidden_states.reshape(
369
+ batch, num_key_value_heads * n_rep, slen, head_dim
370
+ )
371
+
372
+
373
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward.
374
+ # Llama4 does not cast attention weights to fp32 here.
375
+ def eager_attention_forward(
376
+ module: nn.Module,
377
+ query: torch.Tensor,
378
+ key: torch.Tensor,
379
+ value: torch.Tensor,
380
+ attention_mask: Optional[torch.Tensor],
381
+ scaling: float,
382
+ dropout: float = 0.0,
383
+ **kwargs,
384
+ ):
385
+ key_states = repeat_kv(key, module.num_key_value_groups)
386
+ value_states = repeat_kv(value, module.num_key_value_groups)
387
+ # breakpoint()
388
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
389
+ if attention_mask is not None:
390
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
391
+ attn_weights = attn_weights + causal_mask
392
+
393
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
394
+ attn_weights = nn.functional.dropout(
395
+ attn_weights, p=dropout, training=module.training
396
+ )
397
+ attn_output = torch.matmul(attn_weights, value_states)
398
+ attn_output = attn_output.transpose(1, 2).contiguous()
399
+
400
+ return attn_output, attn_weights
401
+
402
+
403
+ @dataclass
404
+ class Step3p7CausalLMOutputWithPast(ModelOutput):
405
+ r"""
406
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
407
+ Language modeling loss (for next-token prediction).
408
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
409
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
410
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
411
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
412
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
413
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
414
+ `past_key_values` input) to speed up sequential decoding.
415
+ """
416
+
417
+ loss: Optional[torch.FloatTensor] = None
418
+ last_hidden_state: Optional[torch.FloatTensor] = None
419
+ logits: torch.FloatTensor = None
420
+ past_key_values: Optional[list[torch.FloatTensor]] = None
421
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
422
+ attentions: Optional[tuple[torch.FloatTensor]] = None
423
+
424
+
425
+ class Step3p7MLP(nn.Module):
426
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
427
+ super().__init__()
428
+ self.config = config
429
+ self.hidden_size = config.hidden_size
430
+ self.intermediate_size = (
431
+ intermediate_size
432
+ if intermediate_size is not None
433
+ else config.intermediate_size
434
+ )
435
+ self.gate_proj = nn.Linear(self.hidden_size,
436
+ self.intermediate_size,
437
+ bias=False)
438
+ self.up_proj = nn.Linear(self.hidden_size,
439
+ self.intermediate_size,
440
+ bias=False)
441
+ self.down_proj = nn.Linear(self.intermediate_size,
442
+ self.hidden_size,
443
+ bias=False)
444
+ self.act_fn = ACT2FN["silu"]
445
+ self.limit = swiglu_limit
446
+
447
+ def forward(self, x):
448
+ up = self.up_proj(x)
449
+ gate = self.act_fn(self.gate_proj(x))
450
+ if self.limit is not None:
451
+ gate = gate.clamp(min=None, max=self.limit)
452
+ up = up.clamp(min=-self.limit, max=self.limit)
453
+
454
+ return self.down_proj(gate * up)
455
+
456
+
457
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
458
+ renormalize: bool):
459
+ gating_output = gating_output.float()
460
+ gate_prob = torch.sigmoid(gating_output)
461
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
462
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
463
+ expert_topk_weight = topk_prob
464
+ if renormalize:
465
+ expert_topk_weight = expert_topk_weight / torch.sum(
466
+ expert_topk_weight, dim=-1, keepdim=True)
467
+ return expert_topk_weight, indices
468
+
469
+
470
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
471
+ renormalize: bool):
472
+ gating_output = gating_output.float()
473
+ gate_prob = torch.softmax(gating_output, dim=-1)
474
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
475
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
476
+ expert_topk_weight = topk_prob
477
+ if renormalize:
478
+ expert_topk_weight = expert_topk_weight / torch.sum(
479
+ expert_topk_weight, dim=-1, keepdim=True)
480
+ return expert_topk_weight, indices.to(torch.int32)
481
+
482
+
483
+ class MoELinear(nn.Module):
484
+
485
+ def __init__(self, num_experts, in_features, out_features):
486
+ super().__init__()
487
+ self.num_experts = num_experts
488
+ self.in_features = in_features
489
+ self.out_features = out_features
490
+ self.weight = nn.Parameter(
491
+ torch.empty(num_experts, out_features, in_features))
492
+
493
+ def forward(self, x, expert_id):
494
+ x = F.linear(x.float(), self.weight[expert_id].float())
495
+ return x
496
+
497
+
498
+ class Step3p7MoEMLP(nn.Module):
499
+
500
+ def __init__(self, config, swiglu_limit=None):
501
+ super().__init__()
502
+ self.num_experts = config.moe_num_experts
503
+ self.top_k = config.moe_top_k
504
+ self.hidden_size = config.hidden_size
505
+ self.moe_intermediate_size = config.moe_intermediate_size
506
+
507
+ self.use_moe_router_bias = config.use_moe_router_bias
508
+ if self.use_moe_router_bias:
509
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
510
+ dtype=torch.float32),
511
+ requires_grad=False)
512
+ self.custom_routing_function = self.router_bias_func
513
+ elif config.moe_router_activation == "sigmoid":
514
+ self.custom_routing_function = sigmoid_routing_function
515
+ else:
516
+ self.custom_routing_function = None
517
+ self.need_fp32_gate = config.need_fp32_gate
518
+ self.routed_scaling_factor = getattr(config,
519
+ "moe_router_scaling_factor", 1.0)
520
+
521
+ # gating
522
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
523
+
524
+ self.act_fn = ACT2FN["silu"]
525
+ self.limit = swiglu_limit
526
+
527
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
528
+ self.moe_intermediate_size)
529
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
530
+ self.moe_intermediate_size)
531
+ self.down_proj = MoELinear(self.num_experts,
532
+ self.moe_intermediate_size,
533
+ self.hidden_size)
534
+
535
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
536
+ renormalize: bool):
537
+ gate_prob = torch.sigmoid(gating_output.float())
538
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
539
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
540
+ topk_prob = torch.gather(gate_prob, 1, indices)
541
+ expert_topk_weight = topk_prob
542
+ if renormalize:
543
+ expert_topk_weight = expert_topk_weight / (
544
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
545
+ return expert_topk_weight, indices
546
+
547
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
548
+ #if self.limit is None:
549
+ up = self.up_proj(inputs, expert_id)
550
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
551
+ if self.limit is not None:
552
+ gate = gate.clamp(min=None, max=self.limit)
553
+ up = up.clamp(min=-self.limit, max=self.limit)
554
+
555
+ return self.down_proj(gate * up, expert_id)
556
+
557
+ def forward(self, hidden_states):
558
+ """ """
559
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
560
+ hidden_states = hidden_states.view(-1, hidden_dim)
561
+ if self.need_fp32_gate:
562
+ router_logits = torch.matmul(
563
+ hidden_states.to(torch.float32),
564
+ self.gate.weight.t().to(torch.float32),
565
+ )
566
+ else:
567
+ # router_logits: (batch * sequence_length, n_experts)
568
+ router_logits = self.gate(hidden_states)
569
+
570
+ if self.custom_routing_function:
571
+ routing_weights, selected_experts = self.custom_routing_function(
572
+ router_logits, self.top_k, renormalize=True)
573
+ else:
574
+ routing_weights = F.softmax(router_logits,
575
+ dim=1,
576
+ dtype=torch.float)
577
+ routing_weights, selected_experts = torch.topk(routing_weights,
578
+ self.top_k,
579
+ dim=-1)
580
+
581
+ routing_weights = routing_weights * self.routed_scaling_factor
582
+
583
+ final_hidden_states = torch.zeros(
584
+ (batch_size * sequence_length, hidden_dim),
585
+ dtype=hidden_states.dtype,
586
+ device=hidden_states.device)
587
+
588
+ # One hot encode the selected experts to create an expert mask
589
+ # this will be used to easily index which expert is going to be sollicitated
590
+ expert_mask = torch.nn.functional.one_hot(
591
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
592
+
593
+ # Loop over all available experts in the model and perform the computation on each expert
594
+ for expert_idx in range(self.num_experts):
595
+ idx, top_x = torch.where(expert_mask[expert_idx])
596
+
597
+ # Index the correct hidden states and compute the expert hidden state for
598
+ # the current expert. We need to make sure to multiply the output hidden
599
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
600
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
601
+ current_hidden_states = (
602
+ self.get_expert_output(current_state, expert_idx) *
603
+ routing_weights[top_x, idx, None])
604
+
605
+ # However `index_add_` only support torch tensors for indexing so we'll use
606
+ # the `top_x` tensor here.
607
+ final_hidden_states.index_add_(
608
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
609
+ final_hidden_states = final_hidden_states.reshape(
610
+ batch_size, sequence_length, hidden_dim)
611
+ return final_hidden_states
612
+
613
+
614
+ class Step3p7RMSNorm(nn.Module):
615
+
616
+ def __init__(
617
+ self,
618
+ hidden_size: int,
619
+ eps: float = 1e-5,
620
+ ) -> None:
621
+ super().__init__()
622
+ self.weight = nn.Parameter(torch.ones(hidden_size))
623
+ self.variance_epsilon = eps
624
+
625
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
626
+ dtype = x.dtype
627
+ x = x.float()
628
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
629
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
630
+ normed = normed * (self.weight.float() + 1)
631
+ return normed.to(dtype)
632
+ class Step3p7Attention(nn.Module):
633
+
634
+ def __init__(self, config: Step3p7TextConfig, layer_idx):
635
+ super().__init__()
636
+ self.config = config
637
+ self.layer_idx = layer_idx
638
+ self.num_attention_heads = config.num_attention_heads
639
+ self.num_key_value_heads = config.num_attention_groups
640
+
641
+ layer_types = getattr(config, "layer_types", [])
642
+ if layer_types:
643
+ enable_sliding_window = layer_types[
644
+ self.layer_idx] == "sliding_attention"
645
+ else:
646
+ enable_sliding_window = self.layer_idx % 2 == 0
647
+
648
+ yarn_only_types = getattr(config, "yarn_only_types", None)
649
+ if yarn_only_types and layer_types[
650
+ self.layer_idx] not in yarn_only_types:
651
+ config.rope_parameters = None
652
+ else:
653
+ config.rope_parameters = getattr(config, "rope_scaling", None)
654
+
655
+ self.sliding_window = config.sliding_window
656
+ if enable_sliding_window:
657
+ self.num_attention_heads = config.attention_other_setting[
658
+ "num_attention_heads"]
659
+ self.num_key_value_heads = config.attention_other_setting[
660
+ "num_attention_groups"]
661
+
662
+ if self.sliding_window is not None and enable_sliding_window:
663
+ self.sliding_window = (self.sliding_window)
664
+ else:
665
+ self.sliding_window = None
666
+ self.head_dim = getattr(config, "head_dim",
667
+ config.hidden_size // self.num_attention_heads)
668
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
669
+
670
+ self.rotary_emb = Step3p7RotaryEmbedding(config, layer_idx=layer_idx)
671
+
672
+ self.q_size = self.num_attention_heads * self.head_dim
673
+ self.kv_size = self.num_key_value_heads * self.head_dim
674
+ self.scaling = self.head_dim**-0.5
675
+
676
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
677
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
678
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
679
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
680
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
681
+ self.q_norm = Step3p7RMSNorm(self.head_dim,
682
+ eps=config.rms_norm_eps)
683
+ self.k_norm = Step3p7RMSNorm(self.head_dim,
684
+ eps=config.rms_norm_eps)
685
+
686
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
687
+ if self.use_head_wise_attn_gate:
688
+ self.g_proj = nn.Linear(config.hidden_size,
689
+ self.num_attention_heads,
690
+ bias=False)
691
+
692
+ self.use_rope = True
693
+ use_rope_layers = getattr(config, "use_rope_layers", None)
694
+ if use_rope_layers:
695
+ self.use_rope = use_rope_layers[self.layer_idx]
696
+
697
+ def forward(
698
+ self,
699
+ hidden_states: torch.Tensor,
700
+ attention_mask: Optional[torch.Tensor],
701
+ past_key_value: Optional[Cache] = None,
702
+ cache_position: Optional[torch.LongTensor] = None,
703
+ position_ids: Optional[torch.LongTensor] = None,
704
+ **kwargs: Unpack[FlashAttentionKwargs],
705
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
706
+ Optional[Tuple[torch.Tensor]]]:
707
+ input_shape = hidden_states.shape[:-1]
708
+ hidden_shape = (*input_shape, -1, self.head_dim)
709
+
710
+ query_states = self.q_norm(
711
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
712
+ key_states = self.k_norm(
713
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
714
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
715
+ 1, 2)
716
+ if self.use_head_wise_attn_gate:
717
+ gate_states = self.g_proj(hidden_states)
718
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
719
+
720
+ # cos, sin = position_embeddings
721
+ query_states, key_states = apply_rotary_pos_emb(
722
+ query_states, key_states, cos, sin)
723
+
724
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
725
+ if past_key_value is not None:
726
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
727
+ cache_kwargs = {
728
+ "sin": sin,
729
+ "cos": cos,
730
+ "cache_position": cache_position
731
+ }
732
+ key_states, value_states = past_key_value.update(
733
+ key_states, value_states, self.layer_idx, cache_kwargs)
734
+
735
+ attention_interface: Callable = eager_attention_forward
736
+ # TODO: considering FP8;
737
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
738
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
739
+ if self.config._attn_implementation != "eager":
740
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
741
+ self.config._attn_implementation]
742
+
743
+ attn_output, attn_weights = attention_interface(
744
+ self,
745
+ query_states,
746
+ key_states,
747
+ value_states,
748
+ attention_mask,
749
+ dropout=0.0 if not self.training else self.attention_dropout,
750
+ scaling=self.scaling,
751
+ sliding_window=self.sliding_window, # main diff with Llama
752
+ **kwargs,
753
+ )
754
+ attn_output = attn_output.reshape(*input_shape, -1)
755
+ if self.use_head_wise_attn_gate:
756
+ output = attn_output.view(
757
+ *attn_output.shape[:-1], self.num_attention_heads,
758
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
759
+ attn_output = output.view(*attn_output.shape)
760
+ attn_output = self.o_proj(attn_output)
761
+
762
+ return attn_output, attn_weights
763
+
764
+
765
+ class Step3p7DecoderLayer(GradientCheckpointingLayer):
766
+
767
+ def __init__(self, config, layer_idx):
768
+ super().__init__()
769
+ self.hidden_size = config.hidden_size
770
+ self.layer_idx = layer_idx
771
+ self.self_attn = Step3p7Attention(config, layer_idx)
772
+ layer_types = getattr(config, "layer_types", None) or []
773
+ if layer_types:
774
+ self.attention_type = layer_types[layer_idx]
775
+ else:
776
+ self.attention_type = (
777
+ "sliding_attention" if layer_idx % 2 == 0 else "full_attention"
778
+ )
779
+
780
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
781
+ if moe_layers_enum is not None:
782
+ if isinstance(moe_layers_enum, str):
783
+ moe_layers_idx = [
784
+ int(i) for i in moe_layers_enum.split(',') if i.strip()
785
+ ]
786
+ else:
787
+ moe_layers_idx = [int(i) for i in moe_layers_enum]
788
+ else:
789
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
790
+ self.is_moe_layer = layer_idx in moe_layers_idx
791
+ self.use_moe = False
792
+
793
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
794
+ layer_idx] is not None and config.swiglu_limits_shared[
795
+ layer_idx] != 0:
796
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
797
+ else:
798
+ swiglu_limit_shared = None
799
+ if config.swiglu_limits and config.swiglu_limits[
800
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
801
+ swiglu_limit = config.swiglu_limits[layer_idx]
802
+ else:
803
+ swiglu_limit = None
804
+ if self.is_moe_layer:
805
+ self.moe = Step3p7MoEMLP(config, swiglu_limit=swiglu_limit) #
806
+ self.share_expert = Step3p7MLP(
807
+ config,
808
+ intermediate_size=config.share_expert_dim,
809
+ swiglu_limit=swiglu_limit_shared)
810
+ self.use_moe = True
811
+ else:
812
+ self.mlp = Step3p7MLP(config,
813
+ intermediate_size=config.intermediate_size,
814
+ swiglu_limit=swiglu_limit_shared)
815
+
816
+ self.input_layernorm = Step3p7RMSNorm(
817
+ config.hidden_size,
818
+ eps=config.rms_norm_eps)
819
+ self.post_attention_layernorm = Step3p7RMSNorm(
820
+ config.hidden_size,
821
+ eps=config.rms_norm_eps)
822
+
823
+ def forward(
824
+ self,
825
+ hidden_states: torch.Tensor,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ position_ids: Optional[torch.LongTensor] = None,
828
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
829
+ cache_position: Optional[torch.LongTensor] = None,
830
+ **kwargs: Unpack[FlashAttentionKwargs],
831
+ ) -> torch.FloatTensor:
832
+ residual = hidden_states
833
+ hidden_states = self.input_layernorm(hidden_states)
834
+ hidden_states, _ = self.self_attn(
835
+ hidden_states=hidden_states,
836
+ attention_mask=attention_mask,
837
+ position_ids=position_ids,
838
+ past_key_value=past_key_value,
839
+ cache_position=cache_position,
840
+ **kwargs,
841
+ )
842
+ hidden_states = residual + hidden_states
843
+
844
+ # Fully Connected
845
+ residual = hidden_states
846
+ hidden_states = self.post_attention_layernorm(hidden_states)
847
+ if self.use_moe:
848
+ share_output = self.share_expert(hidden_states)
849
+ moe_output = self.moe(hidden_states)
850
+ ffn_output = moe_output + share_output
851
+ else:
852
+ ffn_output = self.mlp(hidden_states)
853
+ if isinstance(ffn_output, tuple):
854
+ hidden_states, _ = ffn_output
855
+ else:
856
+ hidden_states = ffn_output
857
+
858
+ hidden_states = residual + hidden_states
859
+ return hidden_states
860
+
861
+
862
+ class Step3p7TextPreTrainedModel(PreTrainedModel):
863
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
864
+ # can load the config instead of failing with a NoneType error.
865
+ config_class = Step3p7TextConfig
866
+ supports_gradient_checkpointing = True
867
+ _skip_keys_device_placement = ["past_key_values"]
868
+ _keys_to_ignore_on_load_unexpected = [
869
+ r"model\.layers\.45\.*",
870
+ r"model\.layers\.46\.*",
871
+ r"model\.layers\.47\.*",
872
+ ]
873
+ _supports_flash_attn = False
874
+ _supports_sdpa = True
875
+ _supports_flex_attn = True
876
+ _supports_static_cache = True
877
+ _supports_attention_backend = True
878
+
879
+
880
+ class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
881
+ _no_split_modules = ["Step3p7DecoderLayer"]
882
+ base_model_prefix = "model"
883
+ _tied_weights_keys = ["lm_head.weight"]
884
+ config: Step3p7TextConfig
885
+
886
+ def __init__(self, config: Step3p7TextConfig):
887
+ super().__init__(config)
888
+ self.padding_idx = config.pad_token_id
889
+ self.vocab_size = config.vocab_size
890
+
891
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
892
+ self.padding_idx)
893
+ self.layers = nn.ModuleList([
894
+ Step3p7DecoderLayer(config, layer_idx)
895
+ for layer_idx in range(config.num_hidden_layers)
896
+ ])
897
+ self.norm = Step3p7RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
898
+ self.gradient_checkpointing = False
899
+ layer_types = self.config.layer_types or []
900
+ self.has_sliding_layers = (not layer_types or
901
+ "sliding_attention" in layer_types)
902
+
903
+ # Initialize weights and apply final processing
904
+ self.post_init()
905
+
906
+
907
+ def get_input_embeddings(self, input_ids):
908
+ return self.embed_tokens(input_ids)
909
+
910
+ @can_return_tuple
911
+ def forward(
912
+ self,
913
+ input_ids: torch.LongTensor = None,
914
+ attention_mask: Optional[torch.Tensor] = None,
915
+ position_ids: Optional[torch.LongTensor] = None,
916
+ past_key_values: Optional[Cache] = None,
917
+ inputs_embeds: Optional[torch.FloatTensor] = None,
918
+ use_cache: Optional[bool] = None,
919
+ output_attentions: Optional[bool] = None,
920
+ output_hidden_states: Optional[bool] = None,
921
+ return_dict: Optional[bool] = None,
922
+ cache_position: Optional[torch.LongTensor] = None,
923
+ **kwargs: Unpack[TransformersKwargs],
924
+ ) -> Union[tuple, BaseModelOutputWithPast]:
925
+ output_attentions = (
926
+ output_attentions
927
+ if output_attentions is not None
928
+ else self.config.output_attentions
929
+ )
930
+ output_hidden_states = (
931
+ output_hidden_states
932
+ if output_hidden_states is not None
933
+ else self.config.output_hidden_states
934
+ )
935
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
936
+ return_dict = (
937
+ return_dict
938
+ if return_dict is not None
939
+ else getattr(self.config, "return_dict", True)
940
+ )
941
+ if (input_ids is None) ^ (inputs_embeds is not None):
942
+ raise ValueError(
943
+ "You must specify exactly one of input_ids or inputs_embeds")
944
+
945
+ if self.gradient_checkpointing and self.training and use_cache:
946
+ logger.warning_once(
947
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
948
+ )
949
+ use_cache = False
950
+
951
+ if inputs_embeds is None:
952
+ inputs_embeds = self.embed_tokens(
953
+ input_ids.to(self.embed_tokens.weight.device))
954
+
955
+ if use_cache and past_key_values is None:
956
+ past_key_values = DynamicCache()
957
+
958
+ if cache_position is None:
959
+ past_seen_tokens = past_key_values.get_seq_length(
960
+ ) if past_key_values is not None else 0
961
+ cache_position = torch.arange(past_seen_tokens,
962
+ past_seen_tokens +
963
+ inputs_embeds.shape[1],
964
+ device=inputs_embeds.device)
965
+
966
+ if position_ids is None:
967
+ position_ids = cache_position.unsqueeze(0)
968
+
969
+ hidden_states = inputs_embeds
970
+
971
+ # It may already have been prepared by e.g. `generate`
972
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
973
+ # Prepare mask arguments
974
+ mask_kwargs = {
975
+ "config": self.config,
976
+ "attention_mask": attention_mask,
977
+ "past_key_values": past_key_values,
978
+ "position_ids": position_ids,
979
+ }
980
+ mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
981
+ # Create the masks
982
+ causal_mask_mapping = {
983
+ "full_attention": create_causal_mask(**mask_kwargs),
984
+ }
985
+
986
+ # The sliding window alternating layers are not always activated depending on the config
987
+ if self.has_sliding_layers:
988
+ causal_mask_mapping[
989
+ "sliding_attention"] = create_sliding_window_causal_mask(
990
+ **mask_kwargs)
991
+
992
+ # # create position embeddings to be shared across the decoder layers
993
+ # decoder layers
994
+ all_hidden_states = () if output_hidden_states else None
995
+ all_self_attns = () if output_attentions else None
996
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
997
+ if output_hidden_states:
998
+ all_hidden_states += (hidden_states, )
999
+
1000
+ layer_outputs = decoder_layer(
1001
+ hidden_states,
1002
+ attention_mask=causal_mask_mapping[
1003
+ decoder_layer.attention_type],
1004
+ position_ids=position_ids,
1005
+ past_key_value=past_key_values,
1006
+ output_attentions=output_attentions,
1007
+ use_cache=use_cache,
1008
+ cache_position=cache_position,
1009
+ **kwargs,
1010
+ )
1011
+
1012
+ hidden_states = layer_outputs
1013
+
1014
+ hidden_states = self.norm(hidden_states)
1015
+
1016
+ return BaseModelOutputWithPast(
1017
+ last_hidden_state=hidden_states,
1018
+ past_key_values=past_key_values if use_cache else None,
1019
+ hidden_states=all_hidden_states,
1020
+ attentions=all_self_attns,
1021
+ )
1022
+
1023
+
1024
+ class Step3p7Model(Step3p7PreTrainedModel, GenerationMixin):
1025
+ config: Step3p7Config
1026
+ _tied_weights_keys = ["lm_head.weight"]
1027
+ base_model_prefix = ""
1028
+
1029
+ def __init__(self, config: Step3p7Config):
1030
+ super().__init__(config)
1031
+ self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
1032
+ self.language_model = Step3p7TextModel(config.text_config)
1033
+ self.vocab_size = config.text_config.vocab_size
1034
+ self.vit_large_projector = nn.Linear(
1035
+ config.vision_config.width * 4,
1036
+ config.text_config.hidden_size,
1037
+ bias=config.projector_bias)
1038
+ self.image_placeholder_token_id = config.image_token_id
1039
+
1040
+ # Initialize weights and apply final processing
1041
+ self.post_init()
1042
+
1043
+ def get_input_embeddings(
1044
+ self,
1045
+ input_ids: torch.Tensor,
1046
+ multimodal_embeddings = None,
1047
+ ) -> torch.Tensor:
1048
+ # breakpoint()
1049
+ input_ids = input_ids.squeeze(0)
1050
+ if multimodal_embeddings is None:
1051
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1052
+ else:
1053
+ is_text = input_ids != self.config.image_token_id
1054
+ text_ids = input_ids[is_text]
1055
+ text_embeds = self.language_model.get_input_embeddings(text_ids)
1056
+
1057
+ inputs_embeds = torch.empty(input_ids.shape[0],
1058
+ text_embeds.shape[-1],
1059
+ dtype=text_embeds.dtype,
1060
+ device=text_embeds.device)
1061
+ inputs_embeds[is_text] = text_embeds
1062
+ inputs_embeds = merge_multimodal_embeddings(
1063
+ input_ids, inputs_embeds, multimodal_embeddings,
1064
+ self.config.image_token_id)
1065
+ inputs_embeds = inputs_embeds.unsqueeze(0)
1066
+ return inputs_embeds
1067
+
1068
+
1069
+ def set_input_embeddings(self, value):
1070
+ return self.language_model.set_input_embeddings(value)
1071
+
1072
+ def set_decoder(self, decoder):
1073
+ self.language_model = decoder
1074
+
1075
+ def get_decoder(self):
1076
+ return self.language_model
1077
+
1078
+ def _parse_and_validate_image_input(
1079
+ self, **kwargs: object) -> Optional[StepVLImageInputs]:
1080
+ pixel_values = kwargs.pop("pixel_values", None)
1081
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
1082
+ num_patches = kwargs.pop("num_patches", None)
1083
+ image_embeds = kwargs.pop("image_embeds", None)
1084
+
1085
+ if pixel_values is None and image_embeds is None:
1086
+ return None
1087
+
1088
+ if pixel_values is not None:
1089
+ # pixel_values = flatten_bn(pixel_values, concat=True)
1090
+ if pixel_values.dim() >= 3:
1091
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
1092
+ if patch_pixel_values is not None:
1093
+ # patch_pixel_values = flatten_bn(patch_pixel_values,
1094
+ # concat=True)
1095
+ patch_pixel_values = patch_pixel_values.view(
1096
+ -1, *patch_pixel_values.shape[-3:])
1097
+ # Handle empty patch_pixel_values by setting to None
1098
+ if patch_pixel_values.shape[0] == 0:
1099
+ patch_pixel_values = None
1100
+
1101
+ return StepVLImagePixelInputs(
1102
+ type="pixel_values",
1103
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
1104
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
1105
+ self.device) if patch_pixel_values is not None else None,
1106
+ num_patches=num_patches,
1107
+ )
1108
+
1109
+ if image_embeds is not None:
1110
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
1111
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
1112
+ else:
1113
+ raise ValueError(
1114
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
1115
+
1116
+ return StepVLImageEmbeddingInputs(
1117
+ type="image_embeds",
1118
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
1119
+ )
1120
+ return None
1121
+
1122
+ def _process_image_features(self,
1123
+ image_features: torch.Tensor) -> torch.Tensor:
1124
+ B, P = image_features.shape[:2]
1125
+ HW = int(P ** 0.5)
1126
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
1127
+ image_features = self.vision_model.vit_downsampler1(image_features)
1128
+ image_features = self.vision_model.vit_downsampler2(image_features)
1129
+
1130
+ B, C, HW, HW = image_features.shape
1131
+ image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
1132
+ image_features = self.vit_large_projector(image_features)
1133
+ return image_features
1134
+
1135
+ def _get_vision_model_output(self,
1136
+ input_tensor: torch.Tensor) -> torch.Tensor:
1137
+ return self.vision_model(input_tensor)
1138
+
1139
+ def _process_image_input(
1140
+ self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
1141
+
1142
+ if image_input["type"] == "image_embeds":
1143
+ image_features = image_input["image_embeds"]
1144
+ else:
1145
+ image_features = self._get_vision_model_output(
1146
+ image_input["pixel_values"])
1147
+ patch_image_features = self._get_vision_model_output(
1148
+ image_input["patch_pixel_values"]
1149
+ ) if image_input["patch_pixel_values"] is not None else None
1150
+ num_patches = image_input["num_patches"]
1151
+
1152
+ image_features = self._process_image_features(image_features)
1153
+ patch_image_features = self._process_image_features(
1154
+ patch_image_features) if patch_image_features is not None else None
1155
+
1156
+ merged_image_features = []
1157
+ cur_patch_idx = 0
1158
+ for i, num_patch in enumerate(num_patches):
1159
+ cur_feature = []
1160
+ if num_patch > 0:
1161
+ patch_slice = patch_image_features[
1162
+ cur_patch_idx:cur_patch_idx + num_patch]
1163
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
1164
+ cur_feature.append(image_features[i].view(
1165
+ -1, image_features.shape[-1]))
1166
+ cur_patch_idx += num_patch
1167
+ merged_image_features.append(
1168
+ torch.cat(cur_feature) if len(cur_feature) >
1169
+ 1 else cur_feature[0])
1170
+
1171
+ return merged_image_features
1172
+
1173
+ def get_multimodal_embeddings(self, **kwargs):
1174
+ # breakpoint()
1175
+ image_input = self._parse_and_validate_image_input(**kwargs)
1176
+ if image_input is None:
1177
+ return None
1178
+ vision_embeddings = self._process_image_input(image_input)
1179
+ return vision_embeddings
1180
+
1181
+ @can_return_tuple
1182
+ def forward(
1183
+ self,
1184
+ input_ids: torch.LongTensor = None,
1185
+ attention_mask: Optional[torch.Tensor] = None,
1186
+ position_ids: Optional[torch.LongTensor] = None,
1187
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
1188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1189
+ labels: Optional[torch.LongTensor] = None,
1190
+ use_cache: Optional[bool] = None,
1191
+ output_attentions: Optional[bool] = None,
1192
+ output_hidden_states: Optional[bool] = None,
1193
+ return_dict: Optional[bool] = None,
1194
+ cache_position: Optional[torch.LongTensor] = None,
1195
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1196
+ images: Optional[list[Image.Image]] = None,
1197
+ **kwargs: Unpack[TransformersKwargs],
1198
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1199
+ r"""
1200
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1201
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1202
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1203
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1204
+ Example:
1205
+ ```python
1206
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
1207
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1208
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1209
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1210
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1211
+ >>> # Generate
1212
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1213
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1214
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1215
+ ```"""
1216
+ output_attentions = (
1217
+ output_attentions
1218
+ if output_attentions is not None
1219
+ else self.config.output_attentions
1220
+ )
1221
+ output_hidden_states = (
1222
+ output_hidden_states
1223
+ if output_hidden_states is not None
1224
+ else self.config.output_hidden_states
1225
+ )
1226
+ return_dict = (
1227
+ return_dict if return_dict is not None else self.config.use_return_dict
1228
+ )
1229
+
1230
+ if inputs_embeds is None:
1231
+ input_ids = input_ids
1232
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1233
+ inputs_embeds = self.get_input_embeddings(input_ids,
1234
+ vision_embeddings)
1235
+ input_ids = None
1236
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1237
+ outputs = self.language_model(
1238
+ input_ids=None,
1239
+ position_ids=position_ids,
1240
+ attention_mask=attention_mask,
1241
+ past_key_values=past_key_values,
1242
+ inputs_embeds=inputs_embeds,
1243
+ use_cache=use_cache,
1244
+ output_attentions=output_attentions,
1245
+ output_hidden_states=output_hidden_states,
1246
+ return_dict=True,
1247
+ cache_position=cache_position,
1248
+ **kwargs,
1249
+ )
1250
+
1251
+ output = Step3p7CausalLMOutputWithPast(
1252
+ last_hidden_state=outputs.last_hidden_state,
1253
+ past_key_values=outputs.past_key_values,
1254
+ attentions=outputs.attentions,
1255
+ )
1256
+ return output if return_dict else output.to_tuple()
1257
+
1258
+
1259
+ class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
1260
+ _checkpoint_conversion_mapping = {
1261
+ "^vision_model": "model.vision_model",
1262
+ r"^model(?!\.(language_model|vision_model))": "model.language_model",
1263
+ "^vit_large_projector": "model.vit_large_projector",
1264
+ }
1265
+ _tied_weights_keys = ["lm_head.weight"]
1266
+ config: Step3p7Config
1267
+
1268
+ def __init__(self, config: Step3p7Config):
1269
+ super().__init__(config)
1270
+ self.model = Step3p7Model(config)
1271
+ self.lm_head = nn.Linear(config.hidden_size,
1272
+ config.text_config.vocab_size,
1273
+ bias=False)
1274
+
1275
+ self.post_init()
1276
+
1277
+ def get_input_embeddings(self):
1278
+ return self.model.get_input_embeddings()
1279
+
1280
+ def set_input_embeddings(self, value):
1281
+ self.model.set_input_embeddings(value)
1282
+
1283
+ def get_output_embeddings(self):
1284
+ return self.model.get_output_embeddings()
1285
+
1286
+ def set_output_embeddings(self, new_embeddings):
1287
+ self.model.set_output_embeddings(new_embeddings)
1288
+
1289
+ def set_decoder(self, decoder):
1290
+ self.model.set_decoder(decoder)
1291
+
1292
+ def get_decoder(self):
1293
+ return self.model.get_decoder()
1294
+
1295
+ @property
1296
+ def language_model(self):
1297
+ return self.model.language_model
1298
+
1299
+ @property
1300
+ def visual(self):
1301
+ return self.model.vision_model
1302
+
1303
+ def forward(
1304
+ self,
1305
+ input_ids: torch.LongTensor = None,
1306
+ pixel_values: Optional[torch.Tensor] = None,
1307
+ num_patches=None,
1308
+ patch_pixel_values=None,
1309
+ patch_newline_mask=None,
1310
+ image_embeds: Optional[torch.FloatTensor] = None,
1311
+ attention_mask: Optional[torch.Tensor] = None,
1312
+ position_ids: Optional[torch.LongTensor] = None,
1313
+ past_key_values: Optional[Cache] = None,
1314
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1315
+ labels: Optional[torch.LongTensor] = None,
1316
+ use_cache: Optional[bool] = None,
1317
+ output_attentions: Optional[bool] = None,
1318
+ output_hidden_states: Optional[bool] = None,
1319
+ return_dict: Optional[bool] = None,
1320
+ cache_position: Optional[torch.LongTensor] = None,
1321
+ **kwargs: Unpack[TransformersKwargs],
1322
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1323
+ output_attentions = (
1324
+ output_attentions
1325
+ if output_attentions is not None
1326
+ else self.config.output_attentions
1327
+ )
1328
+ output_hidden_states = (
1329
+ output_hidden_states
1330
+ if output_hidden_states is not None
1331
+ else self.config.output_hidden_states
1332
+ )
1333
+
1334
+ outputs = self.model(
1335
+ input_ids=input_ids,
1336
+ num_patches=num_patches,
1337
+ patch_pixel_values=patch_pixel_values,
1338
+ patch_newline_mask=patch_newline_mask,
1339
+ position_ids=position_ids,
1340
+ attention_mask=attention_mask,
1341
+ past_key_values=past_key_values,
1342
+ inputs_embeds=inputs_embeds,
1343
+ use_cache=use_cache,
1344
+ output_attentions=output_attentions,
1345
+ output_hidden_states=output_hidden_states,
1346
+ return_dict=return_dict,
1347
+ cache_position=cache_position,
1348
+ **kwargs,
1349
+ )
1350
+
1351
+ hidden_states = outputs.last_hidden_state
1352
+ logits = self.lm_head(hidden_states)
1353
+
1354
+ los = None
1355
+ if labels is not None:
1356
+ loss = self.loss_function(
1357
+ logits=logits, labels=labels, vocab_size=self.config.vocab_size
1358
+ )
1359
+
1360
+ return Step3p7CausalLMOutputWithPast(
1361
+ logits=logits,
1362
+ )
1363
+
1364
+
1365
+ def prepare_inputs_for_generation(
1366
+ self,
1367
+ input_ids,
1368
+ past_key_values=None,
1369
+ inputs_embeds=None,
1370
+ pixel_values=None,
1371
+ patch_pixel_values=None,
1372
+ num_patches=None,
1373
+ image_embeds=None,
1374
+ attention_mask=None,
1375
+ cache_position=None,
1376
+ logits_to_keep=None,
1377
+ **kwargs,
1378
+ ):
1379
+ model_inputs = super().prepare_inputs_for_generation(
1380
+ input_ids,
1381
+ past_key_values=past_key_values,
1382
+ inputs_embeds=inputs_embeds,
1383
+ attention_mask=attention_mask,
1384
+ cache_position=cache_position,
1385
+ logits_to_keep=logits_to_keep,
1386
+ **kwargs,
1387
+ )
1388
+
1389
+ generation_cache_position = model_inputs.get("cache_position", cache_position)
1390
+ is_prefill = past_key_values is None
1391
+ if generation_cache_position is not None and generation_cache_position.numel() > 0:
1392
+ is_prefill = generation_cache_position[0].item() == 0
1393
+
1394
+ if is_prefill:
1395
+ # During cached decoding, input ids no longer contain image tokens,
1396
+ # so pixel values should only be passed at the first step.
1397
+ model_inputs["pixel_values"] = pixel_values
1398
+
1399
+ return model_inputs
1400
+
1401
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
1402
+ if key.startswith("language_model."):
1403
+ return key[len("language_model.") :], True
1404
+
1405
+ return key, False
processing_step3.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BaseImageProcessor, ImageProcessingMixin
2
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3
+ import math
4
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
5
+
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature, TensorType
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
19
+ from transformers.tokenization_utils_tokenizers import TokenizersBackend
20
+ from math import ceil
21
+ from itertools import product
22
+
23
+
24
+
25
+ MAX_IMAGE_SIZE: int = 3024
26
+
27
+ class Step3VLImagePixelInputs(TypedDict):
28
+ type: Literal["pixel_values"]
29
+ pixel_values: torch.Tensor
30
+ patch_pixel_values: Optional[torch.Tensor]
31
+ num_patches: list[int]
32
+
33
+
34
+ class Step3VLImageEmbeddingInputs(TypedDict):
35
+ type: Literal["image_embeds"]
36
+ image_embeds: torch.Tensor
37
+
38
+
39
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
40
+
41
+
42
+ class GPUToTensor(torch.nn.Module):
43
+
44
+ def forward(self, raw_image: Union[np.ndarray,
45
+ Image.Image]) -> torch.Tensor:
46
+ if isinstance(raw_image, Image.Image):
47
+ return transforms.ToTensor()(raw_image)
48
+ if raw_image.ndim == 2:
49
+ raw_image = raw_image[:, :, None].repeat(3, -1)
50
+ if torch.cuda.is_available():
51
+ device = torch.device("cuda")
52
+ else:
53
+ device = torch.device("cpu")
54
+ image_tensor = torch.from_numpy(raw_image).to(device)
55
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
56
+ if image_tensor.dtype == torch.uint8:
57
+ image_tensor = image_tensor.to(torch.float32).div(255)
58
+ return image_tensor
59
+
60
+ class Step3VisionProcessor(BaseImageProcessor):
61
+
62
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
63
+ mean = [0.48145466, 0.4578275, 0.40821073]
64
+ std = [0.26862954, 0.26130258, 0.27577711]
65
+ patch_size = patch_size if patch_size is not None else size
66
+
67
+ self.transform = transforms.Compose([
68
+ GPUToTensor(),
69
+ transforms.Normalize(mean, std),
70
+ transforms.Resize(
71
+ (size, size),
72
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
73
+ == "bicubic" else InterpolationMode.BILINEAR,
74
+ antialias=True),
75
+ ])
76
+
77
+ self.patch_transform = transforms.Compose([
78
+ GPUToTensor(),
79
+ transforms.Normalize(mean, std),
80
+ transforms.Resize(
81
+ (patch_size, patch_size),
82
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
83
+ == "bicubic" else InterpolationMode.BILINEAR,
84
+ antialias=True),
85
+ ]) if patch_size is not None else None
86
+
87
+ def __call__(self, image, is_patch=False):
88
+ if is_patch:
89
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
90
+ else:
91
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
92
+
93
+ class ImagePatcher:
94
+ def determine_window_size(self, long: int, short: int) -> int:
95
+ if long <= 728:
96
+ return short if long / short > 1.5 else 0
97
+ return min(short, 504) if long / short > 4 else 504
98
+ def slide_window(
99
+ self,
100
+ width: int,
101
+ height: int,
102
+ sizes: list[tuple[int, int]],
103
+ steps: list[tuple[int, int]],
104
+ img_rate_thr: float = 0.6,
105
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
106
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
107
+ windows = []
108
+ # Sliding windows.
109
+ for size, step in zip(sizes, steps):
110
+ size_w, size_h = size
111
+ step_w, step_h = step
112
+
113
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
114
+ 1)
115
+ x_start = [step_w * i for i in range(x_num)]
116
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
117
+ x_start[-1] = width - size_w
118
+
119
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
120
+ step_h + 1)
121
+ y_start = [step_h * i for i in range(y_num)]
122
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
123
+ y_start[-1] = height - size_h
124
+
125
+ start = np.array(list(product(y_start, x_start)), dtype=int)
126
+ start[:, [0, 1]] = start[:, [1, 0]]
127
+ windows.append(np.concatenate([start, start + size], axis=1))
128
+ windows = np.concatenate(windows, axis=0)
129
+
130
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
131
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
132
+
133
+ def square_pad(self, img: Image.Image) -> Image.Image:
134
+ w, h = img.size
135
+ if w == h:
136
+ return img
137
+ size = max(w, h)
138
+ padded = Image.new(img.mode, (size, size), 0)
139
+ padded.paste(img, (0, 0))
140
+ return padded
141
+
142
+ def get_image_size_for_padding(self, img_width: int,
143
+ img_height: int) -> tuple[int, int]:
144
+ ratio = img_width / img_height
145
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
146
+ new_size = max(img_height, img_width)
147
+ return new_size, new_size
148
+ return img_width, img_height
149
+
150
+ def get_image_size_for_preprocess(self, img_width: int,
151
+ img_height: int) -> tuple[int, int]:
152
+
153
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
154
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
155
+ img_width = int(img_width * scale_factor)
156
+ img_height = int(img_height * scale_factor)
157
+ return img_width, img_height
158
+
159
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
160
+ window_size: int):
161
+ w_ratio = img_width / window_size
162
+ h_ratio = img_height / window_size
163
+
164
+ if w_ratio < 1:
165
+ width_new = img_width
166
+ else:
167
+ decimal_w = w_ratio - img_width // window_size
168
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
169
+ width_new = window_size * w_ratio
170
+ if h_ratio < 1:
171
+ height_new = img_height
172
+ else:
173
+ decimal_h = h_ratio - img_height // window_size
174
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
175
+ height_new = window_size * h_ratio
176
+ return int(width_new), int(height_new)
177
+
178
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
179
+ target = img.crop((j, i, j + tw, i + th))
180
+ return target
181
+
182
+ def get_num_patches(self, img_width: int,
183
+ img_height: int) -> tuple[int, int]:
184
+ img_width, img_height = self.get_image_size_for_padding(
185
+ img_width, img_height)
186
+ img_width, img_height = self.get_image_size_for_preprocess(
187
+ img_width, img_height)
188
+ window_size = self.determine_window_size(max(img_height, img_width),
189
+ min(img_height, img_width))
190
+ if window_size == 0:
191
+ return 0, 0
192
+ else:
193
+ img_width, img_height = self.get_image_size_for_crop(
194
+ img_width, img_height, window_size)
195
+ center_list, (x_num, y_num) = self.slide_window(
196
+ img_width, img_height, [(window_size, window_size)],
197
+ [(window_size, window_size)])
198
+ full_rows = (len(center_list) - 1) // x_num + 1
199
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
200
+ full_rows -= 1
201
+ return len(center_list), full_rows
202
+
203
+ def __call__(
204
+ self, img: Image.Image
205
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
206
+ img_width, img_height = img.size
207
+ new_img_width, new_img_height = self.get_image_size_for_padding(
208
+ img_width, img_height)
209
+ if new_img_width != img_width or new_img_height != img_height:
210
+ img = self.square_pad(img)
211
+ img_width, img_height = img.size
212
+
213
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
214
+ img_width, img_height)
215
+ img = img.resize((new_img_width, new_img_height),
216
+ Image.Resampling.BILINEAR)
217
+ window_size = self.determine_window_size(
218
+ max(new_img_height, new_img_width),
219
+ min(new_img_height, new_img_width))
220
+ # return img, [], None
221
+ if window_size == 0:
222
+ return img, [], None
223
+ else:
224
+ new_img_width, new_img_height = self.get_image_size_for_crop(
225
+ new_img_width, new_img_height, window_size)
226
+ if (new_img_width, new_img_height) != (img_width, img_height):
227
+ img_for_crop = img.resize((new_img_width, new_img_height),
228
+ Image.Resampling.BILINEAR)
229
+ else:
230
+ img_for_crop = img
231
+
232
+ patches = []
233
+ newlines = []
234
+ center_list, (x_num, y_num) = self.slide_window(
235
+ new_img_width, new_img_height, [(window_size, window_size)],
236
+ [(window_size, window_size)])
237
+ for patch_id, center_lf_point in enumerate(center_list):
238
+ x, y, patch_w, patch_h = center_lf_point
239
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
240
+ patch_w)
241
+ patches.append(big_patch)
242
+ if (patch_id + 1) % x_num == 0:
243
+ newlines.append(patch_id)
244
+
245
+ if newlines and newlines[-1] == len(patches) - 1:
246
+ newlines.pop()
247
+
248
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
249
+
250
+
251
+
252
+
253
+ class Step3VLProcessor(ProcessorMixin):
254
+ # Align ProcessorMixin with our custom components.
255
+ # We only have an image processor (not a feature extractor) plus a tokenizer.
256
+ attributes = ["tokenizer"]
257
+ tokenizer_class = "AutoTokenizer"
258
+
259
+ @classmethod
260
+ def _load_tokenizer_from_pretrained(
261
+ cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
262
+ ):
263
+ return TokenizersBackend.from_pretrained(
264
+ pretrained_model_name_or_path,
265
+ subfolder=subfolder,
266
+ **kwargs,
267
+ )
268
+
269
+ def __init__(
270
+ self,
271
+ tokenizer=None,
272
+ chat_template=None,
273
+ **kwargs
274
+ ) -> None:
275
+ self.image_size = 728
276
+ self.patch_size = 504
277
+
278
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
279
+ "bilinear",
280
+ self.patch_size)
281
+
282
+ self.num_image_feature_size = 169
283
+ self.num_patch_feature_size = 81
284
+ self.image_token = "<im_patch>"
285
+ self.image_feature_placeholder = (self.image_token *
286
+ self.num_image_feature_size)
287
+ self.patch_feature_placeholder = (self.image_token *
288
+ self.num_patch_feature_size)
289
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
290
+ self.patcher = ImagePatcher()
291
+
292
+ @property
293
+ def image_token_id(self) -> int:
294
+ return self.tokenizer.get_vocab()[self.image_token]
295
+
296
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
297
+ num_patches, num_newlines = self.patcher.get_num_patches(
298
+ img_width, img_height)
299
+
300
+ return num_patches * (
301
+ self.num_patch_feature_size +
302
+ 2) + self.num_image_feature_size + 2 + num_newlines
303
+
304
+ def _split_images(self,
305
+ images: list[Image.Image]) -> list[ImageWithPatches]:
306
+ result = []
307
+ for img in images:
308
+ result.append(self.patcher(img))
309
+ return result
310
+
311
+ def _convert_images_to_pixel_values(
312
+ self,
313
+ images: list[Image.Image],
314
+ is_patch: bool = False,
315
+ ) -> list[torch.Tensor]:
316
+ return [
317
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
318
+ for img in images
319
+ ]
320
+
321
+ def _get_patch_repl(
322
+ self,
323
+ num_patches: int,
324
+ patch_newline_mask: list[bool] | None,
325
+ ) -> tuple[str, list[int]]:
326
+ text = ""
327
+ token_ids = []
328
+ for i in range(num_patches):
329
+ assert len(patch_newline_mask) == num_patches
330
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
331
+ token_ids.extend(
332
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
333
+ [self.image_token_id] * self.num_patch_feature_size +
334
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
335
+ if patch_newline_mask and patch_newline_mask[i]:
336
+ text += "<patch_newline>"
337
+ token_ids.append(
338
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
339
+ return text, token_ids
340
+
341
+ def _get_image_repl(
342
+ self,
343
+ num_images: int,
344
+ ) -> tuple[str, list[int]]:
345
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
346
+ token_ids = [
347
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
348
+ ] + [self.image_token_id] * self.num_image_feature_size + [
349
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
350
+ ]
351
+ return text * num_images, token_ids * num_images
352
+
353
+ def _get_image_repl_features(
354
+ self,
355
+ num_images: int,
356
+ num_patches: int,
357
+ patch_new_line_idx: Optional[list[bool]],
358
+ ) -> tuple[str, list[int]]:
359
+ if num_patches > 0:
360
+ patch_repl, patch_repl_ids = self._get_patch_repl(
361
+ num_patches, patch_new_line_idx)
362
+ else:
363
+ patch_repl = ""
364
+ patch_repl_ids = []
365
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
366
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
367
+
368
+ def replace_placeholder(self, text: str, placeholder: str,
369
+ repls: list[str]) -> str:
370
+ parts = text.split(placeholder)
371
+
372
+ if len(parts) - 1 != len(repls):
373
+ raise ValueError(
374
+ "The number of placeholders does not match the number of replacements." # noqa: E501
375
+ )
376
+
377
+ result = [parts[0]]
378
+ for i, repl in enumerate(repls):
379
+ result.append(repl)
380
+ result.append(parts[i + 1])
381
+
382
+ return "".join(result)
383
+
384
+ def __call__(
385
+ self,
386
+ text: Optional[Union[str, list[str]]] = None,
387
+ images: ImageInput | None = None,
388
+ return_tensors: Optional[Union[str, TensorType]] = None,
389
+ **kwargs,
390
+ ) -> BatchFeature:
391
+
392
+ if images is not None:
393
+ images = self.image_preprocessor.fetch_images(images)
394
+ if text is None:
395
+ text = []
396
+ if not isinstance(text, list):
397
+ text = [text]
398
+ if images is None:
399
+ images = []
400
+ elif not isinstance(images, list):
401
+ images = [images]
402
+ elif isinstance(images[0], list):
403
+ images = images[0]
404
+
405
+ if len(images) == 0:
406
+ image_inputs = {}
407
+ text_inputs = self.tokenizer(text)
408
+ else:
409
+ splitted_images_data = self._split_images(images)
410
+ pixel_values_lst = []
411
+ patch_pixel_values_lst = []
412
+ patch_newline_mask_lst = []
413
+ image_repl_str_lst = []
414
+ image_repl_ids_lst = []
415
+ num_patches = []
416
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
417
+ pixel_values_lst.extend(
418
+ self._convert_images_to_pixel_values([raw_img]))
419
+
420
+ if len(img_patches) > 0:
421
+ patch_pixel_values_lst.extend(
422
+ self._convert_images_to_pixel_values(img_patches,
423
+ is_patch=True))
424
+ num_patches.append(len(img_patches))
425
+
426
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
427
+ 1, len(img_patches), patch_newline_mask)
428
+ image_repl_str_lst.append(image_repl_str)
429
+ image_repl_ids_lst.extend(image_repl_ids)
430
+
431
+ if patch_newline_mask is not None:
432
+ patch_newline_mask_lst.extend(patch_newline_mask)
433
+
434
+ image_inputs = {
435
+ "pixel_values": torch.cat(pixel_values_lst),
436
+ "num_patches": num_patches,
437
+ }
438
+ if patch_pixel_values_lst:
439
+ image_inputs["patch_pixel_values"] = torch.cat(
440
+ patch_pixel_values_lst)
441
+ if patch_newline_mask_lst:
442
+ image_inputs["patch_newline_mask"] = torch.tensor(
443
+ patch_newline_mask_lst, dtype=torch.bool)
444
+
445
+ text = [
446
+ self.replace_placeholder(t, self.image_token,
447
+ image_repl_str_lst) for t in text
448
+ ]
449
+ text_inputs = self.tokenizer(text)
450
+
451
+ return BatchFeature(
452
+ {
453
+ **text_inputs,
454
+ **image_inputs,
455
+ },
456
+ tensor_type=return_tensors,
457
+ )
458
+
459
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
460
+ def batch_decode(self, *args, **kwargs):
461
+ """
462
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
463
+ refer to the docstring of this method for more information.
464
+ """
465
+ return self.tokenizer.batch_decode(*args, **kwargs)
466
+
467
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
468
+ def decode(self, *args, **kwargs):
469
+ """
470
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
471
+ the docstring of this method for more information.
472
+ """
473
+ return self.tokenizer.decode(*args, **kwargs)
474
+
475
+ __all__ = ["Step3VLProcessor"]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|begin▁of▁sentence|>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "is_local": true,
8
+ "legacy": true,
9
+ "model_max_length": 131072,
10
+ "pad_token": "<|end▁of▁sentence|>",
11
+ "sp_model_kwargs": {},
12
+ "tokenizer_class": "TokenizersBackend",
13
+ "tool_parser_type": "qwen3_coder",
14
+ "unk_token": null,
15
+ "use_default_system_prompt": false
16
+ }
vision_encoder.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.activations import ACT2FN
7
+
8
+
9
+ from .configuration_step3p7 import StepRoboticsVisionEncoderConfig
10
+
11
+
12
+
13
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
14
+ """Rotate last dimension halves (used by RoPE)."""
15
+ x = x.reshape(*x.shape[:-1], -1, 2)
16
+ x1, x2 = x.unbind(dim=-1)
17
+ x = torch.stack((-x2, x1), dim=-1)
18
+ return x.reshape(*x.shape[:-2], -1)
19
+
20
+
21
+ def apply_rotary_emb(freqs: torch.Tensor,
22
+ t: torch.Tensor,
23
+ start_index: int = 0,
24
+ scale: float = 1.0,
25
+ seq_dim: int = -2) -> torch.Tensor:
26
+ """Apply 2D rotary embeddings to queries / keys."""
27
+ dtype = t.dtype
28
+
29
+ if t.ndim == 3:
30
+ seq_len = t.shape[seq_dim]
31
+ freqs = freqs[-seq_len:]
32
+
33
+ rot_dim = freqs.shape[-1]
34
+ end_index = start_index + rot_dim
35
+ assert rot_dim <= t.shape[-1], (
36
+ f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
37
+
38
+ t_left, t, t_right = (
39
+ t[..., :start_index],
40
+ t[..., start_index:end_index],
41
+ t[..., end_index:],
42
+ )
43
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
44
+ out = torch.cat((t_left, t, t_right), dim=-1)
45
+ return out.type(dtype)
46
+
47
+
48
+ class EncoderRope2D(nn.Module):
49
+ """Cacheable 2D rotary positional embedding."""
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ max_grid_height: int,
55
+ max_grid_width: int,
56
+ use_cls_token: bool = False,
57
+ theta: Union[int, float] = 10000,
58
+ max_freq: int = 10,
59
+ num_freqs: int = 1,
60
+ theta_rescale_factor: float = 1.0,
61
+ ):
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.max_grid_height = max_grid_height
65
+ self.max_grid_width = max_grid_width
66
+ self.use_cls_token = use_cls_token
67
+ self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
68
+ self.max_freq = max_freq
69
+ self.num_freqs = num_freqs
70
+ cache = self._compute_2d_freqs()
71
+ self.register_buffer("freqs_cache", cache, persistent=False)
72
+
73
+ def _compute_inv_freq(self, base: Union[int, float],
74
+ dim: int) -> torch.Tensor:
75
+
76
+ freqs = 1.0 / (base**(
77
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
78
+ return freqs
79
+
80
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
81
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
82
+ inv_freq)
83
+ freqs = freqs.repeat_interleave(2, dim=-1)
84
+ return freqs
85
+
86
+ def _compute_2d_freqs(self) -> torch.Tensor:
87
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
88
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
89
+ if self.use_cls_token:
90
+ grid_h_range += 1
91
+ grid_w_range += 1
92
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
93
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
94
+ self.max_grid_height, self.max_grid_width, -1)
95
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
96
+ self.max_grid_height, self.max_grid_width, -1)
97
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
98
+ self.max_grid_height * self.max_grid_width, -1)
99
+ if self.use_cls_token:
100
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
101
+ freqs = freqs[None, None, ...]
102
+ return freqs
103
+
104
+ def forward(self, q: torch.Tensor, k: torch.Tensor,
105
+ grid_hw: tuple[int, int]):
106
+ # If grid matches cached shape we reuse directly to avoid recomputation.
107
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
108
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
109
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
110
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(
111
+ torch.long)
112
+ if self.use_cls_token:
113
+ positions = torch.cat(
114
+ [torch.zeros(1, device=q.device), positions + 1], dim=0)
115
+ freqs = self.freqs_cache.index_select(2, positions)
116
+ else:
117
+ freqs = self.freqs_cache
118
+ q = apply_rotary_emb(freqs, q)
119
+ k = apply_rotary_emb(freqs, k)
120
+ return q, k
121
+
122
+
123
+ class EncoderLayerScale(nn.Module):
124
+ """Per-channel residual scaling used when ls_init_value is set."""
125
+
126
+ def __init__(self, dim: int, init_values: float):
127
+ super().__init__()
128
+ self.gamma = nn.Parameter(torch.full((dim,), init_values))
129
+
130
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
131
+ return hidden_states * self.gamma
132
+
133
+
134
+ class EncoderMLP(nn.Module):
135
+ """Feed-forward network used inside each transformer block."""
136
+
137
+ def __init__(self, hidden_size: int, intermediate_size: int,
138
+ hidden_act: str):
139
+ super().__init__()
140
+ self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
141
+ self.act_fn = ACT2FN[hidden_act]
142
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+
146
+ hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
147
+ return hidden_states
148
+
149
+
150
+ class EncoderVisionAttention(nn.Module):
151
+ """Multi-head self attention with optional 2D RoPE."""
152
+
153
+ def __init__(
154
+ self,
155
+ hidden_size: int,
156
+ num_heads: int,
157
+ max_grid_height: int,
158
+ max_grid_width: int,
159
+ use_cls_token: bool = False,
160
+ use_rope2d: bool = True,
161
+ rope_theta: Union[int, float] = 10000,
162
+ rope_max_freq: int = 10,
163
+ rope_num_freqs: int = 1,
164
+ rope_theta_rescale_factor: float = 1.0,
165
+ rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
166
+ ):
167
+ super().__init__()
168
+ if hidden_size % num_heads != 0:
169
+ raise ValueError(
170
+ f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
171
+ )
172
+ self.num_heads = num_heads
173
+ self.head_dim = hidden_size // num_heads
174
+ self.scale = self.head_dim**-0.5
175
+ self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
177
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
178
+
179
+ self.rope = None
180
+ if use_rope2d:
181
+ self.rope = EncoderRope2D(
182
+ dim=self.head_dim,
183
+ max_grid_height=max_grid_height,
184
+ max_grid_width=max_grid_width,
185
+ use_cls_token=use_cls_token,
186
+ theta=rope_theta,
187
+ max_freq=rope_max_freq,
188
+ num_freqs=rope_num_freqs,
189
+ theta_rescale_factor=rope_theta_rescale_factor,
190
+ )
191
+
192
+ def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
193
+ bsz, seq_len, _ = hidden_states.shape
194
+ qkv = F.linear(
195
+ hidden_states,
196
+ self.in_proj_weight,
197
+ self.in_proj_bias,
198
+ )
199
+ q, k, v = qkv.chunk(3, dim=-1)
200
+
201
+ q = q.view(bsz, seq_len, self.num_heads,
202
+ self.head_dim).transpose(1, 2)
203
+ k = k.view(bsz, seq_len, self.num_heads,
204
+ self.head_dim).transpose(1, 2)
205
+ if self.rope is not None:
206
+ q, k = self.rope(q, k, grid_hw=grid_hw)
207
+ v = v.view(bsz, seq_len, self.num_heads,
208
+ self.head_dim).transpose(1, 2)
209
+
210
+ attn_output = F.scaled_dot_product_attention(
211
+ q, k, v, is_causal=False, scale=self.scale)
212
+ attn_output = attn_output.transpose(1, 2).reshape(
213
+ bsz, seq_len, self.num_heads * self.head_dim)
214
+ return self.out_proj(attn_output)
215
+
216
+
217
+ class EncoderVisionBlock(nn.Module):
218
+ """A single Vision Transformer block (self-attention + MLP)."""
219
+
220
+ def __init__(
221
+ self,
222
+ hidden_size: int,
223
+ num_heads: int,
224
+ mlp_ratio: float,
225
+ hidden_act: str,
226
+ layer_norm_eps: float,
227
+ ls_init_value: Optional[float] = None,
228
+ max_grid_height: Optional[int] = None,
229
+ max_grid_width: Optional[int] = None,
230
+ use_cls_token: bool = False,
231
+ use_rope2d: bool = True,
232
+ rope_kwargs: Optional[dict] = None,
233
+ ):
234
+ super().__init__()
235
+ rope_kwargs = rope_kwargs or {}
236
+ self.attn = EncoderVisionAttention(
237
+ hidden_size,
238
+ num_heads,
239
+ max_grid_height=max_grid_height,
240
+ max_grid_width=max_grid_width,
241
+ use_cls_token=use_cls_token,
242
+ use_rope2d=use_rope2d,
243
+ **rope_kwargs,
244
+ )
245
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
246
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
247
+
248
+ intermediate = int(hidden_size * mlp_ratio)
249
+ self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
250
+
251
+ self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
252
+ self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
253
+
254
+ def forward(self, hidden_states: torch.Tensor,
255
+ grid_hw: tuple[int, int]) -> torch.Tensor:
256
+ # breakpoint()
257
+ residual = hidden_states
258
+ hidden_states = self.ln_1(hidden_states)
259
+ hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
260
+ hidden_states = residual + self.ls_1(hidden_states)
261
+
262
+ residual = hidden_states
263
+ hidden_states = self.ln_2(hidden_states)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + self.ls_2(hidden_states)
266
+ return hidden_states
267
+
268
+
269
+ class EncoderVisionTransformer(nn.Module):
270
+ """Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
271
+
272
+ def __init__(
273
+ self,
274
+ embed_dim: int,
275
+ depth: int,
276
+ num_heads: int,
277
+ mlp_ratio: float,
278
+ hidden_act: str,
279
+ layer_norm_eps: float,
280
+ ls_init_value: Optional[float] = None,
281
+ max_grid_height: Optional[int] = None,
282
+ max_grid_width: Optional[int] = None,
283
+ use_cls_token: bool = False,
284
+ use_rope2d: bool = True,
285
+ rope_kwargs: Optional[dict] = None,
286
+ ):
287
+ super().__init__()
288
+ self.layers = depth
289
+ rope_kwargs = rope_kwargs or {}
290
+ self.resblocks = nn.ModuleList([
291
+ EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
292
+ layer_norm_eps,
293
+ max_grid_height=max_grid_height,
294
+ max_grid_width=max_grid_width,
295
+ use_cls_token=use_cls_token,
296
+ use_rope2d=use_rope2d,
297
+ ls_init_value=ls_init_value,
298
+ rope_kwargs=rope_kwargs)
299
+ for _ in range(depth)
300
+ ])
301
+
302
+ def forward(self,
303
+ hidden_states: torch.Tensor,
304
+ grid_hw: tuple[int, int]) -> torch.Tensor:
305
+ for block in self.resblocks:
306
+ hidden_states = block(hidden_states, grid_hw=grid_hw)
307
+ return hidden_states
308
+
309
+
310
+ class StepRoboticsVisionEncoder(nn.Module):
311
+ """
312
+ Vision encoder built from StepRoboticsVisionEncoderConfig.
313
+
314
+ The encoder performs patch embedding followed by a stack of transformer
315
+ blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
316
+ StepRoboticVLConfig.vision_config) are expected.
317
+ """
318
+
319
+ def __init__(self, config: StepRoboticsVisionEncoderConfig):
320
+ super().__init__()
321
+ self.config = config
322
+
323
+ # Align commonly used attributes so downstream code (e.g. StepRoboticVL)
324
+ # can access them without extra renaming.
325
+ self.hidden_size = config.width
326
+ self.num_heads = config.heads
327
+ self.num_hidden_layers = config.layers
328
+ self.patch_size = config.patch_size
329
+ self.image_size = config.image_size
330
+ self.use_cls_token = getattr(config, "use_cls_token", False)
331
+ self.use_rope2d = getattr(config, "use_rope2d", True)
332
+ self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
333
+ self.layer_norm_eps = config.layer_norm_eps
334
+ self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
335
+ self.ls_init_value = getattr(config, "ls_init_value", None)
336
+ self.hidden_act = config.hidden_act
337
+ self.use_ln_pre = getattr(config, "use_ln_pre", False)
338
+ self.use_ln_post = getattr(config, "use_ln_post", True)
339
+
340
+ # Patch embedding.
341
+ self.conv1 = nn.Conv2d(in_channels=config.num_channels,
342
+ out_channels=self.hidden_size,
343
+ kernel_size=self.patch_size,
344
+ stride=self.patch_size,
345
+ bias=False)
346
+
347
+ self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
348
+ self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
349
+
350
+ grid_size = self.image_size // self.patch_size
351
+ self.base_grid = (grid_size, grid_size)
352
+
353
+ if self.use_cls_token:
354
+ self.class_embedding = nn.Parameter(
355
+ torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if self.use_abs_posemb:
360
+ self.posemb_grid_size = self.image_size // self.patch_size
361
+ self.positional_embedding = nn.Parameter(
362
+ (self.hidden_size**-0.5) * torch.randn(
363
+ int(self.use_cls_token) + self.posemb_grid_size**2,
364
+ self.hidden_size,
365
+ ))
366
+
367
+ self.transformer = EncoderVisionTransformer(
368
+ embed_dim=self.hidden_size,
369
+ depth=self.num_hidden_layers,
370
+ num_heads=self.num_heads,
371
+ mlp_ratio=self.mlp_ratio,
372
+ hidden_act=self.hidden_act,
373
+ layer_norm_eps=self.layer_norm_eps,
374
+ ls_init_value=self.ls_init_value,
375
+ max_grid_height=self.base_grid[0],
376
+ max_grid_width=self.base_grid[1],
377
+ use_cls_token=self.use_cls_token,
378
+ use_rope2d=self.use_rope2d,
379
+ rope_kwargs={
380
+ "rope_theta": getattr(config, "rope_theta", 10000),
381
+ "rope_max_freq": getattr(config, "rope_max_freq", 10),
382
+ "rope_num_freqs": getattr(config, "rope_num_freqs", 1),
383
+ "rope_theta_rescale_factor":
384
+ getattr(config, "rope_theta_rescale_factor", 1.0),
385
+ "rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
386
+ },
387
+ )
388
+ self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
389
+ self.hidden_size * 2,
390
+ kernel_size=3,
391
+ stride=2,
392
+ padding=1)
393
+ self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
394
+ self.hidden_size * 4,
395
+ kernel_size=3,
396
+ stride=2,
397
+ padding=1)
398
+
399
+
400
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
401
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
402
+ return self.positional_embedding[None, ...]
403
+
404
+ pos_embed = self.positional_embedding
405
+ if self.use_cls_token:
406
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
407
+
408
+ pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
409
+ self.posemb_grid_size,
410
+ -1).permute(0, 3, 1, 2).contiguous())
411
+ pos_embed = F.interpolate(pos_embed,
412
+ size=(grid_h, grid_w),
413
+ mode="bilinear",
414
+ align_corners=False)
415
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
416
+
417
+ if self.use_cls_token:
418
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
419
+
420
+ return pos_embed[None, ...]
421
+
422
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ Args:
425
+ pixel_values: Image tensor of shape (B, C, H, W).
426
+ layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
427
+ strip_cls_token: If True and cls token is used, remove it from output.
428
+ """
429
+ bsz, _, height, width = pixel_values.shape
430
+ grid_h, grid_w = height // self.patch_size, width // self.patch_size
431
+
432
+ hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
433
+ hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
434
+
435
+ if self.use_cls_token:
436
+ cls_token = self.class_embedding.view(1, 1,
437
+ -1).expand(bsz, -1, -1)
438
+ hidden_state = torch.cat([cls_token, hidden_state], dim=1)
439
+
440
+ if self.use_abs_posemb:
441
+ pos_emb = self.sample_abs_posemb(grid_h, grid_w)
442
+ hidden_state = hidden_state + pos_emb
443
+ hidden_state = self.ln_pre(hidden_state)
444
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
445
+
446
+ if self.use_ln_post:
447
+ hidden_state = self.ln_post(hidden_state)
448
+
449
+ if self.use_cls_token:
450
+ hidden_state = hidden_state[:, 1:, :]
451
+
452
+ return hidden_state