danielhanchen commited on
Commit
6c0b16e
·
verified ·
1 Parent(s): 26a7625

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model:
3
+ - stepfun-ai/Step-3.7-Flash
4
+ license: apache-2.0
5
+ library_name: transformers
6
+ pipeline_tag: image-text-to-text
7
+ language:
8
+ - en
9
+ tags:
10
+ - vision-language
11
+ - unsloth
12
+ - multimodal
13
+ - moe
14
+ ---
15
+ <div>
16
+ <p style="margin-top: 0;margin-bottom: 0;">
17
+ <em><a href="https://docs.unsloth.ai/basics/unsloth-dynamic-v2.0-gguf">Unsloth Dynamic 2.0</a> achieves superior accuracy & outperforms other leading quants.</em>
18
+ </p>
19
+ <div style="display: flex; gap: 5px; align-items: center; ">
20
+ <a href="https://github.com/unslothai/unsloth/">
21
+ <img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="133">
22
+ </a>
23
+ <a href="https://discord.gg/unsloth">
24
+ <img src="https://github.com/unslothai/unsloth/raw/main/images/Discord%20button.png" width="173">
25
+ </a>
26
+ <a href="https://docs.unsloth.ai/">
27
+ <img src="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/main/images/documentation%20green%20button.png" width="143">
28
+ </a>
29
+ </div>
30
+ </div>
31
+
32
+
33
+ **[ModelPage]**: https://static.stepfun.com/blog/step-3.7-flash/
34
+
35
+ ## 1. Introduction
36
+
37
+ Step 3.7 Flash is a 198B-parameter sparse Mixture-of-Experts (MoE) vision-language model that combines a 196B-parameter language backbone with a 1.8B-parameter vision encoder for native image understanding. Engineered for high-frequency production workloads, it activates approximately 11B parameters per token and delivers a throughput of up to 400 tokens per second. Step 3.7 Flash supports a 256k context window and offers three selectable reasoning levels (low, medium, and high) so developers can easily balance speed, cost, and cognitive depth.
38
+
39
+ We built Step 3.7 Flash for developers who need to scale agentic workflows that combine perception, search, and reasoning. It is designed to handle intensive tasks such as parsing massive financial reports in one pass, running multi-step search loops with cross-source verification, or operating concurrent coding agents in high-throughput pipelines.
40
+
41
+ ## 2. Capabilities & Performance
42
+
43
+ ### Multimodal Perception and Verification
44
+
45
+ The model delivers top-tier visual intelligence, securing first place on SimpleVQA (Search) with a 79.2 and achieving frontier parity on V* (Python) at 95.3. These metrics reflect strong visual grounding and retrieval-augmented reasoning beyond basic image description. The model accurately processes dense visual interfaces, such as UI wireframes, application GUIs, and data charts, to map them into structured code. When it encounters an incomplete visual asset, it can independently identify missing data and execute lookups to verify context before returning a factually verified conclusion.
46
+
47
+ ### Workflow Integrity and Tool Orchestration
48
+
49
+ Execution reliability is critical for autonomous agents. Step 3.7 Flash leads the ClawEval-1.1 benchmark with a score of 67.1, which significantly outperforms the next closest competitor at 59.8. This performance demonstrates high resistance to adversarial traps and strict adherence to system policies during multi-turn orchestration. Backed by scores of 49.5 on Toolathlon and 48.1 on HLE w. Tool, this profile ensures high trajectory integrity. Step 3.7 Flash reliably interacts with external APIs and executes long-horizon workflows without drifting from instructions or violating system constraints.
50
+
51
+ ### Code Engineering and Professional Baselines
52
+
53
+ Step 3.7 Flash is built for live engineering tasks and secured a definitive second-place finish on SWE-Bench PRO with a score of 56.3. It can independently trace multi-file repositories, isolate bugs from raw issue reports, and generate functional patches that pass automated unit tests. While evaluations like Terminal-Bench 2.1 (59.5) and GDPVal-AA (45.8) show clear areas for future optimization compared to the absolute peak of the cohort, they establish a dependable baseline for system interactions and structured professional deliverables.
54
+
55
+ ![Step 3.7 Flash benchmark results across General Agent, Agentic Coding, and Multimodal evaluations](assets/benchmarks.png)
56
+
57
+ ## 3. Pricing
58
+
59
+ | Token Type | Price |
60
+ |---|---|
61
+ | Input (cache miss) | $0.20 / M tokens |
62
+ | Input (cache hit) | $0.04 / M tokens |
63
+ | Output | $1.15 / M tokens |
64
+
65
+ ## 4. Availability, Deployment, and Ecosystem
66
+ - Availability: Step 3.7 Flash is available on the StepFun Open Platform — [platform.stepfun.ai](https://platform.stepfun.ai) (Global) and [platform.stepfun.com](https://platform.stepfun.com) (China), OpenRouter, and NVIDIA NIM. StepFun is also partnering with DeepInfra, Fireworks AI, and Modal to expand availability soon.
67
+ - Deployment: Step 3.7 Flash supports flexible deployment across cloud, data center, and local environments. For large-scale production and enterprise use cases, Step 3.7 Flash can be deployed on modern data center infrastructure. For local and workstation scenarios, it can also run on high-memory devices such as NVIDIA DGX Station, AMD Ryzen AI Max+ 395-based systems, and Mac Studio / Macbook Pro devices with at least 128GB unified memory.
68
+ - Ecosystem: Step 3.7 Flash is supported across popular open-source infrastructure for both inference and model development. For inference and serving, developers can use vLLM, SGLang, Hugging Face Transformers, and llama.cpp. For model development & customization workflows, StepFun model support has landed in the NVIDIA Nemo ecosystem, including AutoModel, Megatron Core and Megatron Bridge. Step 3.7 Flash is also available as an NVIDIA NIM inference microservice for on-prem, cloud, or hybrid deployment.
69
+
70
+ ## 5. Examples
71
+
72
+ You can get started with Step 3.7 Flash in minutes using StepFun's API or via other inference providers.
73
+
74
+ > Pick the right `base_url` for your region. StepFun operates two regional platforms with separate API hosts. The `base_url` you pass to the OpenAI client must match the platform where your API key was issued, otherwise requests will be rejected as unauthorized.
75
+ >
76
+ > - **Global**: [platform.stepfun.ai](https://platform.stepfun.ai) — `base_url=https://api.stepfun.ai/v1`
77
+ > - **China**: [platform.stepfun.com](https://platform.stepfun.com) — `base_url=https://api.stepfun.com/v1`
78
+ >
79
+ > To avoid hard-coding the wrong region, the examples below read both the API key and base URL from environment variables. Export them once before running:
80
+ >
81
+ > ```bash
82
+ > export STEP_API_KEY="sk-..."
83
+ > export STEP_BASE_URL="https://api.stepfun.ai/v1" # use https://api.stepfun.com/v1 for the China platform
84
+ > ```
85
+
86
+ ### 5.1 Chat Example
87
+
88
+ ```python
89
+ import os
90
+ from openai import OpenAI
91
+
92
+ client = OpenAI(
93
+ api_key=os.environ["STEP_API_KEY"],
94
+ base_url=os.environ["STEP_BASE_URL"],
95
+ )
96
+
97
+ completion = client.chat.completions.create(
98
+ model="step-3.7-flash",
99
+ messages=[
100
+ {
101
+ "role": "system",
102
+ "content": "You are an AI assistant provided by StepFun. You are good at Chinese, English, and many other languages, and you can see, think, and act to help users get things done.",
103
+ },
104
+ {
105
+ "role": "user",
106
+ "content": "Introduce StepFun's artificial intelligence capabilities."
107
+ },
108
+ ],
109
+ )
110
+
111
+ print(completion)
112
+ ```
113
+
114
+ ### 5.2 Text and Image Input Example
115
+
116
+ ```python
117
+ import os
118
+ from openai import OpenAI
119
+
120
+ client = OpenAI(
121
+ api_key=os.environ["STEP_API_KEY"],
122
+ base_url=os.environ["STEP_BASE_URL"],
123
+ )
124
+
125
+ completion = client.chat.completions.create(
126
+ model="step-3.7-flash",
127
+ messages=[
128
+ {
129
+ "role": "user",
130
+ "content": [
131
+ {"type": "text", "text": "What is in this picture?"},
132
+ {
133
+ "type": "image_url",
134
+ "image_url": {"url": "https://example.com/photo.jpg"},
135
+ },
136
+ ],
137
+ },
138
+ ],
139
+ )
140
+
141
+ print(completion)
142
+ ```
143
+
144
+ ## 6. Local Deployment
145
+
146
+ Step 3.7 Flash is optimized for local inference and supports industry-standard backends including vLLM, SGLang, Hugging Face Transformers and llama.cpp.
147
+
148
+ ### 6.1 vLLM
149
+
150
+ We recommend using StepFun's prebuilt vLLM Docker image with Step 3.7 support.
151
+
152
+ 1. Install vLLM.
153
+
154
+ ```bash
155
+ # via Docker
156
+ docker pull vllm/vllm-openai:stepfun37
157
+ ```
158
+
159
+ 2. Launch the server.
160
+
161
+ - For FP8 model
162
+ ```bash
163
+ vllm serve <MODEL_PATH_OR_HF_ID> \
164
+ --served-model-name step3p7-flash \
165
+ --tensor-parallel-size 8 \
166
+ --enable-expert-parallel \
167
+ --disable-cascade-attn \
168
+ --reasoning-parser step3p5 \
169
+ --enable-auto-tool-choice \
170
+ --tool-call-parser step3p5 \
171
+ --speculative_config '{"method": "mtp", "num_speculative_tokens": 3}' \
172
+ --trust-remote-code
173
+ ```
174
+ - For BF16 model
175
+ ```bash
176
+ vllm serve <MODEL_PATH_OR_HF_ID> \
177
+ --served-model-name step3p7-flash-bf16 \
178
+ --tensor-parallel-size 8 \
179
+ --enable-expert-parallel \
180
+ --disable-cascade-attn \
181
+ --reasoning-parser step3p5 \
182
+ --enable-auto-tool-choice \
183
+ --tool-call-parser step3p5 \
184
+ --speculative_config '{"method": "mtp", "num_speculative_tokens": 3}' \
185
+ --trust-remote-code
186
+ ```
187
+
188
+ - For NVFP4 model
189
+ Compared to standard precisions, running the FP4 quantized version requires modelopt activation and FP8 KV Cache alignment.
190
+ ```bash
191
+ python3 -m vllm.entrypoints.openai.api_server \
192
+ --host 0.0.0.0 \
193
+ --port ${PORT} \
194
+ --model stepfun-ai/Step-3.7-Flash-NVFP4 \
195
+ --served-model-name step3p7 \
196
+ --tensor-parallel-size 4 \
197
+ --gpu-memory-utilization 0.9 \
198
+ --enable-expert-parallel \
199
+ --trust-remote-code \
200
+ --quantization modelopt \
201
+ --kv-cache-dtype fp8 \
202
+ --max-model-len 8192 \
203
+ --reasoning-parser step3p5 \
204
+ --enable-auto-tool-choice \
205
+ --tool-call-parser step3p5 \
206
+ --async-scheduling
207
+ ```
208
+
209
+ ### 6.2 SGLang
210
+
211
+ 1. Install SGLang.
212
+
213
+ ```bash
214
+ # via Docker
215
+ docker pull lmsysorg/sglang:dev-step-3.7-flash
216
+
217
+ # or from source (pip)
218
+ pip install "sglang[all] @ git+https://github.com/sgl-project/sglang.git"
219
+ ```
220
+
221
+ 2. Launch the server.
222
+
223
+ > **Note:** For Blackwell GPUs, `--mm-attention-backend fa4` may be used.
224
+
225
+ - For BF16 model
226
+
227
+ ```bash
228
+ sglang serve --model-path stepfun-ai/Step-3.7-Flash \
229
+ --tp 8 \
230
+ --reasoning-parser step3p5 \
231
+ --tool-call-parser step3p5 \
232
+ --enable-multimodal \
233
+ --speculative-algorithm EAGLE \
234
+ --speculative-num-steps 3 \
235
+ --speculative-eagle-topk 1 \
236
+ --speculative-num-draft-tokens 4 \
237
+ --enable-multi-layer-eagle \
238
+ --trust-remote-code \
239
+ --host 0.0.0.0 \
240
+ --port 8000
241
+ ```
242
+
243
+ - For FP8 model
244
+
245
+ ```bash
246
+ sglang serve --model-path stepfun-ai/Step-3.7-Flash-FP8 \
247
+ --tp 8 \
248
+ --ep 4 \
249
+ --reasoning-parser step3p5 \
250
+ --tool-call-parser step3p5 \
251
+ --enable-multimodal \
252
+ --speculative-algorithm EAGLE \
253
+ --speculative-num-steps 3 \
254
+ --speculative-eagle-topk 1 \
255
+ --speculative-num-draft-tokens 4 \
256
+ --enable-multi-layer-eagle \
257
+ --trust-remote-code \
258
+ --host 0.0.0.0 \
259
+ --port 8000
260
+ ```
261
+
262
+ - For NVFP4 model
263
+
264
+ ```bash
265
+ sglang serve --model-path stepfun-ai/Step-3.7-Flash-NVFP4 \
266
+ --tp 4 --ep 4 \
267
+ --moe-runner-backend flashinfer_trtllm \
268
+ --kv-cache-dtype fp8_e4m3 \
269
+ --quantization modelopt_fp4 \
270
+ --trust-remote-code \
271
+ --reasoning-parser step3p5 \
272
+ --tool-call-parser step3p5 \
273
+ --attention-backend trtllm_mha
274
+ ```
275
+
276
+ ### 6.3 Transformers (Debug / Verification)
277
+
278
+ Use this snippet for quick functional verification. For high-throughput serving, use vLLM or SGLang.
279
+
280
+ > **Note:** Deployment of this model requires `transformers` 5.0 or later.
281
+
282
+ ```python
283
+ from transformers import AutoProcessor, AutoModelForCausalLM
284
+
285
+ MODEL_PATH = "<MODEL_PATH_OR_HF_ID>"
286
+
287
+ # 1. Setup
288
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
289
+ model = AutoModelForCausalLM.from_pretrained(
290
+ MODEL_PATH,
291
+ device_map="auto",
292
+ dtype="auto",
293
+ trust_remote_code=True
294
+ )
295
+
296
+ # 2. Prepare Input
297
+ messages = [
298
+ {
299
+ "role": "user",
300
+ "content": [
301
+ {"type": "image", "url": "https://example.com/photo.jpg"},
302
+ {"type": "text", "text": "What is in this picture?"}
303
+ ]
304
+ },
305
+ ]
306
+ inputs = processor.apply_chat_template(
307
+ messages,
308
+ tokenize=True,
309
+ add_generation_prompt=True,
310
+ return_dict=True,
311
+ return_tensors="pt",
312
+ ).to(model.device)
313
+
314
+ # 3. Generate
315
+ generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
316
+ output_text = processor.decode(generated_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
317
+
318
+ print(output_text)
319
+ ```
320
+
321
+ ### 6.4 llama.cpp
322
+
323
+ **System Requirements**
324
+
325
+ GGUF Model Weights:
326
+
327
+ | Component | Quantization | File Size |
328
+ |---|---|---|
329
+ | Language Model | Q4_K_S | 111.5 GB |
330
+ | Language Model | IQ4_XS | 104.99 GB |
331
+ | Language Model | Q3_K_L | 102.5 GB |
332
+ | Multimodal Projector | FP16 | 3.97 GB |
333
+
334
+ - **Runtime Overhead:** ~7 GB
335
+ - **Minimum unified memory / VRAM:** 120 GB (e.g., Mac Studio, NVIDIA DGX Station, AMD Ryzen AI Max+ 395)
336
+ - **Recommended:** 128 GB unified memory
337
+
338
+ **Steps**
339
+
340
+ 1. Use llama.cpp:
341
+
342
+ ```bash
343
+ git clone https://github.com/stepfun-ai/llama.cpp.git
344
+ cd llama.cpp
345
+ git checkout -b step3.7 origin/step3.7
346
+ ```
347
+
348
+ 2. Build llama.cpp on Mac:
349
+
350
+ ```bash
351
+ cmake -B build-macos -S . \
352
+ -DCMAKE_BUILD_TYPE=Release \
353
+ -DBUILD_SHARED_LIBS=ON \
354
+ -DLLAMA_BUILD_SERVER=ON \
355
+ -DLLAMA_BUILD_TESTS=ON \
356
+ -DGGML_METAL=ON \
357
+ -DGGML_METAL_EMBED_LIBRARY=ON \
358
+ -DGGML_BLAS=ON \
359
+ -DGGML_BLAS_VENDOR=Apple \
360
+ -DGGML_ACCELERATE=ON \
361
+ -DGGML_NATIVE=ON
362
+ cmake --build build-macos -j8
363
+ ```
364
+
365
+ 3. Build llama.cpp on DGX-Spark:
366
+
367
+ ```bash
368
+ cmake -S . -B build-cuda \
369
+ -DCMAKE_BUILD_TYPE=Release \
370
+ -DGGML_CUDA=ON \
371
+ -DGGML_CUDA_GRAPHS=ON \
372
+ -DGGML_CUDA_FORCE_MMQ=ON \
373
+ -DLLAMA_OPENSSL=OFF \
374
+ -DLLAMA_BUILD_COMMON=ON \
375
+ -DLLAMA_BUILD_TOOLS=ON \
376
+ -DLLAMA_BUILD_SERVER=ON \
377
+ -DLLAMA_BUILD_EXAMPLES=OFF \
378
+ -DLLAMA_BUILD_TESTS=OFF
379
+ cmake --build build-cuda -j8
380
+ ```
381
+
382
+ 4. Build llama.cpp on AMD Windows:
383
+
384
+ ```bash
385
+ cmake -S . -B build-vulkan \
386
+ -DCMAKE_BUILD_TYPE=Release \
387
+ -DGGML_VULKAN=ON \
388
+ -DGGML_NATIVE=ON \
389
+ -DLLAMA_BUILD_SERVER=ON \
390
+ -DLLAMA_BUILD_UI=OFF \
391
+ -DLLAMA_BUILD_TOOLS=ON
392
+ cmake --build build-vulkan -j8
393
+ ```
394
+
395
+ 5. Run with `llama-cli`:
396
+
397
+ ```bash
398
+ ./llama-cli -m Step3.7_Q4_K_S.gguf -b 2048 -ub 2048 -fa on --temp 1.0 -p "What's your name?"
399
+ ```
400
+
401
+ 6. Test performance with `llama-batched-bench`:
402
+
403
+ ```bash
404
+ ./llama-batched-bench -m step3.7_Q4_K_S.gguf -c 32768 -b 2048 -ub 2048 -npp 0,2048,8192,16384,32768 -ntg 128 -npl 1
405
+ ```
406
+
407
+ ## 7. Using Step 3.7 Flash on Agent Platforms
408
+
409
+ You can use Step 3.7 Flash on Agent platforms such as Hermes Agent, OpenClaw, Kilo Code, and more.
410
+
411
+ ## 8. Getting in Touch
412
+
413
+ As we work to shape the future of AGI by expanding broad model capabilities, we want to ensure we are solving the right problems. We invite you to be part of this continuous feedback loop — your insights directly influence our priorities.
414
+
415
+ - **Join the Conversation:** Our [Discord](https://discord.gg/RcMJhNVAQc) community is the primary hub for brainstorming future architectures, proposing capabilities, and getting early access updates 🚀
416
+ - **Report Friction:** Encountering limitations? You can open an issue or start a discussion on GitHub / HuggingFace, or flag it directly in our Discord support channels.
417
+
418
+ ## 📄 License
419
+
420
+ This project is open-sourced under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0).
421
+
assets/benchmarks.png ADDED

Git LFS Details

  • SHA256: 3d26171162c0421a57c6c2c22074b9b276b626c5d90fe3a62e9fceb8ad988ae7
  • Pointer size: 131 Bytes
  • Size of remote file: 322 kB
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_message_content(message) %}{% if message.content is none %}{{- '' }}{% elif message.content is string %}{{- message.content }}{% elif message.content is mapping %}{{- message.content['value'] if 'value' in message.content else message.content['text'] }}{% elif message.content is iterable %}{% set ns = namespace(needs_text_separator=false) %}{% for item in message.content %}{% if item.type == 'text' %}{% if ns.needs_text_separator %}{{- ' ' }}{% endif %}{{- item['value'] if 'value' in item else item['text'] }}{% set ns.needs_text_separator = true %}{% elif item.type == 'image' %}<im_patch>{% set ns.needs_text_separator = false %}{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {{bos_token}}{%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if reasoning_effort is defined %}
5
+ {{- "Reasoning: " + reasoning_effort + '\n\n' }}
6
+ {%- endif %}
7
+ {%- if messages[0].role == 'system' %}
8
+ {{- render_message_content(messages[0]) + '\n\n' }}
9
+ {%- endif %}
10
+ {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
11
+ {%- for tool in tools %}
12
+ {{- "\n" }}
13
+ {{- tool | tojson(ensure_ascii=False) }}
14
+ {%- endfor %}
15
+ {{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
16
+ {%- else %}
17
+ {%- if messages[0].role == 'system' %}
18
+ {{- '<|im_start|>system\n' }}
19
+ {%- if reasoning_effort is defined %}
20
+ {{- "Reasoning: " + reasoning_effort + '\n\n' }}
21
+ {%- endif %}
22
+ {{- render_message_content(messages[0]) + '<|im_end|>\n' }}
23
+ {%- elif reasoning_effort is defined %}
24
+ {{- '<|im_start|>system\n' + "Reasoning: " + reasoning_effort + '\n\n' + '<|im_end|>\n' }}
25
+ {%- endif %}
26
+ {%- endif %}
27
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
28
+ {%- for message in messages[::-1] %}
29
+ {%- set index = (messages|length - 1) - loop.index0 %}
30
+ {%- if ns.multi_step_tool and message.role == "user" and render_message_content(message) is string and not(render_message_content(message).startswith('<tool_response>') and render_message_content(message).endswith('</tool_response>')) %}
31
+ {%- set ns.multi_step_tool = false %}
32
+ {%- set ns.last_query_index = index %}
33
+ {%- endif %}
34
+ {%- endfor %}
35
+ {%- for message in messages %}
36
+ {%- set content = render_message_content(message) %}
37
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
38
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
39
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
40
+ {%- elif message.role == "assistant" %}
41
+ {%- if message.reasoning_content is string %}
42
+ {%- set reasoning_content = message.reasoning_content %}
43
+ {%- else %}
44
+ {%- if '</think>' in content %}
45
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
46
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
47
+ {%- else %}
48
+ {%- set reasoning_content = '' %}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- if loop.index0 > ns.last_query_index %}
52
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
53
+ {%- else %}
54
+ {{- '<|im_start|>' + message.role + '\n' + content }}
55
+ {%- endif %}
56
+ {%- if message.tool_calls %}
57
+ {%- for tool_call in message.tool_calls %}
58
+ {%- if tool_call.function is defined %}
59
+ {%- set tool_call = tool_call.function %}
60
+ {%- endif %}
61
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
62
+ {%- if tool_call.arguments is defined %}
63
+ {%- set arguments = tool_call.arguments | fromjson if tool_call.arguments is string else tool_call.arguments %}
64
+ {%- for args_name, args_value in arguments|items %}
65
+ {{- '<parameter=' + args_name + '>\n' }}
66
+ {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
67
+ {{- args_value }}
68
+ {{- '\n</parameter>\n' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '</function>\n</tool_call>' }}
72
+ {%- endfor %}
73
+ {%- endif %}
74
+ {{- '<|im_end|>\n' }}
75
+ {%- elif message.role == "tool" %}
76
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
77
+ {{- '<|im_start|>tool_response\n' }}
78
+ {%- endif %}
79
+ {{- '<tool_response>' }}
80
+ {{- content }}
81
+ {{- '</tool_response>' }}
82
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
83
+ {{- '<|im_end|>\n' }}
84
+ {%- endif %}
85
+ {%- endif %}
86
+ {%- endfor %}
87
+ {%- if add_generation_prompt %}
88
+ {{- '<|im_start|>assistant\n<think>\n' }}
89
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step3p7ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_step3p7.Step3p7Config",
7
+ "AutoModelForCausalLM": "modeling_step3p7.Step3p7ForConditionalGeneration",
8
+ "AutoProcessor": "processing_step3.Step3VLProcessor"
9
+ },
10
+ "hidden_size": 4096,
11
+ "im_end_token": "<im_end>",
12
+ "im_patch_token": "<im_patch>",
13
+ "im_start_token": "<im_start>",
14
+ "image_token_id": 128001,
15
+ "image_token_len": 169,
16
+ "max_position_embeddings": 262144,
17
+ "model_type": "step3p7",
18
+ "pad_token_id": 2,
19
+ "patch_token_len": 81,
20
+ "projector_bias": false,
21
+ "text_config": {
22
+ "architectures": [
23
+ "Step3p5ForCausalLM"
24
+ ],
25
+ "att_impl_type": "GQA",
26
+ "attention_dropout": 0.0,
27
+ "attention_other_setting": {
28
+ "attention_type": "sliding_attention",
29
+ "head_dim": 128,
30
+ "num_attention_groups": 8,
31
+ "num_attention_heads": 96,
32
+ "true_head_dim": 128
33
+ },
34
+ "bos_token_id": 0,
35
+ "torch_dtype": "bfloat16",
36
+ "eos_token_id": [
37
+ 1,
38
+ 2,
39
+ 128007
40
+ ],
41
+ "head_dim": 128,
42
+ "hidden_size": 4096,
43
+ "intermediate_size": 11264,
44
+ "layer_types": [
45
+ "full_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "full_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "sliding_attention",
53
+ "full_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "sliding_attention",
57
+ "full_attention",
58
+ "sliding_attention",
59
+ "sliding_attention",
60
+ "sliding_attention",
61
+ "full_attention",
62
+ "sliding_attention",
63
+ "sliding_attention",
64
+ "sliding_attention",
65
+ "full_attention",
66
+ "sliding_attention",
67
+ "sliding_attention",
68
+ "sliding_attention",
69
+ "full_attention",
70
+ "sliding_attention",
71
+ "sliding_attention",
72
+ "sliding_attention",
73
+ "full_attention",
74
+ "sliding_attention",
75
+ "sliding_attention",
76
+ "sliding_attention",
77
+ "full_attention",
78
+ "sliding_attention",
79
+ "sliding_attention",
80
+ "sliding_attention",
81
+ "full_attention",
82
+ "sliding_attention",
83
+ "sliding_attention",
84
+ "sliding_attention",
85
+ "full_attention",
86
+ "sliding_attention",
87
+ "sliding_attention",
88
+ "sliding_attention",
89
+ "full_attention",
90
+ "sliding_attention",
91
+ "sliding_attention",
92
+ "sliding_attention"
93
+ ],
94
+ "max_position_embeddings": 262144,
95
+ "max_seq_len": 262144,
96
+ "model_type": "step3p5",
97
+ "moe_every_n_layer": 1,
98
+ "moe_intermediate_size": 1280,
99
+ "moe_layer_offset": 0,
100
+ "moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
101
+ "moe_num_experts": 288,
102
+ "moe_router_activation": "sigmoid",
103
+ "moe_router_scaling_factor": 3.0,
104
+ "moe_top_k": 8,
105
+ "need_fp32_gate": true,
106
+ "norm_expert_weight": true,
107
+ "num_attention_groups": 8,
108
+ "num_attention_heads": 64,
109
+ "num_hidden_layers": 45,
110
+ "num_nextn_predict_layers": 3,
111
+ "pad_token_id": 1,
112
+ "partial_rotary_factors": [
113
+ 0.5,
114
+ 1.0,
115
+ 1.0,
116
+ 1.0,
117
+ 0.5,
118
+ 1.0,
119
+ 1.0,
120
+ 1.0,
121
+ 0.5,
122
+ 1.0,
123
+ 1.0,
124
+ 1.0,
125
+ 0.5,
126
+ 1.0,
127
+ 1.0,
128
+ 1.0,
129
+ 0.5,
130
+ 1.0,
131
+ 1.0,
132
+ 1.0,
133
+ 0.5,
134
+ 1.0,
135
+ 1.0,
136
+ 1.0,
137
+ 0.5,
138
+ 1.0,
139
+ 1.0,
140
+ 1.0,
141
+ 0.5,
142
+ 1.0,
143
+ 1.0,
144
+ 1.0,
145
+ 0.5,
146
+ 1.0,
147
+ 1.0,
148
+ 1.0,
149
+ 0.5,
150
+ 1.0,
151
+ 1.0,
152
+ 1.0,
153
+ 0.5,
154
+ 1.0,
155
+ 1.0,
156
+ 1.0,
157
+ 0.5,
158
+ 1.0,
159
+ 1.0,
160
+ 1.0
161
+ ],
162
+ "rms_norm_eps": 1e-05,
163
+ "rope_parameters": {
164
+ "factor": 2.0,
165
+ "high_freq_factor": 32.0,
166
+ "low_freq_factor": 1.0,
167
+ "original_max_position_embeddings": 131072,
168
+ "rope_theta": [
169
+ 5000000.0,
170
+ 10000.0,
171
+ 10000.0,
172
+ 10000.0,
173
+ 5000000.0,
174
+ 10000.0,
175
+ 10000.0,
176
+ 10000.0,
177
+ 5000000.0,
178
+ 10000.0,
179
+ 10000.0,
180
+ 10000.0,
181
+ 5000000.0,
182
+ 10000.0,
183
+ 10000.0,
184
+ 10000.0,
185
+ 5000000.0,
186
+ 10000.0,
187
+ 10000.0,
188
+ 10000.0,
189
+ 5000000.0,
190
+ 10000.0,
191
+ 10000.0,
192
+ 10000.0,
193
+ 5000000.0,
194
+ 10000.0,
195
+ 10000.0,
196
+ 10000.0,
197
+ 5000000.0,
198
+ 10000.0,
199
+ 10000.0,
200
+ 10000.0,
201
+ 5000000.0,
202
+ 10000.0,
203
+ 10000.0,
204
+ 10000.0,
205
+ 5000000.0,
206
+ 10000.0,
207
+ 10000.0,
208
+ 10000.0,
209
+ 5000000.0,
210
+ 10000.0,
211
+ 10000.0,
212
+ 10000.0,
213
+ 5000000.0,
214
+ 10000.0,
215
+ 10000.0,
216
+ 10000.0
217
+ ],
218
+ "rope_type": "llama3"
219
+ },
220
+ "rope_theta": [
221
+ 5000000.0,
222
+ 10000.0,
223
+ 10000.0,
224
+ 10000.0,
225
+ 5000000.0,
226
+ 10000.0,
227
+ 10000.0,
228
+ 10000.0,
229
+ 5000000.0,
230
+ 10000.0,
231
+ 10000.0,
232
+ 10000.0,
233
+ 5000000.0,
234
+ 10000.0,
235
+ 10000.0,
236
+ 10000.0,
237
+ 5000000.0,
238
+ 10000.0,
239
+ 10000.0,
240
+ 10000.0,
241
+ 5000000.0,
242
+ 10000.0,
243
+ 10000.0,
244
+ 10000.0,
245
+ 5000000.0,
246
+ 10000.0,
247
+ 10000.0,
248
+ 10000.0,
249
+ 5000000.0,
250
+ 10000.0,
251
+ 10000.0,
252
+ 10000.0,
253
+ 5000000.0,
254
+ 10000.0,
255
+ 10000.0,
256
+ 10000.0,
257
+ 5000000.0,
258
+ 10000.0,
259
+ 10000.0,
260
+ 10000.0,
261
+ 5000000.0,
262
+ 10000.0,
263
+ 10000.0,
264
+ 10000.0,
265
+ 5000000.0,
266
+ 10000.0,
267
+ 10000.0,
268
+ 10000.0
269
+ ],
270
+ "share_expert_dim": 1280,
271
+ "sink": false,
272
+ "sliding_window": 512,
273
+ "swiglu_limits": [
274
+ 0.0,
275
+ 0.0,
276
+ 0.0,
277
+ 0.0,
278
+ 0.0,
279
+ 0.0,
280
+ 0.0,
281
+ 0.0,
282
+ 0.0,
283
+ 0.0,
284
+ 0.0,
285
+ 0.0,
286
+ 0.0,
287
+ 0.0,
288
+ 0.0,
289
+ 0.0,
290
+ 0.0,
291
+ 0.0,
292
+ 0.0,
293
+ 0.0,
294
+ 0.0,
295
+ 0.0,
296
+ 0.0,
297
+ 0.0,
298
+ 0.0,
299
+ 0.0,
300
+ 0.0,
301
+ 0.0,
302
+ 0.0,
303
+ 0.0,
304
+ 0.0,
305
+ 0.0,
306
+ 0.0,
307
+ 0.0,
308
+ 0.0,
309
+ 0.0,
310
+ 0.0,
311
+ 0.0,
312
+ 0.0,
313
+ 0.0,
314
+ 0.0,
315
+ 0.0,
316
+ 0.0,
317
+ 7,
318
+ 7,
319
+ 0.0,
320
+ 0.0,
321
+ 0.0
322
+ ],
323
+ "swiglu_limits_shared": [
324
+ 0.0,
325
+ 0.0,
326
+ 0.0,
327
+ 0.0,
328
+ 0.0,
329
+ 0.0,
330
+ 0.0,
331
+ 0.0,
332
+ 0.0,
333
+ 0.0,
334
+ 0.0,
335
+ 0.0,
336
+ 0.0,
337
+ 0.0,
338
+ 0.0,
339
+ 0.0,
340
+ 0.0,
341
+ 0.0,
342
+ 0.0,
343
+ 0.0,
344
+ 0.0,
345
+ 0.0,
346
+ 0.0,
347
+ 0.0,
348
+ 0.0,
349
+ 0.0,
350
+ 0.0,
351
+ 0.0,
352
+ 0.0,
353
+ 0.0,
354
+ 0.0,
355
+ 0.0,
356
+ 0.0,
357
+ 0.0,
358
+ 0.0,
359
+ 0.0,
360
+ 0.0,
361
+ 0.0,
362
+ 0.0,
363
+ 0.0,
364
+ 0.0,
365
+ 0.0,
366
+ 0.0,
367
+ 16,
368
+ 16,
369
+ 0.0,
370
+ 0.0,
371
+ 0.0
372
+ ],
373
+ "use_head_wise_attn_gate": true,
374
+ "use_mfa": false,
375
+ "use_moe": true,
376
+ "use_moe_router_bias": true,
377
+ "use_qk_norm": false,
378
+ "use_rope_layers": [],
379
+ "vocab_size": 128896,
380
+ "yarn_only_types": [
381
+ "full_attention"
382
+ ]
383
+ },
384
+ "transformers_version": "5.10.0.dev0",
385
+ "understand_projector_stride": 2,
386
+ "unsloth_fixed": true,
387
+ "use_im_start_end": "true",
388
+ "vision_config": {
389
+ "heads": 16,
390
+ "hidden_act": "quick_gelu",
391
+ "image_size": 728,
392
+ "layer_norm_eps": 1e-05,
393
+ "layers": 47,
394
+ "ls_init_value": 0.1,
395
+ "mlp_ratio": 5.833333333333333,
396
+ "model_type": "perception_encoder",
397
+ "num_channels": 3,
398
+ "output_dim": null,
399
+ "patch_size": 14,
400
+ "pool_type": "none",
401
+ "ues_cls_token": false,
402
+ "use_abs_posemb": true,
403
+ "use_cls_token": false,
404
+ "use_ln_post": false,
405
+ "use_ln_pre": true,
406
+ "use_rope2d": true,
407
+ "width": 1536
408
+ },
409
+ "vision_select_layer": -1
410
+ }
configuration_step3p7.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Sequence, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class StepRoboticsVisionEncoderConfig(PretrainedConfig):
6
+ model_type = "perception_encoder"
7
+
8
+ def __init__(
9
+ self,
10
+ width=1536,
11
+ layers=47,
12
+ heads=16,
13
+ num_channels=3,
14
+ image_size=728,
15
+ mlp_ratio = 8960/1536,
16
+ patch_size=14,
17
+ hidden_act="quick_gelu",
18
+ layer_norm_eps=1e-5,
19
+ ues_cls_token=False,
20
+ use_cls_token: Optional[bool] = None,
21
+ use_ln_pre=True,
22
+ use_ln_post=False,
23
+ use_abs_posemb=True,
24
+ use_rope2d=True,
25
+ ls_init_value=0.1,
26
+ **kwargs,
27
+ ):
28
+ self.width = width
29
+ self.layers = layers
30
+ self.heads = heads
31
+ self.num_channels = num_channels
32
+ self.patch_size = patch_size
33
+ self.image_size = image_size
34
+ self.mlp_ratio = mlp_ratio
35
+ self.layer_norm_eps = layer_norm_eps
36
+ self.hidden_act = hidden_act
37
+ if use_cls_token is None:
38
+ use_cls_token = ues_cls_token
39
+ self.ues_cls_token = use_cls_token
40
+ self.use_cls_token = use_cls_token
41
+ self.use_ln_pre = use_ln_pre
42
+ self.ls_init_value = ls_init_value
43
+ self.use_ln_post = use_ln_post
44
+ self.use_abs_posemb = use_abs_posemb
45
+ self.use_rope2d = use_rope2d
46
+ super().__init__(**kwargs)
47
+
48
+
49
+ class Step3p7TextConfig(PretrainedConfig):
50
+ model_type = "step3p5"
51
+ architectures = ["Step3p5ForCausalLM"]
52
+
53
+ def __init__(
54
+ self,
55
+ hidden_size: int = 4096,
56
+ intermediate_size: int = 11264,
57
+ num_attention_heads: int = 64,
58
+ num_attention_groups: int = 8,
59
+ num_hidden_layers: int = 45,
60
+ max_seq_len: int = 128000,
61
+ vocab_size: int = 128815,
62
+ rms_norm_eps: float = 1e-5,
63
+ moe_intermediate_size: int = 1280,
64
+ moe_num_experts: int = 288,
65
+ moe_top_k: int = 8,
66
+ rope_theta: float = 10000,
67
+ rope_scaling: Optional[dict[str, Any]] = None,
68
+ max_position_embeddings: int = 128000,
69
+ share_expert_dims: int = 1280,
70
+ share_expert_dim: Optional[int] = None,
71
+ head_dim: int = 128,
72
+ norm_expert_weight: bool = True,
73
+ layer_types: list[str] = None,
74
+ sliding_window: Optional[int] = None,
75
+ pad_token_id: int = 1,
76
+ attention_dropout: float = 0.0,
77
+ use_head_wise_attn_gate: bool = False,
78
+ use_moe_router_bias: bool = False,
79
+ moe_router_activation: str = "softmax",
80
+ moe_router_scaling_factor: float = 1.0,
81
+ need_fp32_gate: bool = False,
82
+ attention_other_setting: Optional[dict[str, Any]] = None,
83
+ swiglu_limits: Optional[list[Optional[float]]] = None,
84
+ swiglu_limits_shared: Optional[list[Optional[float]]] = None,
85
+ use_rope_layers: Optional[list[bool]] = None,
86
+ yarn_only_types: Optional[list[str]] = None,
87
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
88
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
89
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
90
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
91
+ **kwargs,
92
+ ) -> None:
93
+ torch_dtype = kwargs.get("torch_dtype")
94
+ trim_layer_types = _normalize_per_layer_values(layer_types,
95
+ num_hidden_layers)
96
+ if isinstance(rope_scaling, dict):
97
+ rope_scaling = dict(rope_scaling)
98
+ if share_expert_dim is None:
99
+ share_expert_dim = share_expert_dims
100
+ self.hidden_size = hidden_size
101
+ self.intermediate_size = intermediate_size
102
+ self.num_attention_heads = num_attention_heads
103
+ self.num_attention_groups = num_attention_groups
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.max_seq_len = max_seq_len
106
+ self.vocab_size = vocab_size
107
+ self.rms_norm_eps = rms_norm_eps
108
+ self.moe_intermediate_size = moe_intermediate_size
109
+ self.moe_num_experts = moe_num_experts
110
+ self.moe_top_k = moe_top_k
111
+ self.rope_theta = rope_theta
112
+ self.rope_scaling = rope_scaling
113
+ self.max_position_embeddings = max_position_embeddings
114
+ self.share_expert_dim = share_expert_dim
115
+ self.head_dim = head_dim
116
+ self.norm_expert_weight = norm_expert_weight
117
+ self.moe_layers_enum = moe_layers_enum
118
+ self.layer_types = trim_layer_types
119
+ self.sliding_window = sliding_window
120
+ self.pad_token_id = pad_token_id
121
+ self.attention_dropout = attention_dropout
122
+ self.use_head_wise_attn_gate = use_head_wise_attn_gate
123
+ self.use_moe_router_bias = use_moe_router_bias
124
+ self.moe_router_activation = moe_router_activation
125
+ self.moe_router_scaling_factor = moe_router_scaling_factor
126
+ self.need_fp32_gate = need_fp32_gate
127
+ self.attention_other_setting = attention_other_setting
128
+ self.swiglu_limits = swiglu_limits
129
+ self.swiglu_limits_shared = swiglu_limits_shared
130
+ self.use_rope_layers = use_rope_layers
131
+ self.yarn_only_types = yarn_only_types
132
+ super().__init__(**kwargs)
133
+ if torch_dtype is not None:
134
+ self.torch_dtype = torch_dtype
135
+ self.layer_types = layer_types
136
+
137
+ def to_dict(self):
138
+ output = super().to_dict()
139
+ torch_dtype = getattr(self, "torch_dtype", None)
140
+ if torch_dtype is not None:
141
+ output["torch_dtype"] = torch_dtype
142
+ return output
143
+
144
+
145
+ def _normalize_per_layer_values(
146
+ values: Optional[Sequence[Any]],
147
+ num_hidden_layers: int,
148
+ ) -> Optional[list[Any]]:
149
+ if values is None:
150
+ return None
151
+ normalized = list(values)
152
+ if not normalized:
153
+ return normalized
154
+ if len(normalized) < num_hidden_layers:
155
+ normalized.extend([normalized[-1]] *
156
+ (num_hidden_layers - len(normalized)))
157
+ # Some checkpoints keep MTP/spec layer entries after the decoder layers.
158
+ # This config only builds num_hidden_layers decoder layers, and HF strict
159
+ # validation requires per-layer fields to match that decoder count.
160
+ return normalized[:num_hidden_layers]
161
+
162
+ class Step3p7Config(PretrainedConfig):
163
+ # This loader is a compatibility shim for original Step VL checkpoints
164
+ # whose top-level config model_type is `step3p7`.
165
+ model_type = "step3p7"
166
+
167
+ def __init__(
168
+ self,
169
+ vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
170
+ text_config: Optional[Union[dict, Step3p7TextConfig]] = None,
171
+ understand_projector_stride: int = 2,
172
+ projector_bias: bool = False,
173
+ image_token_id: int = 151679,
174
+ **kwargs,
175
+ ) -> None:
176
+ shared_rope_scaling = kwargs.get("rope_scaling")
177
+ if isinstance(shared_rope_scaling, dict):
178
+ shared_rope_scaling = dict(shared_rope_scaling)
179
+
180
+ if vision_config is None:
181
+ vision_config = StepRoboticsVisionEncoderConfig()
182
+ elif isinstance(vision_config, dict):
183
+ vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
184
+ self.vision_config = vision_config
185
+
186
+ if text_config is None:
187
+ text_config = Step3p7TextConfig(rope_scaling=shared_rope_scaling)
188
+ elif isinstance(text_config, dict):
189
+ text_config = dict(text_config)
190
+ if shared_rope_scaling is not None and "rope_scaling" not in text_config:
191
+ text_config["rope_scaling"] = shared_rope_scaling
192
+ text_config = Step3p7TextConfig(**text_config)
193
+ elif shared_rope_scaling is not None and text_config.rope_scaling is None:
194
+ text_config.rope_scaling = dict(shared_rope_scaling)
195
+ self.text_config = text_config
196
+
197
+ rope_scaling = kwargs.get("rope_scaling")
198
+ if isinstance(rope_scaling, dict):
199
+ kwargs["rope_scaling"] = dict(rope_scaling)
200
+
201
+ self.understand_projector_stride = understand_projector_stride
202
+ self.projector_bias = projector_bias
203
+ self.hidden_size = text_config.hidden_size
204
+ self.max_position_embeddings = text_config.max_position_embeddings
205
+ self.image_token_id = image_token_id
206
+ # Help Auto classes find the correct implementation when saving/loading.
207
+ super().__init__(**kwargs)
model-00001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a2d47133d0ffa22f50a24ad4974c559c1b31f26f5baca24fc4f4dfe198b46c6
3
+ size 924094096
model-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67c13067deed696b62763643b7d531fd2cfde4c6e81cfcaba5460551e510d0af
3
+ size 9808156008
model-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f3567584681f4d2792e4d949c9440198f792a5afd93220d3770b509728b6ef1
3
+ size 18557475928
model-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d035fb813758ed63f1d537bbf41f6cbb2c5c8eb05f187de18a448c7766a64960
3
+ size 18624846944
model-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9a2c0daa3a49fc88e53e0b6419f2e4db7e412f40760488d49ca0f834fe83725
3
+ size 18557475928
model-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fee76c5fb28547ad0d4094a0bae7755a292dd439cc23b054210a24c965b093f
3
+ size 18624846976
model-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccad5d228ec280d95419fbbcf2590f2cdfc4c932a7249a7669dc7f509dc7fe66
3
+ size 18557475968
model-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d537acabde8deace533c23df8e43268f1423b41e7b6e27c79232955283f4e44
3
+ size 18624846976
model-00009.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48be665fd9bce6e2fdac06d03a1a9916794fce4231b03009e6a4cfca1055a2c9
3
+ size 18557475968
model-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd61c7f6d62725005a07fe778dc572b9642972054424b2a12d1494e7ca241d91
3
+ size 18624846976
model-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51c5fe0dce035dd7fc01333fe3ba0fff46e65412ad7a71c09fa8e2992b8d26a7
3
+ size 18557475968
model-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f3e890ede3949af958a72da0beb99db6834853ee22978eb7782a600d013abac
3
+ size 18624846976
model-00013.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98802ed9091498df2ef7a73b2697f5ac275a64892d984b9045a0a99f7b459c78
3
+ size 18557475968
model-00014.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:459e5814b710f888b6763385fb179d52f746f59e702dd165f0c5d5cc73417b03
3
+ size 18624846976
model-00015.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13a51f345afa384b930387d40ac79ed6614f02129d61a9714e213f726970f47c
3
+ size 18557475968
model-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3475a9dcaff31af71b6183371f8e355bdedea5f4dbb1ade6e84dcfe28ddc9517
3
+ size 18624846976
model-00017.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92917af53ef59cd99d43d49de2ffcbec3d21db7ebc59107a66aa2438da2eca14
3
+ size 18557475968
model-00018.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aba73fb3d39556bba83fe864f7a7b60e8b2085204b074101500531e69525ee4f
3
+ size 18624846976
model-00019.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:617c98c96871403936caa0dcea602e7650cb947493555c142dc80e6c991adad8
3
+ size 18557475968
model-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ccea8f04adaeeb446b8def20c6042c96f6da4eb68da6bf2a76bacf65350e4e9
3
+ size 18624846976
model-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af8c9ca65f1830163f6d5741569b4dd4c62468a1c21556e7b760e303bc3b7818
3
+ size 18557475968
model-00022.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cc5137141b5e2522fd3e69a4c828a0dbb602569ab8a0afcce5151b06800339f
3
+ size 18624846976
model-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05c2c2a08df421f617794e137429246a6ea60dd908fc691263242a12325dae7f
3
+ size 9245052456
model-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7688adfc7748c12fdc8504187c57fe6ec6005798a02defc0d3372f921b1400a1
3
+ size 6968188464
model-vit-00001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22aa3f3679feffb57c2fb0bc885db0f5613db3536efef5d4b0984e8d769f6017
3
+ size 1613990904
model-vit-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f63ca4700a4184459d3ddb3a86c54a62914d359cedfddcfc14739ae782be082
3
+ size 2348122376
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_step3p7.py ADDED
@@ -0,0 +1,1405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Callable, Literal, Optional, Tuple, TypedDict, Union
19
+
20
+ from PIL import Image
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.masking_utils import (
29
+ create_causal_mask,
30
+ create_sliding_window_causal_mask,
31
+ )
32
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_layers import GradientCheckpointingLayer
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
39
+ from .configuration_step3p7 import Step3p7Config, Step3p7TextConfig
40
+ from .vision_encoder import StepRoboticsVisionEncoder
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _MASK_INPUT_EMBEDS_ARG = (
45
+ "inputs_embeds"
46
+ if "inputs_embeds" in inspect.signature(create_causal_mask).parameters
47
+ else "input_embeds"
48
+ )
49
+
50
+ __all__ = [
51
+ "Step3p7Model",
52
+ ]
53
+
54
+
55
+ class StepVLImagePixelInputs(TypedDict):
56
+ type: Literal["pixel_values"]
57
+ pixel_values: torch.Tensor
58
+ patch_pixel_values: Optional[torch.Tensor]
59
+ num_patches: list[int]
60
+
61
+
62
+ class StepVLImageEmbeddingInputs(TypedDict):
63
+ type: Literal["image_embeds"]
64
+ image_embeds: torch.Tensor
65
+
66
+
67
+ StepVLImageInputs = Union[StepVLImagePixelInputs, StepVLImageEmbeddingInputs]
68
+
69
+
70
+ def _flatten_embeddings(embeddings) -> torch.Tensor:
71
+ """
72
+ Recursively flattens and concatenates NestedTensors on all but the last
73
+ dimension.
74
+ """
75
+
76
+ if isinstance(embeddings, torch.Tensor):
77
+ # Flatten all but the last dimension.
78
+ return embeddings.flatten(0, -2)
79
+
80
+ return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
81
+
82
+ def _embedding_count_expression(embeddings) -> str:
83
+ """
84
+ Constructs a debugging representation of the number of embeddings in the
85
+ NestedTensors.
86
+ """
87
+
88
+ if isinstance(embeddings, torch.Tensor):
89
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
90
+
91
+ return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
92
+
93
+
94
+ def _merge_multimodal_embeddings(
95
+ inputs_embeds: torch.Tensor,
96
+ is_multimodal: torch.Tensor,
97
+ multimodal_embeddings,
98
+ ) -> torch.Tensor:
99
+ """
100
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
101
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
102
+ ``input_ids``.
103
+ Note:
104
+ This updates ``inputs_embeds`` in place.
105
+ """
106
+ num_expected_tokens = is_multimodal.sum().item()
107
+ assert isinstance(num_expected_tokens, int)
108
+
109
+ flattened = _flatten_embeddings(multimodal_embeddings)
110
+ if flattened.shape[0] != num_expected_tokens:
111
+ expr = _embedding_count_expression(multimodal_embeddings)
112
+ raise ValueError(
113
+ f"Attempted to assign {expr} = {flattened.shape[0]} "
114
+ f"multimodal tokens to {num_expected_tokens} placeholders"
115
+ )
116
+
117
+ is_multimodal = is_multimodal.to(inputs_embeds.device)
118
+ flattened = flattened.to(inputs_embeds.device)
119
+ inputs_embeds[is_multimodal] = flattened
120
+ return inputs_embeds
121
+
122
+ def merge_multimodal_embeddings(
123
+ input_ids: torch.Tensor,
124
+ inputs_embeds: torch.Tensor,
125
+ multimodal_embeddings,
126
+ placeholder_token_id: Union[int, list[int]],
127
+ ) -> torch.Tensor:
128
+ """
129
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
130
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
131
+ ``input_ids``.
132
+
133
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
134
+ of img_start, img_break, and img_end tokens) when needed: This means
135
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
136
+ their embeddings in ``multimodal_embeddings`` since we need to
137
+ slice-merge instead of individually scattering.
138
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
139
+ - T is text token
140
+ - S is image start token
141
+ - I is image embedding token
142
+ - B is image break token
143
+ - E is image end token.
144
+
145
+ Then the image embeddings (that correspond to I's) from vision encoder
146
+ must be padded with embeddings of S, B, and E in the same order of
147
+ input_ids for a correct embedding merge.
148
+ Note:
149
+ This updates ``inputs_embeds`` in place.
150
+ """
151
+ if isinstance(placeholder_token_id, list):
152
+ placeholder_token_id = torch.tensor(
153
+ placeholder_token_id, device=input_ids.device
154
+ )
155
+ return _merge_multimodal_embeddings(
156
+ inputs_embeds,
157
+ torch.isin(input_ids, placeholder_token_id),
158
+ multimodal_embeddings,
159
+ )
160
+
161
+ return _merge_multimodal_embeddings(
162
+ inputs_embeds,
163
+ (input_ids == placeholder_token_id),
164
+ multimodal_embeddings,
165
+ )
166
+
167
+
168
+ class Step3p7PreTrainedModel(PreTrainedModel):
169
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
170
+ # can load the config instead of failing with a NoneType error.
171
+ config_class = Step3p7Config
172
+ supports_gradient_checkpointing = True
173
+ _skip_keys_device_placement = ["past_key_values"]
174
+ _keys_to_ignore_on_load_unexpected = [
175
+ r"model\.layers\.45\.*",
176
+ r"model\.layers\.46\.*",
177
+ r"model\.layers\.47\.*",
178
+ ]
179
+ _supports_flash_attn = False
180
+ _supports_sdpa = True
181
+ _supports_flex_attn = True
182
+ _supports_static_cache = True
183
+ _supports_attention_backend = True
184
+
185
+ @classmethod
186
+ def from_pretrained(
187
+ cls, pretrained_model_name_or_path, *model_args, **kwargs
188
+ ):
189
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
190
+ if key_mapping is not None and kwargs.get("key_mapping") is None:
191
+ # Transformers only applies checkpoint renaming when key_mapping is
192
+ # passed explicitly; inheriting the class attribute alone is not enough.
193
+ kwargs["key_mapping"] = copy.deepcopy(key_mapping)
194
+ return super().from_pretrained(
195
+ pretrained_model_name_or_path, *model_args, **kwargs
196
+ )
197
+
198
+
199
+ class Step3p7RotaryEmbedding(nn.Module):
200
+ def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
201
+ super().__init__()
202
+ self.layer_idx = layer_idx
203
+ self.max_seq_len_cached = config.max_position_embeddings
204
+ self.original_max_seq_len = config.max_position_embeddings
205
+
206
+ rope_theta = config.rope_theta
207
+ if isinstance(rope_theta, list):
208
+ rope_theta = rope_theta[0 if layer_idx is None else layer_idx]
209
+
210
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
211
+ partial_rotary_factors = getattr(config, "partial_rotary_factors", None)
212
+ if partial_rotary_factors is not None:
213
+ partial_rotary_factor = partial_rotary_factors[
214
+ 0 if layer_idx is None else layer_idx
215
+ ]
216
+
217
+ self.rope_theta = rope_theta
218
+ self.partial_rotary_factor = partial_rotary_factor
219
+
220
+ self.config = copy.copy(config)
221
+ self.config.rope_theta = rope_theta
222
+ self.config.partial_rotary_factor = partial_rotary_factor
223
+
224
+ if config.rope_parameters is not None:
225
+ self.config.rope_parameters = copy.deepcopy(config.rope_parameters)
226
+ self.config.rope_parameters["rope_theta"] = rope_theta
227
+ self.config.rope_parameters["partial_rotary_factor"] = (
228
+ partial_rotary_factor
229
+ )
230
+ self.rope_type = self.config.rope_parameters.get(
231
+ "rope_type", self.config.rope_parameters.get("type")
232
+ )
233
+ else:
234
+ self.rope_type = "default"
235
+
236
+ self.rope_init_fn = self.compute_default_rope_parameters
237
+ if self.rope_type != "default":
238
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
239
+ inv_freq, self.attention_scaling = self.rope_init_fn(
240
+ self.config, device
241
+ )
242
+
243
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
244
+ self.original_inv_freq = self.inv_freq
245
+
246
+ @torch.no_grad()
247
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
248
+ def forward(self, x, position_ids):
249
+ inv_freq_expanded = (
250
+ self.inv_freq[None, :, None]
251
+ .float()
252
+ .expand(position_ids.shape[0], -1, 1)
253
+ .to(x.device)
254
+ )
255
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
256
+
257
+ device_type = (
258
+ x.device.type
259
+ if isinstance(x.device.type, str) and x.device.type != "mps"
260
+ else "cpu"
261
+ )
262
+ with torch.autocast(
263
+ device_type=device_type, enabled=False
264
+ ): # Force float32
265
+ freqs = (
266
+ inv_freq_expanded.float() @ position_ids_expanded.float()
267
+ ).transpose(1, 2)
268
+ emb = torch.cat((freqs, freqs), dim=-1)
269
+ cos = emb.cos() * self.attention_scaling
270
+ sin = emb.sin() * self.attention_scaling
271
+
272
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
273
+
274
+ @staticmethod
275
+ def compute_default_rope_parameters(
276
+ config: Step3p7TextConfig | None = None,
277
+ device: Optional["torch.device"] = None,
278
+ ) -> tuple["torch.Tensor", float]:
279
+ """
280
+ Computes the inverse frequencies according to the original RoPE implementation
281
+ Args:
282
+ config ([`~transformers.PreTrainedConfig`]):
283
+ The model configuration.
284
+ device (`torch.device`):
285
+ The device to use for initialization of the inverse frequencies.
286
+ seq_len (`int`, *optional*):
287
+ The current sequence length. Unused for this type of RoPE.
288
+ Returns:
289
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
290
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
291
+ """
292
+ base = config.rope_theta
293
+ partial_rotary_factor = getattr(
294
+ config, "partial_rotary_factor", 1.0
295
+ )
296
+ head_dim = (
297
+ getattr(config, "head_dim", None)
298
+ or config.hidden_size // config.num_attention_heads
299
+ )
300
+ dim = int(head_dim * partial_rotary_factor)
301
+
302
+ attention_factor = 1.0 # Unused in this type of RoPE
303
+
304
+ # Compute the inverse frequencies
305
+ inv_freq = 1.0 / (
306
+ base
307
+ ** (
308
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
309
+ device=device, dtype=torch.float
310
+ )
311
+ / dim
312
+ )
313
+ )
314
+ return inv_freq, attention_factor
315
+
316
+ def rotate_half(x):
317
+ """Rotates half the hidden dims of the input."""
318
+ x1 = x[..., :x.shape[-1] // 2]
319
+ x2 = x[..., x.shape[-1] // 2:]
320
+ return torch.cat((-x2, x1), dim=-1)
321
+
322
+
323
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
324
+ """Applies Rotary Position Embedding to the query and key tensors.
325
+
326
+ Args:
327
+ q (`torch.Tensor`): The query tensor.
328
+ k (`torch.Tensor`): The key tensor.
329
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
330
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
331
+ position_ids (`torch.Tensor`, *optional*):
332
+ Deprecated and unused.
333
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
334
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
335
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
336
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
337
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
338
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
339
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
340
+ Returns:
341
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
342
+ """
343
+ rotary_dim = cos.shape[-1]
344
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
345
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
346
+
347
+ # Apply rotary embeddings on the first half or full tensor
348
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
349
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
350
+
351
+ # Concatenate back to full shape
352
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
353
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
354
+ return q_embed, k_embed
355
+
356
+
357
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
358
+ """
359
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
360
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
361
+ """
362
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
363
+ if n_rep == 1:
364
+ return hidden_states
365
+ hidden_states = hidden_states[:, :, None, :, :].expand(
366
+ batch, num_key_value_heads, n_rep, slen, head_dim
367
+ )
368
+ return hidden_states.reshape(
369
+ batch, num_key_value_heads * n_rep, slen, head_dim
370
+ )
371
+
372
+
373
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward.
374
+ # Llama4 does not cast attention weights to fp32 here.
375
+ def eager_attention_forward(
376
+ module: nn.Module,
377
+ query: torch.Tensor,
378
+ key: torch.Tensor,
379
+ value: torch.Tensor,
380
+ attention_mask: Optional[torch.Tensor],
381
+ scaling: float,
382
+ dropout: float = 0.0,
383
+ **kwargs,
384
+ ):
385
+ key_states = repeat_kv(key, module.num_key_value_groups)
386
+ value_states = repeat_kv(value, module.num_key_value_groups)
387
+ # breakpoint()
388
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
389
+ if attention_mask is not None:
390
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
391
+ attn_weights = attn_weights + causal_mask
392
+
393
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
394
+ attn_weights = nn.functional.dropout(
395
+ attn_weights, p=dropout, training=module.training
396
+ )
397
+ attn_output = torch.matmul(attn_weights, value_states)
398
+ attn_output = attn_output.transpose(1, 2).contiguous()
399
+
400
+ return attn_output, attn_weights
401
+
402
+
403
+ @dataclass
404
+ class Step3p7CausalLMOutputWithPast(ModelOutput):
405
+ r"""
406
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
407
+ Language modeling loss (for next-token prediction).
408
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
409
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
410
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
411
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
412
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
413
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
414
+ `past_key_values` input) to speed up sequential decoding.
415
+ """
416
+
417
+ loss: Optional[torch.FloatTensor] = None
418
+ last_hidden_state: Optional[torch.FloatTensor] = None
419
+ logits: torch.FloatTensor = None
420
+ past_key_values: Optional[list[torch.FloatTensor]] = None
421
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
422
+ attentions: Optional[tuple[torch.FloatTensor]] = None
423
+
424
+
425
+ class Step3p7MLP(nn.Module):
426
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
427
+ super().__init__()
428
+ self.config = config
429
+ self.hidden_size = config.hidden_size
430
+ self.intermediate_size = (
431
+ intermediate_size
432
+ if intermediate_size is not None
433
+ else config.intermediate_size
434
+ )
435
+ self.gate_proj = nn.Linear(self.hidden_size,
436
+ self.intermediate_size,
437
+ bias=False)
438
+ self.up_proj = nn.Linear(self.hidden_size,
439
+ self.intermediate_size,
440
+ bias=False)
441
+ self.down_proj = nn.Linear(self.intermediate_size,
442
+ self.hidden_size,
443
+ bias=False)
444
+ self.act_fn = ACT2FN["silu"]
445
+ self.limit = swiglu_limit
446
+
447
+ def forward(self, x):
448
+ up = self.up_proj(x)
449
+ gate = self.act_fn(self.gate_proj(x))
450
+ if self.limit is not None:
451
+ gate = gate.clamp(min=None, max=self.limit)
452
+ up = up.clamp(min=-self.limit, max=self.limit)
453
+
454
+ return self.down_proj(gate * up)
455
+
456
+
457
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
458
+ renormalize: bool):
459
+ gating_output = gating_output.float()
460
+ gate_prob = torch.sigmoid(gating_output)
461
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
462
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
463
+ expert_topk_weight = topk_prob
464
+ if renormalize:
465
+ expert_topk_weight = expert_topk_weight / torch.sum(
466
+ expert_topk_weight, dim=-1, keepdim=True)
467
+ return expert_topk_weight, indices
468
+
469
+
470
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
471
+ renormalize: bool):
472
+ gating_output = gating_output.float()
473
+ gate_prob = torch.softmax(gating_output, dim=-1)
474
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
475
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
476
+ expert_topk_weight = topk_prob
477
+ if renormalize:
478
+ expert_topk_weight = expert_topk_weight / torch.sum(
479
+ expert_topk_weight, dim=-1, keepdim=True)
480
+ return expert_topk_weight, indices.to(torch.int32)
481
+
482
+
483
+ class MoELinear(nn.Module):
484
+
485
+ def __init__(self, num_experts, in_features, out_features):
486
+ super().__init__()
487
+ self.num_experts = num_experts
488
+ self.in_features = in_features
489
+ self.out_features = out_features
490
+ self.weight = nn.Parameter(
491
+ torch.empty(num_experts, out_features, in_features))
492
+
493
+ def forward(self, x, expert_id):
494
+ x = F.linear(x.float(), self.weight[expert_id].float())
495
+ return x
496
+
497
+
498
+ class Step3p7MoEMLP(nn.Module):
499
+
500
+ def __init__(self, config, swiglu_limit=None):
501
+ super().__init__()
502
+ self.num_experts = config.moe_num_experts
503
+ self.top_k = config.moe_top_k
504
+ self.hidden_size = config.hidden_size
505
+ self.moe_intermediate_size = config.moe_intermediate_size
506
+
507
+ self.use_moe_router_bias = config.use_moe_router_bias
508
+ if self.use_moe_router_bias:
509
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
510
+ dtype=torch.float32),
511
+ requires_grad=False)
512
+ self.custom_routing_function = self.router_bias_func
513
+ elif config.moe_router_activation == "sigmoid":
514
+ self.custom_routing_function = sigmoid_routing_function
515
+ else:
516
+ self.custom_routing_function = None
517
+ self.need_fp32_gate = config.need_fp32_gate
518
+ self.routed_scaling_factor = getattr(config,
519
+ "moe_router_scaling_factor", 1.0)
520
+
521
+ # gating
522
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
523
+
524
+ self.act_fn = ACT2FN["silu"]
525
+ self.limit = swiglu_limit
526
+
527
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
528
+ self.moe_intermediate_size)
529
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
530
+ self.moe_intermediate_size)
531
+ self.down_proj = MoELinear(self.num_experts,
532
+ self.moe_intermediate_size,
533
+ self.hidden_size)
534
+
535
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
536
+ renormalize: bool):
537
+ gate_prob = torch.sigmoid(gating_output.float())
538
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
539
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
540
+ topk_prob = torch.gather(gate_prob, 1, indices)
541
+ expert_topk_weight = topk_prob
542
+ if renormalize:
543
+ expert_topk_weight = expert_topk_weight / (
544
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
545
+ return expert_topk_weight, indices
546
+
547
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
548
+ #if self.limit is None:
549
+ up = self.up_proj(inputs, expert_id)
550
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
551
+ if self.limit is not None:
552
+ gate = gate.clamp(min=None, max=self.limit)
553
+ up = up.clamp(min=-self.limit, max=self.limit)
554
+
555
+ return self.down_proj(gate * up, expert_id)
556
+
557
+ def forward(self, hidden_states):
558
+ """ """
559
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
560
+ hidden_states = hidden_states.view(-1, hidden_dim)
561
+ if self.need_fp32_gate:
562
+ router_logits = torch.matmul(
563
+ hidden_states.to(torch.float32),
564
+ self.gate.weight.t().to(torch.float32),
565
+ )
566
+ else:
567
+ # router_logits: (batch * sequence_length, n_experts)
568
+ router_logits = self.gate(hidden_states)
569
+
570
+ if self.custom_routing_function:
571
+ routing_weights, selected_experts = self.custom_routing_function(
572
+ router_logits, self.top_k, renormalize=True)
573
+ else:
574
+ routing_weights = F.softmax(router_logits,
575
+ dim=1,
576
+ dtype=torch.float)
577
+ routing_weights, selected_experts = torch.topk(routing_weights,
578
+ self.top_k,
579
+ dim=-1)
580
+
581
+ routing_weights = routing_weights * self.routed_scaling_factor
582
+
583
+ final_hidden_states = torch.zeros(
584
+ (batch_size * sequence_length, hidden_dim),
585
+ dtype=hidden_states.dtype,
586
+ device=hidden_states.device)
587
+
588
+ # One hot encode the selected experts to create an expert mask
589
+ # this will be used to easily index which expert is going to be sollicitated
590
+ expert_mask = torch.nn.functional.one_hot(
591
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
592
+
593
+ # Loop over all available experts in the model and perform the computation on each expert
594
+ for expert_idx in range(self.num_experts):
595
+ idx, top_x = torch.where(expert_mask[expert_idx])
596
+
597
+ # Index the correct hidden states and compute the expert hidden state for
598
+ # the current expert. We need to make sure to multiply the output hidden
599
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
600
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
601
+ current_hidden_states = (
602
+ self.get_expert_output(current_state, expert_idx) *
603
+ routing_weights[top_x, idx, None])
604
+
605
+ # However `index_add_` only support torch tensors for indexing so we'll use
606
+ # the `top_x` tensor here.
607
+ final_hidden_states.index_add_(
608
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
609
+ final_hidden_states = final_hidden_states.reshape(
610
+ batch_size, sequence_length, hidden_dim)
611
+ return final_hidden_states
612
+
613
+
614
+ class Step3p7RMSNorm(nn.Module):
615
+
616
+ def __init__(
617
+ self,
618
+ hidden_size: int,
619
+ eps: float = 1e-5,
620
+ ) -> None:
621
+ super().__init__()
622
+ self.weight = nn.Parameter(torch.ones(hidden_size))
623
+ self.variance_epsilon = eps
624
+
625
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
626
+ dtype = x.dtype
627
+ x = x.float()
628
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
629
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
630
+ normed = normed * (self.weight.float() + 1)
631
+ return normed.to(dtype)
632
+ class Step3p7Attention(nn.Module):
633
+
634
+ def __init__(self, config: Step3p7TextConfig, layer_idx):
635
+ super().__init__()
636
+ self.config = config
637
+ self.layer_idx = layer_idx
638
+ self.num_attention_heads = config.num_attention_heads
639
+ self.num_key_value_heads = config.num_attention_groups
640
+
641
+ layer_types = getattr(config, "layer_types", [])
642
+ if layer_types:
643
+ enable_sliding_window = layer_types[
644
+ self.layer_idx] == "sliding_attention"
645
+ else:
646
+ enable_sliding_window = self.layer_idx % 2 == 0
647
+
648
+ yarn_only_types = getattr(config, "yarn_only_types", None)
649
+ if yarn_only_types and layer_types[
650
+ self.layer_idx] not in yarn_only_types:
651
+ config.rope_parameters = None
652
+ else:
653
+ config.rope_parameters = getattr(config, "rope_scaling", None)
654
+
655
+ self.sliding_window = config.sliding_window
656
+ if enable_sliding_window:
657
+ self.num_attention_heads = config.attention_other_setting[
658
+ "num_attention_heads"]
659
+ self.num_key_value_heads = config.attention_other_setting[
660
+ "num_attention_groups"]
661
+
662
+ if self.sliding_window is not None and enable_sliding_window:
663
+ self.sliding_window = (self.sliding_window)
664
+ else:
665
+ self.sliding_window = None
666
+ self.head_dim = getattr(config, "head_dim",
667
+ config.hidden_size // self.num_attention_heads)
668
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
669
+
670
+ self.rotary_emb = Step3p7RotaryEmbedding(config, layer_idx=layer_idx)
671
+
672
+ self.q_size = self.num_attention_heads * self.head_dim
673
+ self.kv_size = self.num_key_value_heads * self.head_dim
674
+ self.scaling = self.head_dim**-0.5
675
+
676
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
677
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
678
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
679
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
680
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
681
+ self.q_norm = Step3p7RMSNorm(self.head_dim,
682
+ eps=config.rms_norm_eps)
683
+ self.k_norm = Step3p7RMSNorm(self.head_dim,
684
+ eps=config.rms_norm_eps)
685
+
686
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
687
+ if self.use_head_wise_attn_gate:
688
+ self.g_proj = nn.Linear(config.hidden_size,
689
+ self.num_attention_heads,
690
+ bias=False)
691
+
692
+ self.use_rope = True
693
+ use_rope_layers = getattr(config, "use_rope_layers", None)
694
+ if use_rope_layers:
695
+ self.use_rope = use_rope_layers[self.layer_idx]
696
+
697
+ def forward(
698
+ self,
699
+ hidden_states: torch.Tensor,
700
+ attention_mask: Optional[torch.Tensor],
701
+ past_key_value: Optional[Cache] = None,
702
+ cache_position: Optional[torch.LongTensor] = None,
703
+ position_ids: Optional[torch.LongTensor] = None,
704
+ **kwargs: Unpack[FlashAttentionKwargs],
705
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
706
+ Optional[Tuple[torch.Tensor]]]:
707
+ input_shape = hidden_states.shape[:-1]
708
+ hidden_shape = (*input_shape, -1, self.head_dim)
709
+
710
+ query_states = self.q_norm(
711
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
712
+ key_states = self.k_norm(
713
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
714
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
715
+ 1, 2)
716
+ if self.use_head_wise_attn_gate:
717
+ gate_states = self.g_proj(hidden_states)
718
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
719
+
720
+ # cos, sin = position_embeddings
721
+ query_states, key_states = apply_rotary_pos_emb(
722
+ query_states, key_states, cos, sin)
723
+
724
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
725
+ if past_key_value is not None:
726
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
727
+ cache_kwargs = {
728
+ "sin": sin,
729
+ "cos": cos,
730
+ "cache_position": cache_position
731
+ }
732
+ key_states, value_states = past_key_value.update(
733
+ key_states, value_states, self.layer_idx, cache_kwargs)
734
+
735
+ attention_interface: Callable = eager_attention_forward
736
+ # TODO: considering FP8;
737
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
738
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
739
+ if self.config._attn_implementation != "eager":
740
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
741
+ self.config._attn_implementation]
742
+
743
+ attn_output, attn_weights = attention_interface(
744
+ self,
745
+ query_states,
746
+ key_states,
747
+ value_states,
748
+ attention_mask,
749
+ dropout=0.0 if not self.training else self.attention_dropout,
750
+ scaling=self.scaling,
751
+ sliding_window=self.sliding_window, # main diff with Llama
752
+ **kwargs,
753
+ )
754
+ attn_output = attn_output.reshape(*input_shape, -1)
755
+ if self.use_head_wise_attn_gate:
756
+ output = attn_output.view(
757
+ *attn_output.shape[:-1], self.num_attention_heads,
758
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
759
+ attn_output = output.view(*attn_output.shape)
760
+ attn_output = self.o_proj(attn_output)
761
+
762
+ return attn_output, attn_weights
763
+
764
+
765
+ class Step3p7DecoderLayer(GradientCheckpointingLayer):
766
+
767
+ def __init__(self, config, layer_idx):
768
+ super().__init__()
769
+ self.hidden_size = config.hidden_size
770
+ self.layer_idx = layer_idx
771
+ self.self_attn = Step3p7Attention(config, layer_idx)
772
+ layer_types = getattr(config, "layer_types", None) or []
773
+ if layer_types:
774
+ self.attention_type = layer_types[layer_idx]
775
+ else:
776
+ self.attention_type = (
777
+ "sliding_attention" if layer_idx % 2 == 0 else "full_attention"
778
+ )
779
+
780
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
781
+ if moe_layers_enum is not None:
782
+ if isinstance(moe_layers_enum, str):
783
+ moe_layers_idx = [
784
+ int(i) for i in moe_layers_enum.split(',') if i.strip()
785
+ ]
786
+ else:
787
+ moe_layers_idx = [int(i) for i in moe_layers_enum]
788
+ else:
789
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
790
+ self.is_moe_layer = layer_idx in moe_layers_idx
791
+ self.use_moe = False
792
+
793
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
794
+ layer_idx] is not None and config.swiglu_limits_shared[
795
+ layer_idx] != 0:
796
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
797
+ else:
798
+ swiglu_limit_shared = None
799
+ if config.swiglu_limits and config.swiglu_limits[
800
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
801
+ swiglu_limit = config.swiglu_limits[layer_idx]
802
+ else:
803
+ swiglu_limit = None
804
+ if self.is_moe_layer:
805
+ self.moe = Step3p7MoEMLP(config, swiglu_limit=swiglu_limit) #
806
+ self.share_expert = Step3p7MLP(
807
+ config,
808
+ intermediate_size=config.share_expert_dim,
809
+ swiglu_limit=swiglu_limit_shared)
810
+ self.use_moe = True
811
+ else:
812
+ self.mlp = Step3p7MLP(config,
813
+ intermediate_size=config.intermediate_size,
814
+ swiglu_limit=swiglu_limit_shared)
815
+
816
+ self.input_layernorm = Step3p7RMSNorm(
817
+ config.hidden_size,
818
+ eps=config.rms_norm_eps)
819
+ self.post_attention_layernorm = Step3p7RMSNorm(
820
+ config.hidden_size,
821
+ eps=config.rms_norm_eps)
822
+
823
+ def forward(
824
+ self,
825
+ hidden_states: torch.Tensor,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ position_ids: Optional[torch.LongTensor] = None,
828
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
829
+ cache_position: Optional[torch.LongTensor] = None,
830
+ **kwargs: Unpack[FlashAttentionKwargs],
831
+ ) -> torch.FloatTensor:
832
+ residual = hidden_states
833
+ hidden_states = self.input_layernorm(hidden_states)
834
+ hidden_states, _ = self.self_attn(
835
+ hidden_states=hidden_states,
836
+ attention_mask=attention_mask,
837
+ position_ids=position_ids,
838
+ past_key_value=past_key_value,
839
+ cache_position=cache_position,
840
+ **kwargs,
841
+ )
842
+ hidden_states = residual + hidden_states
843
+
844
+ # Fully Connected
845
+ residual = hidden_states
846
+ hidden_states = self.post_attention_layernorm(hidden_states)
847
+ if self.use_moe:
848
+ share_output = self.share_expert(hidden_states)
849
+ moe_output = self.moe(hidden_states)
850
+ ffn_output = moe_output + share_output
851
+ else:
852
+ ffn_output = self.mlp(hidden_states)
853
+ if isinstance(ffn_output, tuple):
854
+ hidden_states, _ = ffn_output
855
+ else:
856
+ hidden_states = ffn_output
857
+
858
+ hidden_states = residual + hidden_states
859
+ return hidden_states
860
+
861
+
862
+ class Step3p7TextPreTrainedModel(PreTrainedModel):
863
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
864
+ # can load the config instead of failing with a NoneType error.
865
+ config_class = Step3p7TextConfig
866
+ supports_gradient_checkpointing = True
867
+ _skip_keys_device_placement = ["past_key_values"]
868
+ _keys_to_ignore_on_load_unexpected = [
869
+ r"model\.layers\.45\.*",
870
+ r"model\.layers\.46\.*",
871
+ r"model\.layers\.47\.*",
872
+ ]
873
+ _supports_flash_attn = False
874
+ _supports_sdpa = True
875
+ _supports_flex_attn = True
876
+ _supports_static_cache = True
877
+ _supports_attention_backend = True
878
+
879
+
880
+ class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
881
+ _no_split_modules = ["Step3p7DecoderLayer"]
882
+ base_model_prefix = "model"
883
+ _tied_weights_keys = ["lm_head.weight"]
884
+ config: Step3p7TextConfig
885
+
886
+ def __init__(self, config: Step3p7TextConfig):
887
+ super().__init__(config)
888
+ self.padding_idx = config.pad_token_id
889
+ self.vocab_size = config.vocab_size
890
+
891
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
892
+ self.padding_idx)
893
+ self.layers = nn.ModuleList([
894
+ Step3p7DecoderLayer(config, layer_idx)
895
+ for layer_idx in range(config.num_hidden_layers)
896
+ ])
897
+ self.norm = Step3p7RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
898
+ self.gradient_checkpointing = False
899
+ layer_types = self.config.layer_types or []
900
+ self.has_sliding_layers = (not layer_types or
901
+ "sliding_attention" in layer_types)
902
+
903
+ # Initialize weights and apply final processing
904
+ self.post_init()
905
+
906
+
907
+ def get_input_embeddings(self, input_ids):
908
+ return self.embed_tokens(input_ids)
909
+
910
+ @can_return_tuple
911
+ def forward(
912
+ self,
913
+ input_ids: torch.LongTensor = None,
914
+ attention_mask: Optional[torch.Tensor] = None,
915
+ position_ids: Optional[torch.LongTensor] = None,
916
+ past_key_values: Optional[Cache] = None,
917
+ inputs_embeds: Optional[torch.FloatTensor] = None,
918
+ use_cache: Optional[bool] = None,
919
+ output_attentions: Optional[bool] = None,
920
+ output_hidden_states: Optional[bool] = None,
921
+ return_dict: Optional[bool] = None,
922
+ cache_position: Optional[torch.LongTensor] = None,
923
+ **kwargs: Unpack[TransformersKwargs],
924
+ ) -> Union[tuple, BaseModelOutputWithPast]:
925
+ output_attentions = (
926
+ output_attentions
927
+ if output_attentions is not None
928
+ else self.config.output_attentions
929
+ )
930
+ output_hidden_states = (
931
+ output_hidden_states
932
+ if output_hidden_states is not None
933
+ else self.config.output_hidden_states
934
+ )
935
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
936
+ return_dict = (
937
+ return_dict
938
+ if return_dict is not None
939
+ else getattr(self.config, "return_dict", True)
940
+ )
941
+ if (input_ids is None) ^ (inputs_embeds is not None):
942
+ raise ValueError(
943
+ "You must specify exactly one of input_ids or inputs_embeds")
944
+
945
+ if self.gradient_checkpointing and self.training and use_cache:
946
+ logger.warning_once(
947
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
948
+ )
949
+ use_cache = False
950
+
951
+ if inputs_embeds is None:
952
+ inputs_embeds = self.embed_tokens(
953
+ input_ids.to(self.embed_tokens.weight.device))
954
+
955
+ if use_cache and past_key_values is None:
956
+ past_key_values = DynamicCache()
957
+
958
+ if cache_position is None:
959
+ past_seen_tokens = past_key_values.get_seq_length(
960
+ ) if past_key_values is not None else 0
961
+ cache_position = torch.arange(past_seen_tokens,
962
+ past_seen_tokens +
963
+ inputs_embeds.shape[1],
964
+ device=inputs_embeds.device)
965
+
966
+ if position_ids is None:
967
+ position_ids = cache_position.unsqueeze(0)
968
+
969
+ hidden_states = inputs_embeds
970
+
971
+ # It may already have been prepared by e.g. `generate`
972
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
973
+ # Prepare mask arguments
974
+ mask_kwargs = {
975
+ "config": self.config,
976
+ "attention_mask": attention_mask,
977
+ "past_key_values": past_key_values,
978
+ "position_ids": position_ids,
979
+ }
980
+ mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
981
+ # Create the masks
982
+ causal_mask_mapping = {
983
+ "full_attention": create_causal_mask(**mask_kwargs),
984
+ }
985
+
986
+ # The sliding window alternating layers are not always activated depending on the config
987
+ if self.has_sliding_layers:
988
+ causal_mask_mapping[
989
+ "sliding_attention"] = create_sliding_window_causal_mask(
990
+ **mask_kwargs)
991
+
992
+ # # create position embeddings to be shared across the decoder layers
993
+ # decoder layers
994
+ all_hidden_states = () if output_hidden_states else None
995
+ all_self_attns = () if output_attentions else None
996
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
997
+ if output_hidden_states:
998
+ all_hidden_states += (hidden_states, )
999
+
1000
+ layer_outputs = decoder_layer(
1001
+ hidden_states,
1002
+ attention_mask=causal_mask_mapping[
1003
+ decoder_layer.attention_type],
1004
+ position_ids=position_ids,
1005
+ past_key_value=past_key_values,
1006
+ output_attentions=output_attentions,
1007
+ use_cache=use_cache,
1008
+ cache_position=cache_position,
1009
+ **kwargs,
1010
+ )
1011
+
1012
+ hidden_states = layer_outputs
1013
+
1014
+ hidden_states = self.norm(hidden_states)
1015
+
1016
+ return BaseModelOutputWithPast(
1017
+ last_hidden_state=hidden_states,
1018
+ past_key_values=past_key_values if use_cache else None,
1019
+ hidden_states=all_hidden_states,
1020
+ attentions=all_self_attns,
1021
+ )
1022
+
1023
+
1024
+ class Step3p7Model(Step3p7PreTrainedModel, GenerationMixin):
1025
+ config: Step3p7Config
1026
+ _tied_weights_keys = ["lm_head.weight"]
1027
+ base_model_prefix = ""
1028
+
1029
+ def __init__(self, config: Step3p7Config):
1030
+ super().__init__(config)
1031
+ self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
1032
+ self.language_model = Step3p7TextModel(config.text_config)
1033
+ self.vocab_size = config.text_config.vocab_size
1034
+ self.vit_large_projector = nn.Linear(
1035
+ config.vision_config.width * 4,
1036
+ config.text_config.hidden_size,
1037
+ bias=config.projector_bias)
1038
+ self.image_placeholder_token_id = config.image_token_id
1039
+
1040
+ # Initialize weights and apply final processing
1041
+ self.post_init()
1042
+
1043
+ def get_input_embeddings(
1044
+ self,
1045
+ input_ids: torch.Tensor,
1046
+ multimodal_embeddings = None,
1047
+ ) -> torch.Tensor:
1048
+ # breakpoint()
1049
+ input_ids = input_ids.squeeze(0)
1050
+ if multimodal_embeddings is None:
1051
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1052
+ else:
1053
+ is_text = input_ids != self.config.image_token_id
1054
+ text_ids = input_ids[is_text]
1055
+ text_embeds = self.language_model.get_input_embeddings(text_ids)
1056
+
1057
+ inputs_embeds = torch.empty(input_ids.shape[0],
1058
+ text_embeds.shape[-1],
1059
+ dtype=text_embeds.dtype,
1060
+ device=text_embeds.device)
1061
+ inputs_embeds[is_text] = text_embeds
1062
+ inputs_embeds = merge_multimodal_embeddings(
1063
+ input_ids, inputs_embeds, multimodal_embeddings,
1064
+ self.config.image_token_id)
1065
+ inputs_embeds = inputs_embeds.unsqueeze(0)
1066
+ return inputs_embeds
1067
+
1068
+
1069
+ def set_input_embeddings(self, value):
1070
+ return self.language_model.set_input_embeddings(value)
1071
+
1072
+ def set_decoder(self, decoder):
1073
+ self.language_model = decoder
1074
+
1075
+ def get_decoder(self):
1076
+ return self.language_model
1077
+
1078
+ def _parse_and_validate_image_input(
1079
+ self, **kwargs: object) -> Optional[StepVLImageInputs]:
1080
+ pixel_values = kwargs.pop("pixel_values", None)
1081
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
1082
+ num_patches = kwargs.pop("num_patches", None)
1083
+ image_embeds = kwargs.pop("image_embeds", None)
1084
+
1085
+ if pixel_values is None and image_embeds is None:
1086
+ return None
1087
+
1088
+ if pixel_values is not None:
1089
+ # pixel_values = flatten_bn(pixel_values, concat=True)
1090
+ if pixel_values.dim() >= 3:
1091
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
1092
+ if patch_pixel_values is not None:
1093
+ # patch_pixel_values = flatten_bn(patch_pixel_values,
1094
+ # concat=True)
1095
+ patch_pixel_values = patch_pixel_values.view(
1096
+ -1, *patch_pixel_values.shape[-3:])
1097
+ # Handle empty patch_pixel_values by setting to None
1098
+ if patch_pixel_values.shape[0] == 0:
1099
+ patch_pixel_values = None
1100
+
1101
+ return StepVLImagePixelInputs(
1102
+ type="pixel_values",
1103
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
1104
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
1105
+ self.device) if patch_pixel_values is not None else None,
1106
+ num_patches=num_patches,
1107
+ )
1108
+
1109
+ if image_embeds is not None:
1110
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
1111
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
1112
+ else:
1113
+ raise ValueError(
1114
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
1115
+
1116
+ return StepVLImageEmbeddingInputs(
1117
+ type="image_embeds",
1118
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
1119
+ )
1120
+ return None
1121
+
1122
+ def _process_image_features(self,
1123
+ image_features: torch.Tensor) -> torch.Tensor:
1124
+ B, P = image_features.shape[:2]
1125
+ HW = int(P ** 0.5)
1126
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
1127
+ image_features = self.vision_model.vit_downsampler1(image_features)
1128
+ image_features = self.vision_model.vit_downsampler2(image_features)
1129
+
1130
+ B, C, HW, HW = image_features.shape
1131
+ image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
1132
+ image_features = self.vit_large_projector(image_features)
1133
+ return image_features
1134
+
1135
+ def _get_vision_model_output(self,
1136
+ input_tensor: torch.Tensor) -> torch.Tensor:
1137
+ return self.vision_model(input_tensor)
1138
+
1139
+ def _process_image_input(
1140
+ self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
1141
+
1142
+ if image_input["type"] == "image_embeds":
1143
+ image_features = image_input["image_embeds"]
1144
+ else:
1145
+ image_features = self._get_vision_model_output(
1146
+ image_input["pixel_values"])
1147
+ patch_image_features = self._get_vision_model_output(
1148
+ image_input["patch_pixel_values"]
1149
+ ) if image_input["patch_pixel_values"] is not None else None
1150
+ num_patches = image_input["num_patches"]
1151
+
1152
+ image_features = self._process_image_features(image_features)
1153
+ patch_image_features = self._process_image_features(
1154
+ patch_image_features) if patch_image_features is not None else None
1155
+
1156
+ merged_image_features = []
1157
+ cur_patch_idx = 0
1158
+ for i, num_patch in enumerate(num_patches):
1159
+ cur_feature = []
1160
+ if num_patch > 0:
1161
+ patch_slice = patch_image_features[
1162
+ cur_patch_idx:cur_patch_idx + num_patch]
1163
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
1164
+ cur_feature.append(image_features[i].view(
1165
+ -1, image_features.shape[-1]))
1166
+ cur_patch_idx += num_patch
1167
+ merged_image_features.append(
1168
+ torch.cat(cur_feature) if len(cur_feature) >
1169
+ 1 else cur_feature[0])
1170
+
1171
+ return merged_image_features
1172
+
1173
+ def get_multimodal_embeddings(self, **kwargs):
1174
+ # breakpoint()
1175
+ image_input = self._parse_and_validate_image_input(**kwargs)
1176
+ if image_input is None:
1177
+ return None
1178
+ vision_embeddings = self._process_image_input(image_input)
1179
+ return vision_embeddings
1180
+
1181
+ @can_return_tuple
1182
+ def forward(
1183
+ self,
1184
+ input_ids: torch.LongTensor = None,
1185
+ attention_mask: Optional[torch.Tensor] = None,
1186
+ position_ids: Optional[torch.LongTensor] = None,
1187
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
1188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1189
+ labels: Optional[torch.LongTensor] = None,
1190
+ use_cache: Optional[bool] = None,
1191
+ output_attentions: Optional[bool] = None,
1192
+ output_hidden_states: Optional[bool] = None,
1193
+ return_dict: Optional[bool] = None,
1194
+ cache_position: Optional[torch.LongTensor] = None,
1195
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1196
+ images: Optional[list[Image.Image]] = None,
1197
+ **kwargs: Unpack[TransformersKwargs],
1198
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1199
+ r"""
1200
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1201
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1202
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1203
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1204
+ Example:
1205
+ ```python
1206
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
1207
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1208
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1209
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1210
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1211
+ >>> # Generate
1212
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1213
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1214
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1215
+ ```"""
1216
+ output_attentions = (
1217
+ output_attentions
1218
+ if output_attentions is not None
1219
+ else self.config.output_attentions
1220
+ )
1221
+ output_hidden_states = (
1222
+ output_hidden_states
1223
+ if output_hidden_states is not None
1224
+ else self.config.output_hidden_states
1225
+ )
1226
+ return_dict = (
1227
+ return_dict if return_dict is not None else self.config.use_return_dict
1228
+ )
1229
+
1230
+ if inputs_embeds is None:
1231
+ input_ids = input_ids
1232
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1233
+ inputs_embeds = self.get_input_embeddings(input_ids,
1234
+ vision_embeddings)
1235
+ input_ids = None
1236
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1237
+ outputs = self.language_model(
1238
+ input_ids=None,
1239
+ position_ids=position_ids,
1240
+ attention_mask=attention_mask,
1241
+ past_key_values=past_key_values,
1242
+ inputs_embeds=inputs_embeds,
1243
+ use_cache=use_cache,
1244
+ output_attentions=output_attentions,
1245
+ output_hidden_states=output_hidden_states,
1246
+ return_dict=True,
1247
+ cache_position=cache_position,
1248
+ **kwargs,
1249
+ )
1250
+
1251
+ output = Step3p7CausalLMOutputWithPast(
1252
+ last_hidden_state=outputs.last_hidden_state,
1253
+ past_key_values=outputs.past_key_values,
1254
+ attentions=outputs.attentions,
1255
+ )
1256
+ return output if return_dict else output.to_tuple()
1257
+
1258
+
1259
+ class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
1260
+ _checkpoint_conversion_mapping = {
1261
+ "^vision_model": "model.vision_model",
1262
+ r"^model(?!\.(language_model|vision_model))": "model.language_model",
1263
+ "^vit_large_projector": "model.vit_large_projector",
1264
+ }
1265
+ _tied_weights_keys = ["lm_head.weight"]
1266
+ config: Step3p7Config
1267
+
1268
+ def __init__(self, config: Step3p7Config):
1269
+ super().__init__(config)
1270
+ self.model = Step3p7Model(config)
1271
+ self.lm_head = nn.Linear(config.hidden_size,
1272
+ config.text_config.vocab_size,
1273
+ bias=False)
1274
+
1275
+ self.post_init()
1276
+
1277
+ def get_input_embeddings(self):
1278
+ return self.model.get_input_embeddings()
1279
+
1280
+ def set_input_embeddings(self, value):
1281
+ self.model.set_input_embeddings(value)
1282
+
1283
+ def get_output_embeddings(self):
1284
+ return self.model.get_output_embeddings()
1285
+
1286
+ def set_output_embeddings(self, new_embeddings):
1287
+ self.model.set_output_embeddings(new_embeddings)
1288
+
1289
+ def set_decoder(self, decoder):
1290
+ self.model.set_decoder(decoder)
1291
+
1292
+ def get_decoder(self):
1293
+ return self.model.get_decoder()
1294
+
1295
+ @property
1296
+ def language_model(self):
1297
+ return self.model.language_model
1298
+
1299
+ @property
1300
+ def visual(self):
1301
+ return self.model.vision_model
1302
+
1303
+ def forward(
1304
+ self,
1305
+ input_ids: torch.LongTensor = None,
1306
+ pixel_values: Optional[torch.Tensor] = None,
1307
+ num_patches=None,
1308
+ patch_pixel_values=None,
1309
+ patch_newline_mask=None,
1310
+ image_embeds: Optional[torch.FloatTensor] = None,
1311
+ attention_mask: Optional[torch.Tensor] = None,
1312
+ position_ids: Optional[torch.LongTensor] = None,
1313
+ past_key_values: Optional[Cache] = None,
1314
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1315
+ labels: Optional[torch.LongTensor] = None,
1316
+ use_cache: Optional[bool] = None,
1317
+ output_attentions: Optional[bool] = None,
1318
+ output_hidden_states: Optional[bool] = None,
1319
+ return_dict: Optional[bool] = None,
1320
+ cache_position: Optional[torch.LongTensor] = None,
1321
+ **kwargs: Unpack[TransformersKwargs],
1322
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1323
+ output_attentions = (
1324
+ output_attentions
1325
+ if output_attentions is not None
1326
+ else self.config.output_attentions
1327
+ )
1328
+ output_hidden_states = (
1329
+ output_hidden_states
1330
+ if output_hidden_states is not None
1331
+ else self.config.output_hidden_states
1332
+ )
1333
+
1334
+ outputs = self.model(
1335
+ input_ids=input_ids,
1336
+ num_patches=num_patches,
1337
+ patch_pixel_values=patch_pixel_values,
1338
+ patch_newline_mask=patch_newline_mask,
1339
+ position_ids=position_ids,
1340
+ attention_mask=attention_mask,
1341
+ past_key_values=past_key_values,
1342
+ inputs_embeds=inputs_embeds,
1343
+ use_cache=use_cache,
1344
+ output_attentions=output_attentions,
1345
+ output_hidden_states=output_hidden_states,
1346
+ return_dict=return_dict,
1347
+ cache_position=cache_position,
1348
+ **kwargs,
1349
+ )
1350
+
1351
+ hidden_states = outputs.last_hidden_state
1352
+ logits = self.lm_head(hidden_states)
1353
+
1354
+ los = None
1355
+ if labels is not None:
1356
+ loss = self.loss_function(
1357
+ logits=logits, labels=labels, vocab_size=self.config.vocab_size
1358
+ )
1359
+
1360
+ return Step3p7CausalLMOutputWithPast(
1361
+ logits=logits,
1362
+ )
1363
+
1364
+
1365
+ def prepare_inputs_for_generation(
1366
+ self,
1367
+ input_ids,
1368
+ past_key_values=None,
1369
+ inputs_embeds=None,
1370
+ pixel_values=None,
1371
+ patch_pixel_values=None,
1372
+ num_patches=None,
1373
+ image_embeds=None,
1374
+ attention_mask=None,
1375
+ cache_position=None,
1376
+ logits_to_keep=None,
1377
+ **kwargs,
1378
+ ):
1379
+ model_inputs = super().prepare_inputs_for_generation(
1380
+ input_ids,
1381
+ past_key_values=past_key_values,
1382
+ inputs_embeds=inputs_embeds,
1383
+ attention_mask=attention_mask,
1384
+ cache_position=cache_position,
1385
+ logits_to_keep=logits_to_keep,
1386
+ **kwargs,
1387
+ )
1388
+
1389
+ generation_cache_position = model_inputs.get("cache_position", cache_position)
1390
+ is_prefill = past_key_values is None
1391
+ if generation_cache_position is not None and generation_cache_position.numel() > 0:
1392
+ is_prefill = generation_cache_position[0].item() == 0
1393
+
1394
+ if is_prefill:
1395
+ # During cached decoding, input ids no longer contain image tokens,
1396
+ # so pixel values should only be passed at the first step.
1397
+ model_inputs["pixel_values"] = pixel_values
1398
+
1399
+ return model_inputs
1400
+
1401
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
1402
+ if key.startswith("language_model."):
1403
+ return key[len("language_model.") :], True
1404
+
1405
+ return key, False
processing_step3.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BaseImageProcessor, ImageProcessingMixin
2
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3
+ import math
4
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
5
+
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature, TensorType
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
19
+ from transformers.tokenization_utils_tokenizers import TokenizersBackend
20
+ from math import ceil
21
+ from itertools import product
22
+
23
+
24
+
25
+ MAX_IMAGE_SIZE: int = 3024
26
+
27
+ class Step3VLImagePixelInputs(TypedDict):
28
+ type: Literal["pixel_values"]
29
+ pixel_values: torch.Tensor
30
+ patch_pixel_values: Optional[torch.Tensor]
31
+ num_patches: list[int]
32
+
33
+
34
+ class Step3VLImageEmbeddingInputs(TypedDict):
35
+ type: Literal["image_embeds"]
36
+ image_embeds: torch.Tensor
37
+
38
+
39
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
40
+
41
+
42
+ class GPUToTensor(torch.nn.Module):
43
+
44
+ def forward(self, raw_image: Union[np.ndarray,
45
+ Image.Image]) -> torch.Tensor:
46
+ if isinstance(raw_image, Image.Image):
47
+ return transforms.ToTensor()(raw_image)
48
+ if raw_image.ndim == 2:
49
+ raw_image = raw_image[:, :, None].repeat(3, -1)
50
+ if torch.cuda.is_available():
51
+ device = torch.device("cuda")
52
+ else:
53
+ device = torch.device("cpu")
54
+ image_tensor = torch.from_numpy(raw_image).to(device)
55
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
56
+ if image_tensor.dtype == torch.uint8:
57
+ image_tensor = image_tensor.to(torch.float32).div(255)
58
+ return image_tensor
59
+
60
+ class Step3VisionProcessor(BaseImageProcessor):
61
+
62
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
63
+ mean = [0.48145466, 0.4578275, 0.40821073]
64
+ std = [0.26862954, 0.26130258, 0.27577711]
65
+ patch_size = patch_size if patch_size is not None else size
66
+
67
+ self.transform = transforms.Compose([
68
+ GPUToTensor(),
69
+ transforms.Normalize(mean, std),
70
+ transforms.Resize(
71
+ (size, size),
72
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
73
+ == "bicubic" else InterpolationMode.BILINEAR,
74
+ antialias=True),
75
+ ])
76
+
77
+ self.patch_transform = transforms.Compose([
78
+ GPUToTensor(),
79
+ transforms.Normalize(mean, std),
80
+ transforms.Resize(
81
+ (patch_size, patch_size),
82
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
83
+ == "bicubic" else InterpolationMode.BILINEAR,
84
+ antialias=True),
85
+ ]) if patch_size is not None else None
86
+
87
+ def __call__(self, image, is_patch=False):
88
+ if is_patch:
89
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
90
+ else:
91
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
92
+
93
+ class ImagePatcher:
94
+ def determine_window_size(self, long: int, short: int) -> int:
95
+ if long <= 728:
96
+ return short if long / short > 1.5 else 0
97
+ return min(short, 504) if long / short > 4 else 504
98
+ def slide_window(
99
+ self,
100
+ width: int,
101
+ height: int,
102
+ sizes: list[tuple[int, int]],
103
+ steps: list[tuple[int, int]],
104
+ img_rate_thr: float = 0.6,
105
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
106
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
107
+ windows = []
108
+ # Sliding windows.
109
+ for size, step in zip(sizes, steps):
110
+ size_w, size_h = size
111
+ step_w, step_h = step
112
+
113
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
114
+ 1)
115
+ x_start = [step_w * i for i in range(x_num)]
116
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
117
+ x_start[-1] = width - size_w
118
+
119
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
120
+ step_h + 1)
121
+ y_start = [step_h * i for i in range(y_num)]
122
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
123
+ y_start[-1] = height - size_h
124
+
125
+ start = np.array(list(product(y_start, x_start)), dtype=int)
126
+ start[:, [0, 1]] = start[:, [1, 0]]
127
+ windows.append(np.concatenate([start, start + size], axis=1))
128
+ windows = np.concatenate(windows, axis=0)
129
+
130
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
131
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
132
+
133
+ def square_pad(self, img: Image.Image) -> Image.Image:
134
+ w, h = img.size
135
+ if w == h:
136
+ return img
137
+ size = max(w, h)
138
+ padded = Image.new(img.mode, (size, size), 0)
139
+ padded.paste(img, (0, 0))
140
+ return padded
141
+
142
+ def get_image_size_for_padding(self, img_width: int,
143
+ img_height: int) -> tuple[int, int]:
144
+ ratio = img_width / img_height
145
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
146
+ new_size = max(img_height, img_width)
147
+ return new_size, new_size
148
+ return img_width, img_height
149
+
150
+ def get_image_size_for_preprocess(self, img_width: int,
151
+ img_height: int) -> tuple[int, int]:
152
+
153
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
154
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
155
+ img_width = int(img_width * scale_factor)
156
+ img_height = int(img_height * scale_factor)
157
+ return img_width, img_height
158
+
159
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
160
+ window_size: int):
161
+ w_ratio = img_width / window_size
162
+ h_ratio = img_height / window_size
163
+
164
+ if w_ratio < 1:
165
+ width_new = img_width
166
+ else:
167
+ decimal_w = w_ratio - img_width // window_size
168
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
169
+ width_new = window_size * w_ratio
170
+ if h_ratio < 1:
171
+ height_new = img_height
172
+ else:
173
+ decimal_h = h_ratio - img_height // window_size
174
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
175
+ height_new = window_size * h_ratio
176
+ return int(width_new), int(height_new)
177
+
178
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
179
+ target = img.crop((j, i, j + tw, i + th))
180
+ return target
181
+
182
+ def get_num_patches(self, img_width: int,
183
+ img_height: int) -> tuple[int, int]:
184
+ img_width, img_height = self.get_image_size_for_padding(
185
+ img_width, img_height)
186
+ img_width, img_height = self.get_image_size_for_preprocess(
187
+ img_width, img_height)
188
+ window_size = self.determine_window_size(max(img_height, img_width),
189
+ min(img_height, img_width))
190
+ if window_size == 0:
191
+ return 0, 0
192
+ else:
193
+ img_width, img_height = self.get_image_size_for_crop(
194
+ img_width, img_height, window_size)
195
+ center_list, (x_num, y_num) = self.slide_window(
196
+ img_width, img_height, [(window_size, window_size)],
197
+ [(window_size, window_size)])
198
+ full_rows = (len(center_list) - 1) // x_num + 1
199
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
200
+ full_rows -= 1
201
+ return len(center_list), full_rows
202
+
203
+ def __call__(
204
+ self, img: Image.Image
205
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
206
+ img_width, img_height = img.size
207
+ new_img_width, new_img_height = self.get_image_size_for_padding(
208
+ img_width, img_height)
209
+ if new_img_width != img_width or new_img_height != img_height:
210
+ img = self.square_pad(img)
211
+ img_width, img_height = img.size
212
+
213
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
214
+ img_width, img_height)
215
+ img = img.resize((new_img_width, new_img_height),
216
+ Image.Resampling.BILINEAR)
217
+ window_size = self.determine_window_size(
218
+ max(new_img_height, new_img_width),
219
+ min(new_img_height, new_img_width))
220
+ # return img, [], None
221
+ if window_size == 0:
222
+ return img, [], None
223
+ else:
224
+ new_img_width, new_img_height = self.get_image_size_for_crop(
225
+ new_img_width, new_img_height, window_size)
226
+ if (new_img_width, new_img_height) != (img_width, img_height):
227
+ img_for_crop = img.resize((new_img_width, new_img_height),
228
+ Image.Resampling.BILINEAR)
229
+ else:
230
+ img_for_crop = img
231
+
232
+ patches = []
233
+ newlines = []
234
+ center_list, (x_num, y_num) = self.slide_window(
235
+ new_img_width, new_img_height, [(window_size, window_size)],
236
+ [(window_size, window_size)])
237
+ for patch_id, center_lf_point in enumerate(center_list):
238
+ x, y, patch_w, patch_h = center_lf_point
239
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
240
+ patch_w)
241
+ patches.append(big_patch)
242
+ if (patch_id + 1) % x_num == 0:
243
+ newlines.append(patch_id)
244
+
245
+ if newlines and newlines[-1] == len(patches) - 1:
246
+ newlines.pop()
247
+
248
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
249
+
250
+
251
+
252
+
253
+ class Step3VLProcessor(ProcessorMixin):
254
+ # Align ProcessorMixin with our custom components.
255
+ # We only have an image processor (not a feature extractor) plus a tokenizer.
256
+ attributes = ["tokenizer"]
257
+ tokenizer_class = "AutoTokenizer"
258
+
259
+ @classmethod
260
+ def _load_tokenizer_from_pretrained(
261
+ cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
262
+ ):
263
+ return TokenizersBackend.from_pretrained(
264
+ pretrained_model_name_or_path,
265
+ subfolder=subfolder,
266
+ **kwargs,
267
+ )
268
+
269
+ def __init__(
270
+ self,
271
+ tokenizer=None,
272
+ chat_template=None,
273
+ **kwargs
274
+ ) -> None:
275
+ self.image_size = 728
276
+ self.patch_size = 504
277
+
278
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
279
+ "bilinear",
280
+ self.patch_size)
281
+
282
+ self.num_image_feature_size = 169
283
+ self.num_patch_feature_size = 81
284
+ self.image_token = "<im_patch>"
285
+ self.image_feature_placeholder = (self.image_token *
286
+ self.num_image_feature_size)
287
+ self.patch_feature_placeholder = (self.image_token *
288
+ self.num_patch_feature_size)
289
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
290
+ self.patcher = ImagePatcher()
291
+
292
+ @property
293
+ def image_token_id(self) -> int:
294
+ return self.tokenizer.get_vocab()[self.image_token]
295
+
296
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
297
+ num_patches, num_newlines = self.patcher.get_num_patches(
298
+ img_width, img_height)
299
+
300
+ return num_patches * (
301
+ self.num_patch_feature_size +
302
+ 2) + self.num_image_feature_size + 2 + num_newlines
303
+
304
+ def _split_images(self,
305
+ images: list[Image.Image]) -> list[ImageWithPatches]:
306
+ result = []
307
+ for img in images:
308
+ result.append(self.patcher(img))
309
+ return result
310
+
311
+ def _convert_images_to_pixel_values(
312
+ self,
313
+ images: list[Image.Image],
314
+ is_patch: bool = False,
315
+ ) -> list[torch.Tensor]:
316
+ return [
317
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
318
+ for img in images
319
+ ]
320
+
321
+ def _get_patch_repl(
322
+ self,
323
+ num_patches: int,
324
+ patch_newline_mask: list[bool] | None,
325
+ ) -> tuple[str, list[int]]:
326
+ text = ""
327
+ token_ids = []
328
+ for i in range(num_patches):
329
+ assert len(patch_newline_mask) == num_patches
330
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
331
+ token_ids.extend(
332
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
333
+ [self.image_token_id] * self.num_patch_feature_size +
334
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
335
+ if patch_newline_mask and patch_newline_mask[i]:
336
+ text += "<patch_newline>"
337
+ token_ids.append(
338
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
339
+ return text, token_ids
340
+
341
+ def _get_image_repl(
342
+ self,
343
+ num_images: int,
344
+ ) -> tuple[str, list[int]]:
345
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
346
+ token_ids = [
347
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
348
+ ] + [self.image_token_id] * self.num_image_feature_size + [
349
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
350
+ ]
351
+ return text * num_images, token_ids * num_images
352
+
353
+ def _get_image_repl_features(
354
+ self,
355
+ num_images: int,
356
+ num_patches: int,
357
+ patch_new_line_idx: Optional[list[bool]],
358
+ ) -> tuple[str, list[int]]:
359
+ if num_patches > 0:
360
+ patch_repl, patch_repl_ids = self._get_patch_repl(
361
+ num_patches, patch_new_line_idx)
362
+ else:
363
+ patch_repl = ""
364
+ patch_repl_ids = []
365
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
366
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
367
+
368
+ def replace_placeholder(self, text: str, placeholder: str,
369
+ repls: list[str]) -> str:
370
+ parts = text.split(placeholder)
371
+
372
+ if len(parts) - 1 != len(repls):
373
+ raise ValueError(
374
+ "The number of placeholders does not match the number of replacements." # noqa: E501
375
+ )
376
+
377
+ result = [parts[0]]
378
+ for i, repl in enumerate(repls):
379
+ result.append(repl)
380
+ result.append(parts[i + 1])
381
+
382
+ return "".join(result)
383
+
384
+ def __call__(
385
+ self,
386
+ text: Optional[Union[str, list[str]]] = None,
387
+ images: ImageInput | None = None,
388
+ return_tensors: Optional[Union[str, TensorType]] = None,
389
+ **kwargs,
390
+ ) -> BatchFeature:
391
+
392
+ if images is not None:
393
+ images = self.image_preprocessor.fetch_images(images)
394
+ if text is None:
395
+ text = []
396
+ if not isinstance(text, list):
397
+ text = [text]
398
+ if images is None:
399
+ images = []
400
+ elif not isinstance(images, list):
401
+ images = [images]
402
+ elif isinstance(images[0], list):
403
+ images = images[0]
404
+
405
+ if len(images) == 0:
406
+ image_inputs = {}
407
+ text_inputs = self.tokenizer(text)
408
+ else:
409
+ splitted_images_data = self._split_images(images)
410
+ pixel_values_lst = []
411
+ patch_pixel_values_lst = []
412
+ patch_newline_mask_lst = []
413
+ image_repl_str_lst = []
414
+ image_repl_ids_lst = []
415
+ num_patches = []
416
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
417
+ pixel_values_lst.extend(
418
+ self._convert_images_to_pixel_values([raw_img]))
419
+
420
+ if len(img_patches) > 0:
421
+ patch_pixel_values_lst.extend(
422
+ self._convert_images_to_pixel_values(img_patches,
423
+ is_patch=True))
424
+ num_patches.append(len(img_patches))
425
+
426
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
427
+ 1, len(img_patches), patch_newline_mask)
428
+ image_repl_str_lst.append(image_repl_str)
429
+ image_repl_ids_lst.extend(image_repl_ids)
430
+
431
+ if patch_newline_mask is not None:
432
+ patch_newline_mask_lst.extend(patch_newline_mask)
433
+
434
+ image_inputs = {
435
+ "pixel_values": torch.cat(pixel_values_lst),
436
+ "num_patches": num_patches,
437
+ }
438
+ if patch_pixel_values_lst:
439
+ image_inputs["patch_pixel_values"] = torch.cat(
440
+ patch_pixel_values_lst)
441
+ if patch_newline_mask_lst:
442
+ image_inputs["patch_newline_mask"] = torch.tensor(
443
+ patch_newline_mask_lst, dtype=torch.bool)
444
+
445
+ text = [
446
+ self.replace_placeholder(t, self.image_token,
447
+ image_repl_str_lst) for t in text
448
+ ]
449
+ text_inputs = self.tokenizer(text)
450
+
451
+ return BatchFeature(
452
+ {
453
+ **text_inputs,
454
+ **image_inputs,
455
+ },
456
+ tensor_type=return_tensors,
457
+ )
458
+
459
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
460
+ def batch_decode(self, *args, **kwargs):
461
+ """
462
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
463
+ refer to the docstring of this method for more information.
464
+ """
465
+ return self.tokenizer.batch_decode(*args, **kwargs)
466
+
467
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
468
+ def decode(self, *args, **kwargs):
469
+ """
470
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
471
+ the docstring of this method for more information.
472
+ """
473
+ return self.tokenizer.decode(*args, **kwargs)
474
+
475
+ __all__ = ["Step3VLProcessor"]
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_step3.Step3VLProcessor"
4
+ },
5
+ "processor_class": "Step3VLProcessor"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin▁of▁sentence|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|end▁of▁sentence|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_step3.Step3VLProcessor"
5
+ },
6
+ "backend": "tokenizers",
7
+ "bos_token": "<|begin▁of▁sentence|>",
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "is_local": true,
11
+ "legacy": true,
12
+ "local_files_only": false,
13
+ "model_max_length": 262144,
14
+ "pad_token": "<|▁pad▁|>",
15
+ "padding_side": "left",
16
+ "processor_class": "Step3VLProcessor",
17
+ "sp_model_kwargs": {},
18
+ "tokenizer_class": "TokenizersBackend",
19
+ "unk_token": null,
20
+ "use_default_system_prompt": false,
21
+ "chat_template": "{% macro render_message_content(message) %}{% if message.content is none %}{{- '' }}{% elif message.content is string %}{{- message.content }}{% elif message.content is mapping %}{{- message.content['value'] if 'value' in message.content else message.content['text'] }}{% elif message.content is iterable %}{% set ns = namespace(needs_text_separator=false) %}{% for item in message.content %}{% if item.type == 'text' %}{% if ns.needs_text_separator %}{{- ' ' }}{% endif %}{{- item['value'] if 'value' in item else item['text'] }}{% set ns.needs_text_separator = true %}{% elif item.type == 'image' %}<im_patch>{% set ns.needs_text_separator = false %}{% endif %}{% endfor %}{% endif %}{% endmacro %}\n{{bos_token}}{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if reasoning_effort is defined %}\n {{- \"Reasoning: \" + reasoning_effort + '\\n\\n' }}\n {%- endif %}\n {%- if messages[0].role == 'system' %}\n {{- render_message_content(messages[0]) + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou have access to the following functions in JSONSchema format:\\n\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson(ensure_ascii=False) }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nIf you choose to call a function ONLY reply in the following format with NO suffix:\\n\\n<tool_call>\\n<function=example_function_name>\\n<parameter=example_parameter_1>\\nvalue_1\\n</parameter>\\n<parameter=example_parameter_2>\\nThis is the value for the second parameter\\nthat can span\\nmultiple lines\\n</parameter>\\n</function>\\n</tool_call>\\n\\n<IMPORTANT>\\nReminder:\\n- Function calls MUST follow the specified format: an inner <function=...>\\n...\\n</function> block must be nested within <tool_call>\\n...\\n</tool_call> XML tags\\n- Required parameters MUST be specified\\n</IMPORTANT><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if reasoning_effort is defined %}\n {{- \"Reasoning: \" + reasoning_effort + '\\n\\n' }}\n {%- endif %}\n {{- render_message_content(messages[0]) + '<|im_end|>\\n' }}\n {%- elif reasoning_effort is defined %}\n {{- '<|im_start|>system\\n' + \"Reasoning: \" + reasoning_effort + '\\n\\n' + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and render_message_content(message) is string and not(render_message_content(message).startswith('<tool_response>') and render_message_content(message).endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- set content = render_message_content(message) %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {%- set role_name = 'observation' if (message.role == \"system\" and not loop.first and message.name == 'observation') else message.role %}\n {{- '<|im_start|>' + role_name + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- else %}\n {%- set reasoning_content = '' %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content + '\\n</think>\\n' + content }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n<function=' + tool_call.name + '>\\n' }}\n {%- if tool_call.arguments is defined %}\n {%- set arguments = tool_call.arguments | fromjson if tool_call.arguments is string else tool_call.arguments %}\n {%- for args_name, args_value in arguments|items %}\n {{- '<parameter=' + args_name + '>\\n' }}\n {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}\n {{- args_value }}\n {{- '\\n</parameter>\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '</function>\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>tool_response\\n' }}\n {%- endif %}\n {{- '<tool_response>' }}\n {{- content }}\n {{- '</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n<think>\\n' }}\n{%- endif %}\n"
22
+ }
vision_encoder.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.activations import ACT2FN
7
+
8
+
9
+ from .configuration_step3p7 import StepRoboticsVisionEncoderConfig
10
+
11
+
12
+
13
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
14
+ """Rotate last dimension halves (used by RoPE)."""
15
+ x = x.reshape(*x.shape[:-1], -1, 2)
16
+ x1, x2 = x.unbind(dim=-1)
17
+ x = torch.stack((-x2, x1), dim=-1)
18
+ return x.reshape(*x.shape[:-2], -1)
19
+
20
+
21
+ def apply_rotary_emb(freqs: torch.Tensor,
22
+ t: torch.Tensor,
23
+ start_index: int = 0,
24
+ scale: float = 1.0,
25
+ seq_dim: int = -2) -> torch.Tensor:
26
+ """Apply 2D rotary embeddings to queries / keys."""
27
+ dtype = t.dtype
28
+
29
+ if t.ndim == 3:
30
+ seq_len = t.shape[seq_dim]
31
+ freqs = freqs[-seq_len:]
32
+
33
+ rot_dim = freqs.shape[-1]
34
+ end_index = start_index + rot_dim
35
+ assert rot_dim <= t.shape[-1], (
36
+ f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
37
+
38
+ t_left, t, t_right = (
39
+ t[..., :start_index],
40
+ t[..., start_index:end_index],
41
+ t[..., end_index:],
42
+ )
43
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
44
+ out = torch.cat((t_left, t, t_right), dim=-1)
45
+ return out.type(dtype)
46
+
47
+
48
+ class EncoderRope2D(nn.Module):
49
+ """Cacheable 2D rotary positional embedding."""
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ max_grid_height: int,
55
+ max_grid_width: int,
56
+ use_cls_token: bool = False,
57
+ theta: Union[int, float] = 10000,
58
+ max_freq: int = 10,
59
+ num_freqs: int = 1,
60
+ theta_rescale_factor: float = 1.0,
61
+ ):
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.max_grid_height = max_grid_height
65
+ self.max_grid_width = max_grid_width
66
+ self.use_cls_token = use_cls_token
67
+ self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
68
+ self.max_freq = max_freq
69
+ self.num_freqs = num_freqs
70
+ cache = self._compute_2d_freqs()
71
+ self.register_buffer("freqs_cache", cache, persistent=False)
72
+
73
+ def _compute_inv_freq(self, base: Union[int, float],
74
+ dim: int) -> torch.Tensor:
75
+
76
+ freqs = 1.0 / (base**(
77
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
78
+ return freqs
79
+
80
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
81
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
82
+ inv_freq)
83
+ freqs = freqs.repeat_interleave(2, dim=-1)
84
+ return freqs
85
+
86
+ def _compute_2d_freqs(self) -> torch.Tensor:
87
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
88
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
89
+ if self.use_cls_token:
90
+ grid_h_range += 1
91
+ grid_w_range += 1
92
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
93
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
94
+ self.max_grid_height, self.max_grid_width, -1)
95
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
96
+ self.max_grid_height, self.max_grid_width, -1)
97
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
98
+ self.max_grid_height * self.max_grid_width, -1)
99
+ if self.use_cls_token:
100
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
101
+ freqs = freqs[None, None, ...]
102
+ return freqs
103
+
104
+ def forward(self, q: torch.Tensor, k: torch.Tensor,
105
+ grid_hw: tuple[int, int]):
106
+ # If grid matches cached shape we reuse directly to avoid recomputation.
107
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
108
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
109
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
110
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(
111
+ torch.long)
112
+ if self.use_cls_token:
113
+ positions = torch.cat(
114
+ [torch.zeros(1, device=q.device), positions + 1], dim=0)
115
+ freqs = self.freqs_cache.index_select(2, positions)
116
+ else:
117
+ freqs = self.freqs_cache
118
+ q = apply_rotary_emb(freqs, q)
119
+ k = apply_rotary_emb(freqs, k)
120
+ return q, k
121
+
122
+
123
+ class EncoderLayerScale(nn.Module):
124
+ """Per-channel residual scaling used when ls_init_value is set."""
125
+
126
+ def __init__(self, dim: int, init_values: float):
127
+ super().__init__()
128
+ self.gamma = nn.Parameter(torch.full((dim,), init_values))
129
+
130
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
131
+ return hidden_states * self.gamma
132
+
133
+
134
+ class EncoderMLP(nn.Module):
135
+ """Feed-forward network used inside each transformer block."""
136
+
137
+ def __init__(self, hidden_size: int, intermediate_size: int,
138
+ hidden_act: str):
139
+ super().__init__()
140
+ self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
141
+ self.act_fn = ACT2FN[hidden_act]
142
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+
146
+ hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
147
+ return hidden_states
148
+
149
+
150
+ class EncoderVisionAttention(nn.Module):
151
+ """Multi-head self attention with optional 2D RoPE."""
152
+
153
+ def __init__(
154
+ self,
155
+ hidden_size: int,
156
+ num_heads: int,
157
+ max_grid_height: int,
158
+ max_grid_width: int,
159
+ use_cls_token: bool = False,
160
+ use_rope2d: bool = True,
161
+ rope_theta: Union[int, float] = 10000,
162
+ rope_max_freq: int = 10,
163
+ rope_num_freqs: int = 1,
164
+ rope_theta_rescale_factor: float = 1.0,
165
+ rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
166
+ ):
167
+ super().__init__()
168
+ if hidden_size % num_heads != 0:
169
+ raise ValueError(
170
+ f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
171
+ )
172
+ self.num_heads = num_heads
173
+ self.head_dim = hidden_size // num_heads
174
+ self.scale = self.head_dim**-0.5
175
+ self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
177
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
178
+
179
+ self.rope = None
180
+ if use_rope2d:
181
+ self.rope = EncoderRope2D(
182
+ dim=self.head_dim,
183
+ max_grid_height=max_grid_height,
184
+ max_grid_width=max_grid_width,
185
+ use_cls_token=use_cls_token,
186
+ theta=rope_theta,
187
+ max_freq=rope_max_freq,
188
+ num_freqs=rope_num_freqs,
189
+ theta_rescale_factor=rope_theta_rescale_factor,
190
+ )
191
+
192
+ def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
193
+ bsz, seq_len, _ = hidden_states.shape
194
+ qkv = F.linear(
195
+ hidden_states,
196
+ self.in_proj_weight,
197
+ self.in_proj_bias,
198
+ )
199
+ q, k, v = qkv.chunk(3, dim=-1)
200
+
201
+ q = q.view(bsz, seq_len, self.num_heads,
202
+ self.head_dim).transpose(1, 2)
203
+ k = k.view(bsz, seq_len, self.num_heads,
204
+ self.head_dim).transpose(1, 2)
205
+ if self.rope is not None:
206
+ q, k = self.rope(q, k, grid_hw=grid_hw)
207
+ v = v.view(bsz, seq_len, self.num_heads,
208
+ self.head_dim).transpose(1, 2)
209
+
210
+ attn_output = F.scaled_dot_product_attention(
211
+ q, k, v, is_causal=False, scale=self.scale)
212
+ attn_output = attn_output.transpose(1, 2).reshape(
213
+ bsz, seq_len, self.num_heads * self.head_dim)
214
+ return self.out_proj(attn_output)
215
+
216
+
217
+ class EncoderVisionBlock(nn.Module):
218
+ """A single Vision Transformer block (self-attention + MLP)."""
219
+
220
+ def __init__(
221
+ self,
222
+ hidden_size: int,
223
+ num_heads: int,
224
+ mlp_ratio: float,
225
+ hidden_act: str,
226
+ layer_norm_eps: float,
227
+ ls_init_value: Optional[float] = None,
228
+ max_grid_height: Optional[int] = None,
229
+ max_grid_width: Optional[int] = None,
230
+ use_cls_token: bool = False,
231
+ use_rope2d: bool = True,
232
+ rope_kwargs: Optional[dict] = None,
233
+ ):
234
+ super().__init__()
235
+ rope_kwargs = rope_kwargs or {}
236
+ self.attn = EncoderVisionAttention(
237
+ hidden_size,
238
+ num_heads,
239
+ max_grid_height=max_grid_height,
240
+ max_grid_width=max_grid_width,
241
+ use_cls_token=use_cls_token,
242
+ use_rope2d=use_rope2d,
243
+ **rope_kwargs,
244
+ )
245
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
246
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
247
+
248
+ intermediate = int(hidden_size * mlp_ratio)
249
+ self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
250
+
251
+ self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
252
+ self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
253
+
254
+ def forward(self, hidden_states: torch.Tensor,
255
+ grid_hw: tuple[int, int]) -> torch.Tensor:
256
+ # breakpoint()
257
+ residual = hidden_states
258
+ hidden_states = self.ln_1(hidden_states)
259
+ hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
260
+ hidden_states = residual + self.ls_1(hidden_states)
261
+
262
+ residual = hidden_states
263
+ hidden_states = self.ln_2(hidden_states)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + self.ls_2(hidden_states)
266
+ return hidden_states
267
+
268
+
269
+ class EncoderVisionTransformer(nn.Module):
270
+ """Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
271
+
272
+ def __init__(
273
+ self,
274
+ embed_dim: int,
275
+ depth: int,
276
+ num_heads: int,
277
+ mlp_ratio: float,
278
+ hidden_act: str,
279
+ layer_norm_eps: float,
280
+ ls_init_value: Optional[float] = None,
281
+ max_grid_height: Optional[int] = None,
282
+ max_grid_width: Optional[int] = None,
283
+ use_cls_token: bool = False,
284
+ use_rope2d: bool = True,
285
+ rope_kwargs: Optional[dict] = None,
286
+ ):
287
+ super().__init__()
288
+ self.layers = depth
289
+ rope_kwargs = rope_kwargs or {}
290
+ self.resblocks = nn.ModuleList([
291
+ EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
292
+ layer_norm_eps,
293
+ max_grid_height=max_grid_height,
294
+ max_grid_width=max_grid_width,
295
+ use_cls_token=use_cls_token,
296
+ use_rope2d=use_rope2d,
297
+ ls_init_value=ls_init_value,
298
+ rope_kwargs=rope_kwargs)
299
+ for _ in range(depth)
300
+ ])
301
+
302
+ def forward(self,
303
+ hidden_states: torch.Tensor,
304
+ grid_hw: tuple[int, int]) -> torch.Tensor:
305
+ for block in self.resblocks:
306
+ hidden_states = block(hidden_states, grid_hw=grid_hw)
307
+ return hidden_states
308
+
309
+
310
+ class StepRoboticsVisionEncoder(nn.Module):
311
+ """
312
+ Vision encoder built from StepRoboticsVisionEncoderConfig.
313
+
314
+ The encoder performs patch embedding followed by a stack of transformer
315
+ blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
316
+ StepRoboticVLConfig.vision_config) are expected.
317
+ """
318
+
319
+ def __init__(self, config: StepRoboticsVisionEncoderConfig):
320
+ super().__init__()
321
+ self.config = config
322
+
323
+ # Align commonly used attributes so downstream code (e.g. StepRoboticVL)
324
+ # can access them without extra renaming.
325
+ self.hidden_size = config.width
326
+ self.num_heads = config.heads
327
+ self.num_hidden_layers = config.layers
328
+ self.patch_size = config.patch_size
329
+ self.image_size = config.image_size
330
+ self.use_cls_token = getattr(config, "use_cls_token", False)
331
+ self.use_rope2d = getattr(config, "use_rope2d", True)
332
+ self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
333
+ self.layer_norm_eps = config.layer_norm_eps
334
+ self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
335
+ self.ls_init_value = getattr(config, "ls_init_value", None)
336
+ self.hidden_act = config.hidden_act
337
+ self.use_ln_pre = getattr(config, "use_ln_pre", False)
338
+ self.use_ln_post = getattr(config, "use_ln_post", True)
339
+
340
+ # Patch embedding.
341
+ self.conv1 = nn.Conv2d(in_channels=config.num_channels,
342
+ out_channels=self.hidden_size,
343
+ kernel_size=self.patch_size,
344
+ stride=self.patch_size,
345
+ bias=False)
346
+
347
+ self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
348
+ self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
349
+
350
+ grid_size = self.image_size // self.patch_size
351
+ self.base_grid = (grid_size, grid_size)
352
+
353
+ if self.use_cls_token:
354
+ self.class_embedding = nn.Parameter(
355
+ torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if self.use_abs_posemb:
360
+ self.posemb_grid_size = self.image_size // self.patch_size
361
+ self.positional_embedding = nn.Parameter(
362
+ (self.hidden_size**-0.5) * torch.randn(
363
+ int(self.use_cls_token) + self.posemb_grid_size**2,
364
+ self.hidden_size,
365
+ ))
366
+
367
+ self.transformer = EncoderVisionTransformer(
368
+ embed_dim=self.hidden_size,
369
+ depth=self.num_hidden_layers,
370
+ num_heads=self.num_heads,
371
+ mlp_ratio=self.mlp_ratio,
372
+ hidden_act=self.hidden_act,
373
+ layer_norm_eps=self.layer_norm_eps,
374
+ ls_init_value=self.ls_init_value,
375
+ max_grid_height=self.base_grid[0],
376
+ max_grid_width=self.base_grid[1],
377
+ use_cls_token=self.use_cls_token,
378
+ use_rope2d=self.use_rope2d,
379
+ rope_kwargs={
380
+ "rope_theta": getattr(config, "rope_theta", 10000),
381
+ "rope_max_freq": getattr(config, "rope_max_freq", 10),
382
+ "rope_num_freqs": getattr(config, "rope_num_freqs", 1),
383
+ "rope_theta_rescale_factor":
384
+ getattr(config, "rope_theta_rescale_factor", 1.0),
385
+ "rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
386
+ },
387
+ )
388
+ self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
389
+ self.hidden_size * 2,
390
+ kernel_size=3,
391
+ stride=2,
392
+ padding=1)
393
+ self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
394
+ self.hidden_size * 4,
395
+ kernel_size=3,
396
+ stride=2,
397
+ padding=1)
398
+
399
+
400
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
401
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
402
+ return self.positional_embedding[None, ...]
403
+
404
+ pos_embed = self.positional_embedding
405
+ if self.use_cls_token:
406
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
407
+
408
+ pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
409
+ self.posemb_grid_size,
410
+ -1).permute(0, 3, 1, 2).contiguous())
411
+ pos_embed = F.interpolate(pos_embed,
412
+ size=(grid_h, grid_w),
413
+ mode="bilinear",
414
+ align_corners=False)
415
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
416
+
417
+ if self.use_cls_token:
418
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
419
+
420
+ return pos_embed[None, ...]
421
+
422
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ Args:
425
+ pixel_values: Image tensor of shape (B, C, H, W).
426
+ layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
427
+ strip_cls_token: If True and cls token is used, remove it from output.
428
+ """
429
+ bsz, _, height, width = pixel_values.shape
430
+ grid_h, grid_w = height // self.patch_size, width // self.patch_size
431
+
432
+ hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
433
+ hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
434
+
435
+ if self.use_cls_token:
436
+ cls_token = self.class_embedding.view(1, 1,
437
+ -1).expand(bsz, -1, -1)
438
+ hidden_state = torch.cat([cls_token, hidden_state], dim=1)
439
+
440
+ if self.use_abs_posemb:
441
+ pos_emb = self.sample_abs_posemb(grid_h, grid_w)
442
+ hidden_state = hidden_state + pos_emb
443
+ hidden_state = self.ln_pre(hidden_state)
444
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
445
+
446
+ if self.use_ln_post:
447
+ hidden_state = self.ln_post(hidden_state)
448
+
449
+ if self.use_cls_token:
450
+ hidden_state = hidden_state[:, 1:, :]
451
+
452
+ return hidden_state