luotingdan commited on
Commit
75b1727
·
1 Parent(s): 4d7b6d7

add files

Browse files
README.md CHANGED
@@ -1,3 +1,326 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - stepfun-ai/Step3-VL-10B-Base
5
+ tags:
6
+ - quantization
7
+ - fp8
8
+ pipeline_tag: image-text-to-text
9
+ ---
10
+
11
+ <div align="center">
12
+
13
+ <div align="center" style="display: flex; justify-content: center; align-items: center;">
14
+ <img src="figures/stepfun.svg" width="25" style="margin-right: 10px;"/>
15
+ <h1 style="margin: 0; border-bottom: none;">STEP3-VL-10B</h1>
16
+ </div>
17
+
18
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20HF-StepFun/STEP3--VL--10B-blue)](https://huggingface.co/collections/stepfun-ai/step3-vl-10b)
19
+ [![ModelScope](https://img.shields.io/badge/ModelScope-StepFun/STEP3--VL--10B-624aff)](https://modelscope.cn/collections/stepfun-ai/Step3-VL-10B)
20
+ [![Paper](https://img.shields.io/badge/Paper-Arxiv-red)](https://arxiv.org/abs/2601.09668)
21
+ [![License](https://img.shields.io/badge/License-Apache%202.0-green)]()
22
+
23
+ </div>
24
+
25
+ ## 📢 News & Updates
26
+
27
+ - 🚀 **Online Demo**: Explore Step3-VL-10B on [Hugging Face Spaces](https://huggingface.co/spaces/stepfun-ai/Step3-VL-10B) !
28
+ - 📢 **[Notice] vLLM Support:** vLLM integration is now officially supported! (PR [#32329](https://github.com/vllm-project/vllm/pull/32329))
29
+ - ✅ **[Fixed] HF Inference:** Resolved the `eos_token_id` misconfiguration in `config.json` that caused infinite generation loops. (PR [#abdf3](https://huggingface.co/stepfun-ai/Step3-VL-10B/commit/abdf3618e914a9e3de0ad74efacc8b7a10f06c10))
30
+ - ✅ **[Fixing] Metric Correction:** We sincerely apologize for inaccuracies in the Qwen3VL-8B benchmarks (e.g., AIME, HMMT, LCB). The errors were caused by an incorrect max_tokens setting (mistakenly set to 32k) during our large-scale evaluation process. We are re-running the tests and will provide corrected numbers in the next version of technical report.
31
+
32
+ ## 🚀 Introduction
33
+
34
+ **STEP3-VL-10B** is a lightweight open-source foundation model designed to redefine the trade-off between compact efficiency and frontier-level multimodal intelligence. Despite its compact **10B parameter footprint**, STEP3-VL-10B excels in **visual perception**, **complex reasoning**, and **human-centric alignment**. It consistently outperforms models under the 10B scale and rivals or surpasses significantly larger open-weights models (**10×–20× its size**), such as GLM-4.6V (106B-A12B), Qwen3-VL-Thinking (235B-A22B), and top-tier proprietary flagships like Gemini 2.5 Pro and Seed-1.5-VL.
35
+
36
+ <div align="center">
37
+ <img src="figures/performance.png" alt="Performance Comparison" width="800"/>
38
+ <p><i>Figure 1: Performance comparison of STEP3-VL-10B against SOTA multimodal foundation models. SeRe: Sequential Reasoning; PaCoRe: Parallel Coordinated Reasoning.</i></p>
39
+ </div>
40
+
41
+ The success of STEP3-VL-10B is driven by two key strategic designs:
42
+
43
+ 1. **Unified Pre-training on High-Quality Multimodal Corpus:** A single-stage, fully unfrozen training strategy on a 1.2T token multimodal corpus, focusing on two foundational capabilities: **reasoning** (e.g., general knowledge and education-centric tasks) and **perception** (e.g., grounding, counting, OCR, and GUI interactions). By jointly optimizing the Perception Encoder and the Qwen3-8B decoder, STEP3-VL-10B establishes intrinsic vision-language synergy.
44
+ 2. **Scaled Multimodal Reinforcement Learning and Parallel Reasoning:** Frontier capabilities are unlocked through a rigorous post-training pipeline comprising two-stage supervised finetuning (SFT) and **over 1,400 iterations of RL** with both verifiable rewards (RLVR) and human feedback (RLHF). Beyond sequential reasoning, we adopt **Parallel Coordinated Reasoning (PaCoRe)**, which allocates test-time compute to aggregate evidence from parallel visual exploration.
45
+
46
+ ## 📥 Model Zoo
47
+
48
+ | Model Name | Type | Hugging Face | ModelScope |
49
+ | :-------------------- | :--- | :----------------------------------------------------------------: | :----------------------------------------------------------------------: |
50
+ | **STEP3-VL-10B-Base** | Base | [🤗 Download](https://huggingface.co/stepfun-ai/Step3-VL-10B-Base) | [🤖 Download](https://modelscope.cn/models/stepfun-ai/Step3-VL-10B-Base) |
51
+ | **STEP3-VL-10B** | Chat | [🤗 Download](https://huggingface.co/stepfun-ai/Step3-VL-10B) | [🤖 Download](https://modelscope.cn/models/stepfun-ai/Step3-VL-10B) |
52
+
53
+ ## 📊 Performance
54
+
55
+ STEP3-VL-10B delivers best-in-class performance across major multimodal benchmarks, establishing a new performance standard for compact models. The results demonstrate that STEP3-VL-10B is the **most powerful open-source model in the 10B parameter class**.
56
+
57
+ ### Comparison with Larger Models (10×–20× Larger)
58
+
59
+ | Benchmark | STEP3-VL-10B (SeRe) | STEP3-VL-10B (PaCoRe) | GLM-4.6V (106B-A12B) | Qwen3-VL (235B-A22B) | Gemini-2.5-Pro | Seed-1.5-VL |
60
+ | :---------------- | :-----------------: | :-------------------: | :------------------: | :------------------: | :------------: | :---------: |
61
+ | **MMMU** | 78.11 | 80.11 | 75.20 | 78.70 | **83.89** | 79.11 |
62
+ | **MathVista** | 83.97 | 85.50 | 83.51 | 85.10 | 83.88 | **85.60** |
63
+ | **MathVision** | 70.81 | **75.95** | 63.50 | 72.10 | 73.30 | 68.70 |
64
+ | **MMBench (EN)** | 92.05 | 92.38 | 92.75 | 92.70 | **93.19** | 92.11 |
65
+ | **MMStar** | 77.48 | 77.64 | 75.30 | 76.80 | **79.18** | 77.91 |
66
+ | **OCRBench** | 86.75 | **89.00** | 86.20 | 87.30 | 85.90 | 85.20 |
67
+ | **AIME 2025** | 87.66 | **94.43** | 71.88 | 83.59 | 83.96 | 64.06 |
68
+ | **HMMT 2025** | 78.18 | **92.14** | 57.29 | 67.71 | 65.68 | 51.30 |
69
+ | **LiveCodeBench** | 75.77 | **76.43** | 48.71 | 69.45 | 72.01 | 57.10 |
70
+
71
+ <!-- > **Note:** **SeRe** (Sequential Reasoning) uses a max length of 64K tokens; **PaCoRe** (Parallel Coordinated Reasoning) synthesizes 16 SeRe rollouts with a max length of 128K tokens. -->
72
+
73
+ > **Note on Inference Modes:**
74
+ >
75
+ > **SeRe (Sequential Reasoning):** The standard inference mode using sequential generation (Chain-of-Thought) with a max length of 64K tokens.
76
+ >
77
+ > **PaCoRe (Parallel Coordinated Reasoning):** An advanced mode that scales test-time compute. It aggregates evidence from **16 parallel rollouts** to synthesize a final answer, utilizing a max context length of 128K tokens.
78
+ >
79
+ > _Unless otherwise stated, scores below refer to the standard SeRe mode. Higher scores achieved via PaCoRe are explicitly marked._
80
+
81
+ ### Comparison with Open-Source Models (7B–10B)
82
+
83
+ | Category | Benchmark | STEP3-VL-10B | GLM-4.6V-Flash (9B) | Qwen3-VL-Thinking (8B) | InternVL-3.5 (8B) | MiMo-VL-RL-2508 (7B) |
84
+ | :----------------- | :--------------- | :----------: | :-----------------: | :--------------------: | :---------------: | :------------------: |
85
+ | **STEM Reasoning** | MMMU | **78.11** | 71.17 | 73.53 | 71.69 | 71.14 |
86
+ | | MathVision | **70.81** | 54.05 | 59.60 | 52.05 | 59.65 |
87
+ | | MathVista | **83.97** | 82.85 | 78.50 | 76.78 | 79.86 |
88
+ | | PhyX | **59.45** | 52.28 | 57.67 | 50.51 | 56.00 |
89
+ | **Recognition** | MMBench (EN) | **92.05** | 91.04 | 90.55 | 88.20 | 89.91 |
90
+ | | MMStar | **77.48** | 74.26 | 73.58 | 69.83 | 72.93 |
91
+ | | ReMI | **67.29** | 60.75 | 57.17 | 52.65 | 63.13 |
92
+ | **OCR & Document** | OCRBench | **86.75** | 85.97 | 82.85 | 83.70 | 85.40 |
93
+ | | AI2D | **89.35** | 88.93 | 83.32 | 82.34 | 84.96 |
94
+ | **GUI Grounding** | ScreenSpot-V2 | 92.61 | 92.14 | **93.60** | 84.02 | 90.82 |
95
+ | | ScreenSpot-Pro | **51.55** | 45.68 | 46.60 | 15.39 | 34.84 |
96
+ | | OSWorld-G | **59.02** | 54.71 | 56.70 | 31.91 | 50.54 |
97
+ | **Spatial** | BLINK | **66.79** | 64.90 | 62.78 | 55.40 | 62.57 |
98
+ | | All-Angles-Bench | **57.21** | 53.24 | 45.88 | 45.29 | 51.62 |
99
+ | **Code** | HumanEval-V | **66.05** | 29.26 | 26.94 | 24.31 | 31.96 |
100
+
101
+ ### Key Capabilities
102
+
103
+ - **STEM Reasoning:** Achieves **94.43%** on AIME 2025 and **75.95%** on MathVision (with PaCoRe), demonstrating exceptional complex reasoning capabilities that outperform models 10×–20× larger.
104
+ - **Visual Perception:** Records **92.05%** on MMBench and **80.11%** on MMMU, establishing strong general visual understanding and multimodal reasoning.
105
+ - **GUI & OCR:** Delivers state-of-the-art performance on ScreenSpot-V2 (**92.61%**), ScreenSpot-Pro (**51.55%**), and OCRBench (**86.75%**), optimized for agentic and document understanding tasks.
106
+ - **Spatial Understanding:** Demonstrates emergent spatial awareness with **66.79%** on BLINK and **57.21%** on All-Angles-Bench, establishing strong potential for embodied intelligence applications.
107
+
108
+ ## 🏗️ Architecture & Training
109
+
110
+ ### Architecture
111
+
112
+ - **Visual Encoder:** PE-lang (Language-Optimized Perception Encoder), 1.8B parameters.
113
+ - **Decoder:** Qwen3-8B.
114
+ - **Projector:** Two consecutive stride-2 layers (resulting in 16× spatial downsampling).
115
+ - **Resolution:** Multi-crop strategy consisting of a 728×728 global view and multiple 504×504 local crops.
116
+
117
+ ### Training Pipeline
118
+
119
+ - **Pre-training:** Single-stage, fully unfrozen strategy using AdamW optimizer (Total: 1.2T tokens, 370K iterations).
120
+ - Phase 1: 900B tokens.
121
+ - Phase 2: 300B tokens.
122
+ - **Supervised Finetuning (SFT):** Two-stage approach (Total: ~226B tokens).
123
+ - Stage 1: 9:1 text-to-multimodal ratio (~190B tokens).
124
+ - Stage 2: 1:1 text-to-multimodal ratio (~36B tokens).
125
+ - **Reinforcement Learning:** Total >1,400 iterations.
126
+ - **RLVR:** 600 iterations (Tasks: mathematics, geometry, physics, perception, grounding).
127
+ - **RLHF:** 300 iterations (Task: open-ended generation).
128
+ - **PaCoRe Training:** 500 iterations (Context length: 64K max sequence).
129
+
130
+ ## 🛠️ Quick Start
131
+
132
+ **Deployment Resource Specifications**
133
+
134
+ - Model Weights: 14 GB
135
+ - Runtime Overhead: ~4 GB
136
+ - Minimum VRAM Required: 24 GB (e.g., RTX 4090 or A100)
137
+
138
+ ### Inference with Hugging Face Transformers
139
+
140
+ We introduce how to use our model at inference stage using transformers library. It is recommended to use python=3.10, torch>=2.1.0, and transformers=4.57.0 as the development environment.We currently only support bf16 inference, and multi-patch for image preprocessing is supported by default. This behavior is aligned with vllm.
141
+
142
+ **Note:** If you experience infinite generation issues, please check [Discussion #9](https://huggingface.co/stepfun-ai/Step3-VL-10B/discussions/9) for the fix.
143
+
144
+ ```python
145
+ from transformers import AutoProcessor, AutoModelForCausalLM
146
+
147
+
148
+ key_mapping = {
149
+ "^vision_model": "model.vision_model",
150
+ r"^model(?!\.(language_model|vision_model))": "model.language_model",
151
+ "vit_large_projector": "model.vit_large_projector",
152
+ }
153
+
154
+ model_path = "stepfun-ai/Step3-VL-10B"
155
+
156
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
157
+
158
+ messages = [
159
+ {
160
+ "role": "user",
161
+ "content": [
162
+ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
163
+ {"type": "text", "text": "What's in this picture?"}
164
+ ]
165
+ },
166
+ ]
167
+
168
+ model = AutoModelForCausalLM.from_pretrained(
169
+ model_path,
170
+ trust_remote_code=True,
171
+ device_map="auto",
172
+ torch_dtype="auto",
173
+ key_mapping=key_mapping).eval()
174
+
175
+
176
+ inputs = processor.apply_chat_template(
177
+ messages, add_generation_prompt=True, tokenize=True,
178
+ return_dict=True, return_tensors="pt"
179
+ ).to(model.device)
180
+
181
+
182
+ generate_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
183
+ decoded = processor.decode(generate_ids[0, inputs["input_ids"].shape[-1] :], skip_special_tokens=True)
184
+
185
+ print(decoded)
186
+ ```
187
+
188
+ ## 🚀 Deployment with vLLM (OpenAI-compatible API)
189
+
190
+ For deployment, you can use vllm to create an OpenAI-compatible API endpoint.
191
+
192
+ 1. Install vLLM nightly (choose one):
193
+ - **Python / pip**
194
+
195
+ ```bash
196
+ pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
197
+ ```
198
+
199
+ Python ≥3.10 is required. Please ensure vLLM version >= 0.14.0rc2.dev143+gc0a350ca7.
200
+
201
+ - **Docker (nightly image)**
202
+
203
+ ```bash
204
+ docker pull vllm/vllm-openai:nightly-963dc0b865a3b6011fde7e0d938f86245dccbfac
205
+ ```
206
+
207
+ The tag above pins the nightly build we validated; update to the latest nightly tag if needed.
208
+
209
+ 2. Launch the server:
210
+
211
+ ```bash
212
+ vllm serve --model stepfun-ai/Step3-VL-10B -tp 1 --reasoning-parser deepseek_r1 --enable-auto-tool-choice --tool-call-parser hermes --trust-remote-code
213
+ ```
214
+
215
+ **Crucial Step:**
216
+ You must append the --trust-remote-code flag to your deployment command. This is mandatory for models that utilize custom code for their architecture.
217
+
218
+ 3. Call the endpoint using any OpenAI-compatible SDK (example in Python):
219
+
220
+ ```python
221
+ from openai import OpenAI
222
+
223
+ client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
224
+
225
+ resp = client.chat.completions.create(
226
+ model="stepfun-ai/Step3-VL-10B",
227
+ messages=[{
228
+ "role":
229
+ "user",
230
+ "content": [{
231
+ "type": "image_url",
232
+ "image_url": {
233
+ "url":
234
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
235
+ }
236
+ }, {
237
+ "type": "text",
238
+ "text": "what's in this picture?"
239
+ }]
240
+ }])
241
+
242
+ print(resp.choices[0].message.content)
243
+
244
+
245
+ ```
246
+
247
+ ## 🚀 Deployment with SGLang (OpenAI-compatible API)
248
+
249
+ 1. Install SGLang latest main (choose one):
250
+
251
+ - **Python / pip**
252
+
253
+ ```bash
254
+ pip install "sglang @ git+https://github.com/sgl-project/sglang.git#subdirectory=python"
255
+ pip install nvidia-cudnn-cu12==9.16.0.29
256
+ ```
257
+
258
+ - **Docker**
259
+ ```bash
260
+ docker run --gpus all \
261
+ --shm-size 32g \
262
+ -p 30000:30000 \
263
+ -v ~/.cache/huggingface:/root/.cache/huggingface \
264
+ --ipc=host \
265
+ lmsysorg/sglang:latest \
266
+ python3 -m sglang.launch_server --model-path stepfun-ai/Step3-VL-10B-FP8 --host 0.0.0.0 --port 30000
267
+ ```
268
+
269
+ 2. Launch the server:
270
+
271
+ ```
272
+ sglang serve --model-path stepfun-ai/Step3-VL-10B-FP8 --trust-remote-code --port 2345 --reasoning-parser deepseek-r1 --tool-call-parser hermes
273
+ ```
274
+
275
+ 3. Call the endpoint using any OpenAI-compatible SDK (example in Python):
276
+
277
+ ```
278
+ from openai import OpenAI
279
+
280
+ port = 30000
281
+
282
+ client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="None")
283
+
284
+ response = client.chat.completions.create(
285
+ model="stepfun-ai/Step3-VL-10B-FP8",
286
+ messages=[
287
+ {
288
+ "role": "user",
289
+ "content": [
290
+ {
291
+ "type": "text",
292
+ "text": "What is in this image?",
293
+ },
294
+ {
295
+ "type": "image_url",
296
+ "image_url": {
297
+ "url": "https://github.com/sgl-project/sglang/blob/main/examples/assets/example_image.png?raw=true"
298
+ },
299
+ },
300
+ ],
301
+ }
302
+ ],
303
+ )
304
+
305
+ print(response.choices[0].message.content)
306
+ ```
307
+
308
+ ## 📜 Citation
309
+
310
+ If you find this project useful in your research, please cite our technical report:
311
+
312
+ ```tex
313
+ @misc{huang2026step3vl10btechnicalreport,
314
+ title={STEP3-VL-10B Technical Report},
315
+ author={Ailin Huang and Chengyuan Yao and Chunrui Han and Fanqi Wan and Hangyu Guo and Haoran Lv and Hongyu Zhou and Jia Wang and Jian Zhou and Jianjian Sun and Jingcheng Hu and Kangheng Lin and Liang Zhao and Mitt Huang and Song Yuan and Wenwen Qu and Xiangfeng Wang and Yanlin Lai and Yingxiu Zhao and Yinmin Zhang and Yukang Shi and Yuyang Chen and Zejia Weng and Ziyang Meng and Ang Li and Aobo Kong and Bo Dong and Changyi Wan and David Wang and Di Qi and Dingming Li and En Yu and Guopeng Li and Haiquan Yin and Han Zhou and Hanshan Zhang and Haolong Yan and Hebin Zhou and Hongbo Peng and Jiaran Zhang and Jiashu Lv and Jiayi Fu and Jie Cheng and Jie Zhou and Jisheng Yin and Jingjing Xie and Jingwei Wu and Jun Zhang and Junfeng Liu and Kaijun Tan and Kaiwen Yan and Liangyu Chen and Lina Chen and Mingliang Li and Qian Zhao and Quan Sun and Shaoliang Pang and Shengjie Fan and Shijie Shang and Siyuan Zhang and Tianhao You and Wei Ji and Wuxun Xie and Xiaobo Yang and Xiaojie Hou and Xiaoran Jiao and Xiaoxiao Ren and Xiangwen Kong and Xin Huang and Xin Wu and Xing Chen and Xinran Wang and Xuelin Zhang and Yana Wei and Yang Li and Yanming Xu and Yeqing Shen and Yuang Peng and Yue Peng and Yu Zhou and Yusheng Li and Yuxiang Yang and Yuyang Zhang and Zhe Xie and Zhewei Huang and Zhenyi Lu and Zhimin Fan and Zihui Cheng and Daxin Jiang and Qi Han and Xiangyu Zhang and Yibo Zhu and Zheng Ge},
316
+ year={2026},
317
+ eprint={2601.09668},
318
+ archivePrefix={arXiv},
319
+ primaryClass={cs.CV},
320
+ url={https://arxiv.org/abs/2601.09668},
321
+ }
322
+ ```
323
+
324
+ ## 📄 License
325
+
326
+ This project is open-sourced under the [Apache 2.0 License](https://www.google.com/search?q=LICENSE).
added_tokens.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_calls>": 151670,
5
+ "</tool_response>": 151666,
6
+ "<dream>": 151682,
7
+ "<dream_end>": 151684,
8
+ "<dream_start>": 151683,
9
+ "<im_end>": 151681,
10
+ "<im_patch>": 151679,
11
+ "<im_start>": 151680,
12
+ "<patch_end>": 151690,
13
+ "<patch_newline>": 151691,
14
+ "<patch_start>": 151689,
15
+ "<think>": 151667,
16
+ "<tool_call>": 151657,
17
+ "<tool_calls>": 151669,
18
+ "<tool_response>": 151665,
19
+ "<video_end>": 151688,
20
+ "<video_start>": 151687,
21
+ "<|BOT|>": 151672,
22
+ "<|CALL_END|>": 151674,
23
+ "<|CALL_START|>": 151673,
24
+ "<|EOT|>": 151671,
25
+ "<|IMG_END|>": 151678,
26
+ "<|IMG_START|>": 151677,
27
+ "<|MASK_1e69f|>": 151685,
28
+ "<|THINK_END|>": 151676,
29
+ "<|THINK_START|>": 151675,
30
+ "<|UNMASK_1e69f|>": 151686,
31
+ "<|box_end|>": 151649,
32
+ "<|box_start|>": 151648,
33
+ "<|endoftext|>": 151643,
34
+ "<|file_sep|>": 151664,
35
+ "<|fim_middle|>": 151660,
36
+ "<|fim_pad|>": 151662,
37
+ "<|fim_prefix|>": 151659,
38
+ "<|fim_suffix|>": 151661,
39
+ "<|im_end|>": 151645,
40
+ "<|im_start|>": 151644,
41
+ "<|image_pad|>": 151655,
42
+ "<|object_ref_end|>": 151647,
43
+ "<|object_ref_start|>": 151646,
44
+ "<|quad_end|>": 151651,
45
+ "<|quad_start|>": 151650,
46
+ "<|repo_name|>": 151663,
47
+ "<|video_pad|>": 151656,
48
+ "<|vision_end|>": 151653,
49
+ "<|vision_pad|>": 151654,
50
+ "<|vision_start|>": 151652
51
+ }
chat_template.jinja ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if messages[0].role == 'system' %}
5
+ {{- render_content(messages[0].content) + '\n\n' }}
6
+ {%- endif %}
7
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
8
+ {%- for tool in tools %}
9
+ {{- "\n" }}
10
+ {{- tool | tojson }}
11
+ {%- endfor %}
12
+ {{- "\n</tools>\n\nAlways adhere to this exact format for tool use:\n<tool_calls>\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>\n{additional_tool_calls}</tool_calls>\n\nNote:\n- For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags.\n- `<function-name>` must be an exact match to one of the available tools.\n- `<args-json-object>` must be valid JSON that strictly follows the tool's parameters schema.<|im_end|>\n" }}
13
+ {%- else %}
14
+ {%- if messages[0].role == 'system' %}
15
+ {{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
16
+ {%- endif %}
17
+ {%- endif %}
18
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
19
+ {%- for message in messages[::-1] %}
20
+ {%- set index = (messages|length - 1) - loop.index0 %}
21
+ {%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
22
+ {%- set ns.multi_step_tool = false %}
23
+ {%- set ns.last_query_index = index %}
24
+ {%- endif %}
25
+ {%- endfor %}
26
+ {%- for message in messages %}
27
+ {%- set content = render_content(message.content) %}
28
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
29
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
30
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
31
+ {%- elif message.role == "assistant" %}
32
+ {%- if message.reasoning_content is string %}
33
+ {%- set reasoning_content = render_content(message.reasoning_content) %}
34
+ {%- else %}
35
+ {%- if '</think>' in content %}
36
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
38
+ {%- else %}
39
+ {%- set reasoning_content = '' %}
40
+ {%- endif %}
41
+ {%- endif %}
42
+ {%- if loop.index0 > ns.last_query_index %}
43
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
44
+ {%- else %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content }}
46
+ {%- endif %}
47
+ {%- if message.tool_calls %}
48
+ {{- '\n<tool_calls>' }}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {{- '\n' }}
51
+ {%- if tool_call.function %}
52
+ {%- set tool_call = tool_call.function %}
53
+ {%- endif %}
54
+ {{- '<tool_call>\n{"name": "' }}
55
+ {{- tool_call.name }}
56
+ {{- '", "arguments": ' }}
57
+ {%- if tool_call.arguments is string %}
58
+ {{- tool_call.arguments }}
59
+ {%- else %}
60
+ {{- tool_call.arguments | tojson }}
61
+ {%- endif %}
62
+ {{- '}\n</tool_call>' }}
63
+ {%- endfor %}
64
+ {{- '\n</tool_calls>' }}
65
+ {%- endif %}
66
+ {{- '<|im_end|>\n' }}
67
+ {%- elif message.role == "tool" %}
68
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
69
+ {{- '<|im_start|>tool_response' }}
70
+ {%- endif %}
71
+ {{- '\n<tool_response>\n' }}
72
+ {{- content }}
73
+ {{- '\n</tool_response>' }}
74
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
75
+ {{- '<|im_end|>\n' }}
76
+ {%- endif %}
77
+ {%- endif %}
78
+ {%- endfor %}
79
+ {%- if add_generation_prompt %}
80
+ {{- '<|im_start|>assistant\n<think>\n' }}
81
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StepVLForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_step_vl.StepRoboticsConfig",
7
+ "AutoModelForCausalLM": "modeling_step_vl.Step3VL10BForCausalLM"
8
+ },
9
+ "model_type": "step_robotics",
10
+ "im_end_token": "<im_end>",
11
+ "im_patch_token": "<im_patch>",
12
+ "im_start_token": "<im_start>",
13
+ "image_token_len": 169,
14
+ "patch_token_len": 81,
15
+ "image_token_id": 151679,
16
+ "understand_projector_stride": 2,
17
+ "use_im_start_end": "true",
18
+ "vision_select_layer": -1,
19
+ "projector_bias": false,
20
+ "vision_config": {
21
+ "image_size": 728,
22
+ "patch_size": 14,
23
+ "width": 1536,
24
+ "layers": 47,
25
+ "heads": 16,
26
+ "pool_type": "none",
27
+ "output_dim": null,
28
+ "use_cls_token": false,
29
+ "ls_init_value": 0.1,
30
+ "use_ln_post": false,
31
+ "hidden_act": "quick_gelu"
32
+ },
33
+ "text_config": {
34
+ "architectures": [
35
+ "Qwen3ForCausalLM"
36
+ ],
37
+ "attention_bias": false,
38
+ "attention_dropout": 0.0,
39
+ "bos_token_id": 151643,
40
+ "eos_token_id": [
41
+ 151643,
42
+ 151679
43
+ ],
44
+ "head_dim": 128,
45
+ "hidden_act": "silu",
46
+ "hidden_size": 4096,
47
+ "initializer_range": 0.02,
48
+ "intermediate_size": 12288,
49
+ "max_position_embeddings": 65536,
50
+ "max_window_layers": 36,
51
+ "model_type": "qwen3",
52
+ "num_attention_heads": 32,
53
+ "num_hidden_layers": 36,
54
+ "num_key_value_heads": 8,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_scaling": null,
57
+ "rope_theta": 1000000,
58
+ "sliding_window": null,
59
+ "tie_word_embeddings": false,
60
+ "torch_dtype": "bfloat16",
61
+ "transformers_version": "4.51.0",
62
+ "use_cache": true,
63
+ "use_sliding_window": false,
64
+ "vocab_size": 151936
65
+ },
66
+ "quantization_config": {
67
+ "quant_method": "fp8",
68
+ "activation_scheme": "dynamic",
69
+ "fmt": "e4m3",
70
+ "weight_block_size": [
71
+ 128,
72
+ 128
73
+ ],
74
+ "modules_to_not_convert": [
75
+ "lm_head",
76
+ "model.visual"
77
+ ],
78
+ "ignored_layers": [
79
+ "lm_head",
80
+ "model.layers.0.mlp.down_proj",
81
+ "model.layers.0.mlp.gate_proj",
82
+ "model.layers.0.mlp.up_proj",
83
+ "model.layers.0.self_attn.k_proj",
84
+ "model.layers.0.self_attn.o_proj",
85
+ "model.layers.0.self_attn.q_proj",
86
+ "model.layers.0.self_attn.v_proj",
87
+ "model.layers.1.self_attn.k_proj",
88
+ "model.layers.1.self_attn.o_proj",
89
+ "model.layers.1.self_attn.q_proj",
90
+ "model.layers.1.self_attn.v_proj",
91
+ "model.layers.10.self_attn.k_proj",
92
+ "model.layers.10.self_attn.o_proj",
93
+ "model.layers.10.self_attn.q_proj",
94
+ "model.layers.10.self_attn.v_proj",
95
+ "model.layers.11.self_attn.k_proj",
96
+ "model.layers.11.self_attn.o_proj",
97
+ "model.layers.11.self_attn.q_proj",
98
+ "model.layers.11.self_attn.v_proj",
99
+ "model.layers.12.self_attn.k_proj",
100
+ "model.layers.12.self_attn.o_proj",
101
+ "model.layers.12.self_attn.q_proj",
102
+ "model.layers.12.self_attn.v_proj",
103
+ "model.layers.13.self_attn.k_proj",
104
+ "model.layers.13.self_attn.o_proj",
105
+ "model.layers.13.self_attn.q_proj",
106
+ "model.layers.13.self_attn.v_proj",
107
+ "model.layers.14.self_attn.k_proj",
108
+ "model.layers.14.self_attn.o_proj",
109
+ "model.layers.14.self_attn.q_proj",
110
+ "model.layers.14.self_attn.v_proj",
111
+ "model.layers.15.self_attn.k_proj",
112
+ "model.layers.15.self_attn.o_proj",
113
+ "model.layers.15.self_attn.q_proj",
114
+ "model.layers.15.self_attn.v_proj",
115
+ "model.layers.16.self_attn.k_proj",
116
+ "model.layers.16.self_attn.o_proj",
117
+ "model.layers.16.self_attn.q_proj",
118
+ "model.layers.16.self_attn.v_proj",
119
+ "model.layers.17.self_attn.k_proj",
120
+ "model.layers.17.self_attn.o_proj",
121
+ "model.layers.17.self_attn.q_proj",
122
+ "model.layers.17.self_attn.v_proj",
123
+ "model.layers.18.self_attn.k_proj",
124
+ "model.layers.18.self_attn.o_proj",
125
+ "model.layers.18.self_attn.q_proj",
126
+ "model.layers.18.self_attn.v_proj",
127
+ "model.layers.19.self_attn.k_proj",
128
+ "model.layers.19.self_attn.o_proj",
129
+ "model.layers.19.self_attn.q_proj",
130
+ "model.layers.19.self_attn.v_proj",
131
+ "model.layers.2.self_attn.k_proj",
132
+ "model.layers.2.self_attn.o_proj",
133
+ "model.layers.2.self_attn.q_proj",
134
+ "model.layers.2.self_attn.v_proj",
135
+ "model.layers.20.self_attn.k_proj",
136
+ "model.layers.20.self_attn.o_proj",
137
+ "model.layers.20.self_attn.q_proj",
138
+ "model.layers.20.self_attn.v_proj",
139
+ "model.layers.21.self_attn.k_proj",
140
+ "model.layers.21.self_attn.o_proj",
141
+ "model.layers.21.self_attn.q_proj",
142
+ "model.layers.21.self_attn.v_proj",
143
+ "model.layers.22.self_attn.k_proj",
144
+ "model.layers.22.self_attn.o_proj",
145
+ "model.layers.22.self_attn.q_proj",
146
+ "model.layers.22.self_attn.v_proj",
147
+ "model.layers.23.self_attn.k_proj",
148
+ "model.layers.23.self_attn.o_proj",
149
+ "model.layers.23.self_attn.q_proj",
150
+ "model.layers.23.self_attn.v_proj",
151
+ "model.layers.24.self_attn.k_proj",
152
+ "model.layers.24.self_attn.o_proj",
153
+ "model.layers.24.self_attn.q_proj",
154
+ "model.layers.24.self_attn.v_proj",
155
+ "model.layers.25.self_attn.k_proj",
156
+ "model.layers.25.self_attn.o_proj",
157
+ "model.layers.25.self_attn.q_proj",
158
+ "model.layers.25.self_attn.v_proj",
159
+ "model.layers.26.self_attn.k_proj",
160
+ "model.layers.26.self_attn.o_proj",
161
+ "model.layers.26.self_attn.q_proj",
162
+ "model.layers.26.self_attn.v_proj",
163
+ "model.layers.27.self_attn.k_proj",
164
+ "model.layers.27.self_attn.o_proj",
165
+ "model.layers.27.self_attn.q_proj",
166
+ "model.layers.27.self_attn.v_proj",
167
+ "model.layers.28.self_attn.k_proj",
168
+ "model.layers.28.self_attn.o_proj",
169
+ "model.layers.28.self_attn.q_proj",
170
+ "model.layers.28.self_attn.v_proj",
171
+ "model.layers.29.self_attn.k_proj",
172
+ "model.layers.29.self_attn.o_proj",
173
+ "model.layers.29.self_attn.q_proj",
174
+ "model.layers.29.self_attn.v_proj",
175
+ "model.layers.3.self_attn.k_proj",
176
+ "model.layers.3.self_attn.o_proj",
177
+ "model.layers.3.self_attn.q_proj",
178
+ "model.layers.3.self_attn.v_proj",
179
+ "model.layers.30.self_attn.k_proj",
180
+ "model.layers.30.self_attn.o_proj",
181
+ "model.layers.30.self_attn.q_proj",
182
+ "model.layers.30.self_attn.v_proj",
183
+ "model.layers.31.self_attn.k_proj",
184
+ "model.layers.31.self_attn.o_proj",
185
+ "model.layers.31.self_attn.q_proj",
186
+ "model.layers.31.self_attn.v_proj",
187
+ "model.layers.32.self_attn.k_proj",
188
+ "model.layers.32.self_attn.o_proj",
189
+ "model.layers.32.self_attn.q_proj",
190
+ "model.layers.32.self_attn.v_proj",
191
+ "model.layers.33.self_attn.k_proj",
192
+ "model.layers.33.self_attn.o_proj",
193
+ "model.layers.33.self_attn.q_proj",
194
+ "model.layers.33.self_attn.v_proj",
195
+ "model.layers.34.self_attn.k_proj",
196
+ "model.layers.34.self_attn.o_proj",
197
+ "model.layers.34.self_attn.q_proj",
198
+ "model.layers.34.self_attn.v_proj",
199
+ "model.layers.35.self_attn.k_proj",
200
+ "model.layers.35.self_attn.o_proj",
201
+ "model.layers.35.self_attn.q_proj",
202
+ "model.layers.35.self_attn.v_proj",
203
+ "model.layers.4.self_attn.k_proj",
204
+ "model.layers.4.self_attn.o_proj",
205
+ "model.layers.4.self_attn.q_proj",
206
+ "model.layers.4.self_attn.v_proj",
207
+ "model.layers.5.self_attn.k_proj",
208
+ "model.layers.5.self_attn.o_proj",
209
+ "model.layers.5.self_attn.q_proj",
210
+ "model.layers.5.self_attn.v_proj",
211
+ "model.layers.6.self_attn.k_proj",
212
+ "model.layers.6.self_attn.o_proj",
213
+ "model.layers.6.self_attn.q_proj",
214
+ "model.layers.6.self_attn.v_proj",
215
+ "model.layers.7.self_attn.k_proj",
216
+ "model.layers.7.self_attn.o_proj",
217
+ "model.layers.7.self_attn.q_proj",
218
+ "model.layers.7.self_attn.v_proj",
219
+ "model.layers.8.self_attn.k_proj",
220
+ "model.layers.8.self_attn.o_proj",
221
+ "model.layers.8.self_attn.q_proj",
222
+ "model.layers.8.self_attn.v_proj",
223
+ "model.layers.9.self_attn.k_proj",
224
+ "model.layers.9.self_attn.o_proj",
225
+ "model.layers.9.self_attn.q_proj",
226
+ "model.layers.9.self_attn.v_proj",
227
+ "model.vision_model",
228
+ "model.vision_model.conv1",
229
+ "model.vision_model.ln_pre",
230
+ "model.vision_model.transformer.resblocks.0.attn.out_proj",
231
+ "model.vision_model.transformer.resblocks.0.attn.qkv_proj",
232
+ "model.vision_model.transformer.resblocks.0.ln_1",
233
+ "model.vision_model.transformer.resblocks.0.ln_2",
234
+ "model.vision_model.transformer.resblocks.0.mlp.c_fc",
235
+ "model.vision_model.transformer.resblocks.0.mlp.c_proj",
236
+ "model.vision_model.transformer.resblocks.1.attn.out_proj",
237
+ "model.vision_model.transformer.resblocks.1.attn.qkv_proj",
238
+ "model.vision_model.transformer.resblocks.1.ln_1",
239
+ "model.vision_model.transformer.resblocks.1.ln_2",
240
+ "model.vision_model.transformer.resblocks.1.mlp.c_fc",
241
+ "model.vision_model.transformer.resblocks.1.mlp.c_proj",
242
+ "model.vision_model.transformer.resblocks.10.attn.out_proj",
243
+ "model.vision_model.transformer.resblocks.10.attn.qkv_proj",
244
+ "model.vision_model.transformer.resblocks.10.ln_1",
245
+ "model.vision_model.transformer.resblocks.10.ln_2",
246
+ "model.vision_model.transformer.resblocks.10.mlp.c_fc",
247
+ "model.vision_model.transformer.resblocks.10.mlp.c_proj",
248
+ "model.vision_model.transformer.resblocks.11.attn.out_proj",
249
+ "model.vision_model.transformer.resblocks.11.attn.qkv_proj",
250
+ "model.vision_model.transformer.resblocks.11.ln_1",
251
+ "model.vision_model.transformer.resblocks.11.ln_2",
252
+ "model.vision_model.transformer.resblocks.11.mlp.c_fc",
253
+ "model.vision_model.transformer.resblocks.11.mlp.c_proj",
254
+ "model.vision_model.transformer.resblocks.12.attn.out_proj",
255
+ "model.vision_model.transformer.resblocks.12.attn.qkv_proj",
256
+ "model.vision_model.transformer.resblocks.12.ln_1",
257
+ "model.vision_model.transformer.resblocks.12.ln_2",
258
+ "model.vision_model.transformer.resblocks.12.mlp.c_fc",
259
+ "model.vision_model.transformer.resblocks.12.mlp.c_proj",
260
+ "model.vision_model.transformer.resblocks.13.attn.out_proj",
261
+ "model.vision_model.transformer.resblocks.13.attn.qkv_proj",
262
+ "model.vision_model.transformer.resblocks.13.ln_1",
263
+ "model.vision_model.transformer.resblocks.13.ln_2",
264
+ "model.vision_model.transformer.resblocks.13.mlp.c_fc",
265
+ "model.vision_model.transformer.resblocks.13.mlp.c_proj",
266
+ "model.vision_model.transformer.resblocks.14.attn.out_proj",
267
+ "model.vision_model.transformer.resblocks.14.attn.qkv_proj",
268
+ "model.vision_model.transformer.resblocks.14.ln_1",
269
+ "model.vision_model.transformer.resblocks.14.ln_2",
270
+ "model.vision_model.transformer.resblocks.14.mlp.c_fc",
271
+ "model.vision_model.transformer.resblocks.14.mlp.c_proj",
272
+ "model.vision_model.transformer.resblocks.15.attn.out_proj",
273
+ "model.vision_model.transformer.resblocks.15.attn.qkv_proj",
274
+ "model.vision_model.transformer.resblocks.15.ln_1",
275
+ "model.vision_model.transformer.resblocks.15.ln_2",
276
+ "model.vision_model.transformer.resblocks.15.mlp.c_fc",
277
+ "model.vision_model.transformer.resblocks.15.mlp.c_proj",
278
+ "model.vision_model.transformer.resblocks.16.attn.out_proj",
279
+ "model.vision_model.transformer.resblocks.16.attn.qkv_proj",
280
+ "model.vision_model.transformer.resblocks.16.ln_1",
281
+ "model.vision_model.transformer.resblocks.16.ln_2",
282
+ "model.vision_model.transformer.resblocks.16.mlp.c_fc",
283
+ "model.vision_model.transformer.resblocks.16.mlp.c_proj",
284
+ "model.vision_model.transformer.resblocks.17.attn.out_proj",
285
+ "model.vision_model.transformer.resblocks.17.attn.qkv_proj",
286
+ "model.vision_model.transformer.resblocks.17.ln_1",
287
+ "model.vision_model.transformer.resblocks.17.ln_2",
288
+ "model.vision_model.transformer.resblocks.17.mlp.c_fc",
289
+ "model.vision_model.transformer.resblocks.17.mlp.c_proj",
290
+ "model.vision_model.transformer.resblocks.18.attn.out_proj",
291
+ "model.vision_model.transformer.resblocks.18.attn.qkv_proj",
292
+ "model.vision_model.transformer.resblocks.18.ln_1",
293
+ "model.vision_model.transformer.resblocks.18.ln_2",
294
+ "model.vision_model.transformer.resblocks.18.mlp.c_fc",
295
+ "model.vision_model.transformer.resblocks.18.mlp.c_proj",
296
+ "model.vision_model.transformer.resblocks.19.attn.out_proj",
297
+ "model.vision_model.transformer.resblocks.19.attn.qkv_proj",
298
+ "model.vision_model.transformer.resblocks.19.ln_1",
299
+ "model.vision_model.transformer.resblocks.19.ln_2",
300
+ "model.vision_model.transformer.resblocks.19.mlp.c_fc",
301
+ "model.vision_model.transformer.resblocks.19.mlp.c_proj",
302
+ "model.vision_model.transformer.resblocks.2.attn.out_proj",
303
+ "model.vision_model.transformer.resblocks.2.attn.qkv_proj",
304
+ "model.vision_model.transformer.resblocks.2.ln_1",
305
+ "model.vision_model.transformer.resblocks.2.ln_2",
306
+ "model.vision_model.transformer.resblocks.2.mlp.c_fc",
307
+ "model.vision_model.transformer.resblocks.2.mlp.c_proj",
308
+ "model.vision_model.transformer.resblocks.20.attn.out_proj",
309
+ "model.vision_model.transformer.resblocks.20.attn.qkv_proj",
310
+ "model.vision_model.transformer.resblocks.20.ln_1",
311
+ "model.vision_model.transformer.resblocks.20.ln_2",
312
+ "model.vision_model.transformer.resblocks.20.mlp.c_fc",
313
+ "model.vision_model.transformer.resblocks.20.mlp.c_proj",
314
+ "model.vision_model.transformer.resblocks.21.attn.out_proj",
315
+ "model.vision_model.transformer.resblocks.21.attn.qkv_proj",
316
+ "model.vision_model.transformer.resblocks.21.ln_1",
317
+ "model.vision_model.transformer.resblocks.21.ln_2",
318
+ "model.vision_model.transformer.resblocks.21.mlp.c_fc",
319
+ "model.vision_model.transformer.resblocks.21.mlp.c_proj",
320
+ "model.vision_model.transformer.resblocks.22.attn.out_proj",
321
+ "model.vision_model.transformer.resblocks.22.attn.qkv_proj",
322
+ "model.vision_model.transformer.resblocks.22.ln_1",
323
+ "model.vision_model.transformer.resblocks.22.ln_2",
324
+ "model.vision_model.transformer.resblocks.22.mlp.c_fc",
325
+ "model.vision_model.transformer.resblocks.22.mlp.c_proj",
326
+ "model.vision_model.transformer.resblocks.23.attn.out_proj",
327
+ "model.vision_model.transformer.resblocks.23.attn.qkv_proj",
328
+ "model.vision_model.transformer.resblocks.23.ln_1",
329
+ "model.vision_model.transformer.resblocks.23.ln_2",
330
+ "model.vision_model.transformer.resblocks.23.mlp.c_fc",
331
+ "model.vision_model.transformer.resblocks.23.mlp.c_proj",
332
+ "model.vision_model.transformer.resblocks.24.attn.out_proj",
333
+ "model.vision_model.transformer.resblocks.24.attn.qkv_proj",
334
+ "model.vision_model.transformer.resblocks.24.ln_1",
335
+ "model.vision_model.transformer.resblocks.24.ln_2",
336
+ "model.vision_model.transformer.resblocks.24.mlp.c_fc",
337
+ "model.vision_model.transformer.resblocks.24.mlp.c_proj",
338
+ "model.vision_model.transformer.resblocks.25.attn.out_proj",
339
+ "model.vision_model.transformer.resblocks.25.attn.qkv_proj",
340
+ "model.vision_model.transformer.resblocks.25.ln_1",
341
+ "model.vision_model.transformer.resblocks.25.ln_2",
342
+ "model.vision_model.transformer.resblocks.25.mlp.c_fc",
343
+ "model.vision_model.transformer.resblocks.25.mlp.c_proj",
344
+ "model.vision_model.transformer.resblocks.26.attn.out_proj",
345
+ "model.vision_model.transformer.resblocks.26.attn.qkv_proj",
346
+ "model.vision_model.transformer.resblocks.26.ln_1",
347
+ "model.vision_model.transformer.resblocks.26.ln_2",
348
+ "model.vision_model.transformer.resblocks.26.mlp.c_fc",
349
+ "model.vision_model.transformer.resblocks.26.mlp.c_proj",
350
+ "model.vision_model.transformer.resblocks.27.attn.out_proj",
351
+ "model.vision_model.transformer.resblocks.27.attn.qkv_proj",
352
+ "model.vision_model.transformer.resblocks.27.ln_1",
353
+ "model.vision_model.transformer.resblocks.27.ln_2",
354
+ "model.vision_model.transformer.resblocks.27.mlp.c_fc",
355
+ "model.vision_model.transformer.resblocks.27.mlp.c_proj",
356
+ "model.vision_model.transformer.resblocks.28.attn.out_proj",
357
+ "model.vision_model.transformer.resblocks.28.attn.qkv_proj",
358
+ "model.vision_model.transformer.resblocks.28.ln_1",
359
+ "model.vision_model.transformer.resblocks.28.ln_2",
360
+ "model.vision_model.transformer.resblocks.28.mlp.c_fc",
361
+ "model.vision_model.transformer.resblocks.28.mlp.c_proj",
362
+ "model.vision_model.transformer.resblocks.29.attn.out_proj",
363
+ "model.vision_model.transformer.resblocks.29.attn.qkv_proj",
364
+ "model.vision_model.transformer.resblocks.29.ln_1",
365
+ "model.vision_model.transformer.resblocks.29.ln_2",
366
+ "model.vision_model.transformer.resblocks.29.mlp.c_fc",
367
+ "model.vision_model.transformer.resblocks.29.mlp.c_proj",
368
+ "model.vision_model.transformer.resblocks.3.attn.out_proj",
369
+ "model.vision_model.transformer.resblocks.3.attn.qkv_proj",
370
+ "model.vision_model.transformer.resblocks.3.ln_1",
371
+ "model.vision_model.transformer.resblocks.3.ln_2",
372
+ "model.vision_model.transformer.resblocks.3.mlp.c_fc",
373
+ "model.vision_model.transformer.resblocks.3.mlp.c_proj",
374
+ "model.vision_model.transformer.resblocks.30.attn.out_proj",
375
+ "model.vision_model.transformer.resblocks.30.attn.qkv_proj",
376
+ "model.vision_model.transformer.resblocks.30.ln_1",
377
+ "model.vision_model.transformer.resblocks.30.ln_2",
378
+ "model.vision_model.transformer.resblocks.30.mlp.c_fc",
379
+ "model.vision_model.transformer.resblocks.30.mlp.c_proj",
380
+ "model.vision_model.transformer.resblocks.31.attn.out_proj",
381
+ "model.vision_model.transformer.resblocks.31.attn.qkv_proj",
382
+ "model.vision_model.transformer.resblocks.31.ln_1",
383
+ "model.vision_model.transformer.resblocks.31.ln_2",
384
+ "model.vision_model.transformer.resblocks.31.mlp.c_fc",
385
+ "model.vision_model.transformer.resblocks.31.mlp.c_proj",
386
+ "model.vision_model.transformer.resblocks.32.attn.out_proj",
387
+ "model.vision_model.transformer.resblocks.32.attn.qkv_proj",
388
+ "model.vision_model.transformer.resblocks.32.ln_1",
389
+ "model.vision_model.transformer.resblocks.32.ln_2",
390
+ "model.vision_model.transformer.resblocks.32.mlp.c_fc",
391
+ "model.vision_model.transformer.resblocks.32.mlp.c_proj",
392
+ "model.vision_model.transformer.resblocks.33.attn.out_proj",
393
+ "model.vision_model.transformer.resblocks.33.attn.qkv_proj",
394
+ "model.vision_model.transformer.resblocks.33.ln_1",
395
+ "model.vision_model.transformer.resblocks.33.ln_2",
396
+ "model.vision_model.transformer.resblocks.33.mlp.c_fc",
397
+ "model.vision_model.transformer.resblocks.33.mlp.c_proj",
398
+ "model.vision_model.transformer.resblocks.34.attn.out_proj",
399
+ "model.vision_model.transformer.resblocks.34.attn.qkv_proj",
400
+ "model.vision_model.transformer.resblocks.34.ln_1",
401
+ "model.vision_model.transformer.resblocks.34.ln_2",
402
+ "model.vision_model.transformer.resblocks.34.mlp.c_fc",
403
+ "model.vision_model.transformer.resblocks.34.mlp.c_proj",
404
+ "model.vision_model.transformer.resblocks.35.attn.out_proj",
405
+ "model.vision_model.transformer.resblocks.35.attn.qkv_proj",
406
+ "model.vision_model.transformer.resblocks.35.ln_1",
407
+ "model.vision_model.transformer.resblocks.35.ln_2",
408
+ "model.vision_model.transformer.resblocks.35.mlp.c_fc",
409
+ "model.vision_model.transformer.resblocks.35.mlp.c_proj",
410
+ "model.vision_model.transformer.resblocks.36.attn.out_proj",
411
+ "model.vision_model.transformer.resblocks.36.attn.qkv_proj",
412
+ "model.vision_model.transformer.resblocks.36.ln_1",
413
+ "model.vision_model.transformer.resblocks.36.ln_2",
414
+ "model.vision_model.transformer.resblocks.36.mlp.c_fc",
415
+ "model.vision_model.transformer.resblocks.36.mlp.c_proj",
416
+ "model.vision_model.transformer.resblocks.37.attn.out_proj",
417
+ "model.vision_model.transformer.resblocks.37.attn.qkv_proj",
418
+ "model.vision_model.transformer.resblocks.37.ln_1",
419
+ "model.vision_model.transformer.resblocks.37.ln_2",
420
+ "model.vision_model.transformer.resblocks.37.mlp.c_fc",
421
+ "model.vision_model.transformer.resblocks.37.mlp.c_proj",
422
+ "model.vision_model.transformer.resblocks.38.attn.out_proj",
423
+ "model.vision_model.transformer.resblocks.38.attn.qkv_proj",
424
+ "model.vision_model.transformer.resblocks.38.ln_1",
425
+ "model.vision_model.transformer.resblocks.38.ln_2",
426
+ "model.vision_model.transformer.resblocks.38.mlp.c_fc",
427
+ "model.vision_model.transformer.resblocks.38.mlp.c_proj",
428
+ "model.vision_model.transformer.resblocks.39.attn.out_proj",
429
+ "model.vision_model.transformer.resblocks.39.attn.qkv_proj",
430
+ "model.vision_model.transformer.resblocks.39.ln_1",
431
+ "model.vision_model.transformer.resblocks.39.ln_2",
432
+ "model.vision_model.transformer.resblocks.39.mlp.c_fc",
433
+ "model.vision_model.transformer.resblocks.39.mlp.c_proj",
434
+ "model.vision_model.transformer.resblocks.4.attn.out_proj",
435
+ "model.vision_model.transformer.resblocks.4.attn.qkv_proj",
436
+ "model.vision_model.transformer.resblocks.4.ln_1",
437
+ "model.vision_model.transformer.resblocks.4.ln_2",
438
+ "model.vision_model.transformer.resblocks.4.mlp.c_fc",
439
+ "model.vision_model.transformer.resblocks.4.mlp.c_proj",
440
+ "model.vision_model.transformer.resblocks.40.attn.out_proj",
441
+ "model.vision_model.transformer.resblocks.40.attn.qkv_proj",
442
+ "model.vision_model.transformer.resblocks.40.ln_1",
443
+ "model.vision_model.transformer.resblocks.40.ln_2",
444
+ "model.vision_model.transformer.resblocks.40.mlp.c_fc",
445
+ "model.vision_model.transformer.resblocks.40.mlp.c_proj",
446
+ "model.vision_model.transformer.resblocks.41.attn.out_proj",
447
+ "model.vision_model.transformer.resblocks.41.attn.qkv_proj",
448
+ "model.vision_model.transformer.resblocks.41.ln_1",
449
+ "model.vision_model.transformer.resblocks.41.ln_2",
450
+ "model.vision_model.transformer.resblocks.41.mlp.c_fc",
451
+ "model.vision_model.transformer.resblocks.41.mlp.c_proj",
452
+ "model.vision_model.transformer.resblocks.42.attn.out_proj",
453
+ "model.vision_model.transformer.resblocks.42.attn.qkv_proj",
454
+ "model.vision_model.transformer.resblocks.42.ln_1",
455
+ "model.vision_model.transformer.resblocks.42.ln_2",
456
+ "model.vision_model.transformer.resblocks.42.mlp.c_fc",
457
+ "model.vision_model.transformer.resblocks.42.mlp.c_proj",
458
+ "model.vision_model.transformer.resblocks.43.attn.out_proj",
459
+ "model.vision_model.transformer.resblocks.43.attn.qkv_proj",
460
+ "model.vision_model.transformer.resblocks.43.ln_1",
461
+ "model.vision_model.transformer.resblocks.43.ln_2",
462
+ "model.vision_model.transformer.resblocks.43.mlp.c_fc",
463
+ "model.vision_model.transformer.resblocks.43.mlp.c_proj",
464
+ "model.vision_model.transformer.resblocks.44.attn.out_proj",
465
+ "model.vision_model.transformer.resblocks.44.attn.qkv_proj",
466
+ "model.vision_model.transformer.resblocks.44.ln_1",
467
+ "model.vision_model.transformer.resblocks.44.ln_2",
468
+ "model.vision_model.transformer.resblocks.44.mlp.c_fc",
469
+ "model.vision_model.transformer.resblocks.44.mlp.c_proj",
470
+ "model.vision_model.transformer.resblocks.45.attn.out_proj",
471
+ "model.vision_model.transformer.resblocks.45.attn.qkv_proj",
472
+ "model.vision_model.transformer.resblocks.45.ln_1",
473
+ "model.vision_model.transformer.resblocks.45.ln_2",
474
+ "model.vision_model.transformer.resblocks.45.mlp.c_fc",
475
+ "model.vision_model.transformer.resblocks.45.mlp.c_proj",
476
+ "model.vision_model.transformer.resblocks.46.attn.out_proj",
477
+ "model.vision_model.transformer.resblocks.46.attn.qkv_proj",
478
+ "model.vision_model.transformer.resblocks.46.ln_1",
479
+ "model.vision_model.transformer.resblocks.46.ln_2",
480
+ "model.vision_model.transformer.resblocks.46.mlp.c_fc",
481
+ "model.vision_model.transformer.resblocks.46.mlp.c_proj",
482
+ "model.vision_model.transformer.resblocks.5.attn.out_proj",
483
+ "model.vision_model.transformer.resblocks.5.attn.qkv_proj",
484
+ "model.vision_model.transformer.resblocks.5.ln_1",
485
+ "model.vision_model.transformer.resblocks.5.ln_2",
486
+ "model.vision_model.transformer.resblocks.5.mlp.c_fc",
487
+ "model.vision_model.transformer.resblocks.5.mlp.c_proj",
488
+ "model.vision_model.transformer.resblocks.6.attn.out_proj",
489
+ "model.vision_model.transformer.resblocks.6.attn.qkv_proj",
490
+ "model.vision_model.transformer.resblocks.6.ln_1",
491
+ "model.vision_model.transformer.resblocks.6.ln_2",
492
+ "model.vision_model.transformer.resblocks.6.mlp.c_fc",
493
+ "model.vision_model.transformer.resblocks.6.mlp.c_proj",
494
+ "model.vision_model.transformer.resblocks.7.attn.out_proj",
495
+ "model.vision_model.transformer.resblocks.7.attn.qkv_proj",
496
+ "model.vision_model.transformer.resblocks.7.ln_1",
497
+ "model.vision_model.transformer.resblocks.7.ln_2",
498
+ "model.vision_model.transformer.resblocks.7.mlp.c_fc",
499
+ "model.vision_model.transformer.resblocks.7.mlp.c_proj",
500
+ "model.vision_model.transformer.resblocks.8.attn.out_proj",
501
+ "model.vision_model.transformer.resblocks.8.attn.qkv_proj",
502
+ "model.vision_model.transformer.resblocks.8.ln_1",
503
+ "model.vision_model.transformer.resblocks.8.ln_2",
504
+ "model.vision_model.transformer.resblocks.8.mlp.c_fc",
505
+ "model.vision_model.transformer.resblocks.8.mlp.c_proj",
506
+ "model.vision_model.transformer.resblocks.9.attn.out_proj",
507
+ "model.vision_model.transformer.resblocks.9.attn.qkv_proj",
508
+ "model.vision_model.transformer.resblocks.9.ln_1",
509
+ "model.vision_model.transformer.resblocks.9.ln_2",
510
+ "model.vision_model.transformer.resblocks.9.mlp.c_fc",
511
+ "model.vision_model.transformer.resblocks.9.mlp.c_proj",
512
+ "model.vision_model.vit_downsampler1",
513
+ "model.vision_model.vit_downsampler2",
514
+ "model.vit_large_projector",
515
+ "vision_model",
516
+ "vision_model.conv1",
517
+ "vision_model.ln_pre",
518
+ "vision_model.transformer.resblocks.0.attn.out_proj",
519
+ "vision_model.transformer.resblocks.0.attn.qkv_proj",
520
+ "vision_model.transformer.resblocks.0.ln_1",
521
+ "vision_model.transformer.resblocks.0.ln_2",
522
+ "vision_model.transformer.resblocks.0.mlp.c_fc",
523
+ "vision_model.transformer.resblocks.0.mlp.c_proj",
524
+ "vision_model.transformer.resblocks.1.attn.out_proj",
525
+ "vision_model.transformer.resblocks.1.attn.qkv_proj",
526
+ "vision_model.transformer.resblocks.1.ln_1",
527
+ "vision_model.transformer.resblocks.1.ln_2",
528
+ "vision_model.transformer.resblocks.1.mlp.c_fc",
529
+ "vision_model.transformer.resblocks.1.mlp.c_proj",
530
+ "vision_model.transformer.resblocks.10.attn.out_proj",
531
+ "vision_model.transformer.resblocks.10.attn.qkv_proj",
532
+ "vision_model.transformer.resblocks.10.ln_1",
533
+ "vision_model.transformer.resblocks.10.ln_2",
534
+ "vision_model.transformer.resblocks.10.mlp.c_fc",
535
+ "vision_model.transformer.resblocks.10.mlp.c_proj",
536
+ "vision_model.transformer.resblocks.11.attn.out_proj",
537
+ "vision_model.transformer.resblocks.11.attn.qkv_proj",
538
+ "vision_model.transformer.resblocks.11.ln_1",
539
+ "vision_model.transformer.resblocks.11.ln_2",
540
+ "vision_model.transformer.resblocks.11.mlp.c_fc",
541
+ "vision_model.transformer.resblocks.11.mlp.c_proj",
542
+ "vision_model.transformer.resblocks.12.attn.out_proj",
543
+ "vision_model.transformer.resblocks.12.attn.qkv_proj",
544
+ "vision_model.transformer.resblocks.12.ln_1",
545
+ "vision_model.transformer.resblocks.12.ln_2",
546
+ "vision_model.transformer.resblocks.12.mlp.c_fc",
547
+ "vision_model.transformer.resblocks.12.mlp.c_proj",
548
+ "vision_model.transformer.resblocks.13.attn.out_proj",
549
+ "vision_model.transformer.resblocks.13.attn.qkv_proj",
550
+ "vision_model.transformer.resblocks.13.ln_1",
551
+ "vision_model.transformer.resblocks.13.ln_2",
552
+ "vision_model.transformer.resblocks.13.mlp.c_fc",
553
+ "vision_model.transformer.resblocks.13.mlp.c_proj",
554
+ "vision_model.transformer.resblocks.14.attn.out_proj",
555
+ "vision_model.transformer.resblocks.14.attn.qkv_proj",
556
+ "vision_model.transformer.resblocks.14.ln_1",
557
+ "vision_model.transformer.resblocks.14.ln_2",
558
+ "vision_model.transformer.resblocks.14.mlp.c_fc",
559
+ "vision_model.transformer.resblocks.14.mlp.c_proj",
560
+ "vision_model.transformer.resblocks.15.attn.out_proj",
561
+ "vision_model.transformer.resblocks.15.attn.qkv_proj",
562
+ "vision_model.transformer.resblocks.15.ln_1",
563
+ "vision_model.transformer.resblocks.15.ln_2",
564
+ "vision_model.transformer.resblocks.15.mlp.c_fc",
565
+ "vision_model.transformer.resblocks.15.mlp.c_proj",
566
+ "vision_model.transformer.resblocks.16.attn.out_proj",
567
+ "vision_model.transformer.resblocks.16.attn.qkv_proj",
568
+ "vision_model.transformer.resblocks.16.ln_1",
569
+ "vision_model.transformer.resblocks.16.ln_2",
570
+ "vision_model.transformer.resblocks.16.mlp.c_fc",
571
+ "vision_model.transformer.resblocks.16.mlp.c_proj",
572
+ "vision_model.transformer.resblocks.17.attn.out_proj",
573
+ "vision_model.transformer.resblocks.17.attn.qkv_proj",
574
+ "vision_model.transformer.resblocks.17.ln_1",
575
+ "vision_model.transformer.resblocks.17.ln_2",
576
+ "vision_model.transformer.resblocks.17.mlp.c_fc",
577
+ "vision_model.transformer.resblocks.17.mlp.c_proj",
578
+ "vision_model.transformer.resblocks.18.attn.out_proj",
579
+ "vision_model.transformer.resblocks.18.attn.qkv_proj",
580
+ "vision_model.transformer.resblocks.18.ln_1",
581
+ "vision_model.transformer.resblocks.18.ln_2",
582
+ "vision_model.transformer.resblocks.18.mlp.c_fc",
583
+ "vision_model.transformer.resblocks.18.mlp.c_proj",
584
+ "vision_model.transformer.resblocks.19.attn.out_proj",
585
+ "vision_model.transformer.resblocks.19.attn.qkv_proj",
586
+ "vision_model.transformer.resblocks.19.ln_1",
587
+ "vision_model.transformer.resblocks.19.ln_2",
588
+ "vision_model.transformer.resblocks.19.mlp.c_fc",
589
+ "vision_model.transformer.resblocks.19.mlp.c_proj",
590
+ "vision_model.transformer.resblocks.2.attn.out_proj",
591
+ "vision_model.transformer.resblocks.2.attn.qkv_proj",
592
+ "vision_model.transformer.resblocks.2.ln_1",
593
+ "vision_model.transformer.resblocks.2.ln_2",
594
+ "vision_model.transformer.resblocks.2.mlp.c_fc",
595
+ "vision_model.transformer.resblocks.2.mlp.c_proj",
596
+ "vision_model.transformer.resblocks.20.attn.out_proj",
597
+ "vision_model.transformer.resblocks.20.attn.qkv_proj",
598
+ "vision_model.transformer.resblocks.20.ln_1",
599
+ "vision_model.transformer.resblocks.20.ln_2",
600
+ "vision_model.transformer.resblocks.20.mlp.c_fc",
601
+ "vision_model.transformer.resblocks.20.mlp.c_proj",
602
+ "vision_model.transformer.resblocks.21.attn.out_proj",
603
+ "vision_model.transformer.resblocks.21.attn.qkv_proj",
604
+ "vision_model.transformer.resblocks.21.ln_1",
605
+ "vision_model.transformer.resblocks.21.ln_2",
606
+ "vision_model.transformer.resblocks.21.mlp.c_fc",
607
+ "vision_model.transformer.resblocks.21.mlp.c_proj",
608
+ "vision_model.transformer.resblocks.22.attn.out_proj",
609
+ "vision_model.transformer.resblocks.22.attn.qkv_proj",
610
+ "vision_model.transformer.resblocks.22.ln_1",
611
+ "vision_model.transformer.resblocks.22.ln_2",
612
+ "vision_model.transformer.resblocks.22.mlp.c_fc",
613
+ "vision_model.transformer.resblocks.22.mlp.c_proj",
614
+ "vision_model.transformer.resblocks.23.attn.out_proj",
615
+ "vision_model.transformer.resblocks.23.attn.qkv_proj",
616
+ "vision_model.transformer.resblocks.23.ln_1",
617
+ "vision_model.transformer.resblocks.23.ln_2",
618
+ "vision_model.transformer.resblocks.23.mlp.c_fc",
619
+ "vision_model.transformer.resblocks.23.mlp.c_proj",
620
+ "vision_model.transformer.resblocks.24.attn.out_proj",
621
+ "vision_model.transformer.resblocks.24.attn.qkv_proj",
622
+ "vision_model.transformer.resblocks.24.ln_1",
623
+ "vision_model.transformer.resblocks.24.ln_2",
624
+ "vision_model.transformer.resblocks.24.mlp.c_fc",
625
+ "vision_model.transformer.resblocks.24.mlp.c_proj",
626
+ "vision_model.transformer.resblocks.25.attn.out_proj",
627
+ "vision_model.transformer.resblocks.25.attn.qkv_proj",
628
+ "vision_model.transformer.resblocks.25.ln_1",
629
+ "vision_model.transformer.resblocks.25.ln_2",
630
+ "vision_model.transformer.resblocks.25.mlp.c_fc",
631
+ "vision_model.transformer.resblocks.25.mlp.c_proj",
632
+ "vision_model.transformer.resblocks.26.attn.out_proj",
633
+ "vision_model.transformer.resblocks.26.attn.qkv_proj",
634
+ "vision_model.transformer.resblocks.26.ln_1",
635
+ "vision_model.transformer.resblocks.26.ln_2",
636
+ "vision_model.transformer.resblocks.26.mlp.c_fc",
637
+ "vision_model.transformer.resblocks.26.mlp.c_proj",
638
+ "vision_model.transformer.resblocks.27.attn.out_proj",
639
+ "vision_model.transformer.resblocks.27.attn.qkv_proj",
640
+ "vision_model.transformer.resblocks.27.ln_1",
641
+ "vision_model.transformer.resblocks.27.ln_2",
642
+ "vision_model.transformer.resblocks.27.mlp.c_fc",
643
+ "vision_model.transformer.resblocks.27.mlp.c_proj",
644
+ "vision_model.transformer.resblocks.28.attn.out_proj",
645
+ "vision_model.transformer.resblocks.28.attn.qkv_proj",
646
+ "vision_model.transformer.resblocks.28.ln_1",
647
+ "vision_model.transformer.resblocks.28.ln_2",
648
+ "vision_model.transformer.resblocks.28.mlp.c_fc",
649
+ "vision_model.transformer.resblocks.28.mlp.c_proj",
650
+ "vision_model.transformer.resblocks.29.attn.out_proj",
651
+ "vision_model.transformer.resblocks.29.attn.qkv_proj",
652
+ "vision_model.transformer.resblocks.29.ln_1",
653
+ "vision_model.transformer.resblocks.29.ln_2",
654
+ "vision_model.transformer.resblocks.29.mlp.c_fc",
655
+ "vision_model.transformer.resblocks.29.mlp.c_proj",
656
+ "vision_model.transformer.resblocks.3.attn.out_proj",
657
+ "vision_model.transformer.resblocks.3.attn.qkv_proj",
658
+ "vision_model.transformer.resblocks.3.ln_1",
659
+ "vision_model.transformer.resblocks.3.ln_2",
660
+ "vision_model.transformer.resblocks.3.mlp.c_fc",
661
+ "vision_model.transformer.resblocks.3.mlp.c_proj",
662
+ "vision_model.transformer.resblocks.30.attn.out_proj",
663
+ "vision_model.transformer.resblocks.30.attn.qkv_proj",
664
+ "vision_model.transformer.resblocks.30.ln_1",
665
+ "vision_model.transformer.resblocks.30.ln_2",
666
+ "vision_model.transformer.resblocks.30.mlp.c_fc",
667
+ "vision_model.transformer.resblocks.30.mlp.c_proj",
668
+ "vision_model.transformer.resblocks.31.attn.out_proj",
669
+ "vision_model.transformer.resblocks.31.attn.qkv_proj",
670
+ "vision_model.transformer.resblocks.31.ln_1",
671
+ "vision_model.transformer.resblocks.31.ln_2",
672
+ "vision_model.transformer.resblocks.31.mlp.c_fc",
673
+ "vision_model.transformer.resblocks.31.mlp.c_proj",
674
+ "vision_model.transformer.resblocks.32.attn.out_proj",
675
+ "vision_model.transformer.resblocks.32.attn.qkv_proj",
676
+ "vision_model.transformer.resblocks.32.ln_1",
677
+ "vision_model.transformer.resblocks.32.ln_2",
678
+ "vision_model.transformer.resblocks.32.mlp.c_fc",
679
+ "vision_model.transformer.resblocks.32.mlp.c_proj",
680
+ "vision_model.transformer.resblocks.33.attn.out_proj",
681
+ "vision_model.transformer.resblocks.33.attn.qkv_proj",
682
+ "vision_model.transformer.resblocks.33.ln_1",
683
+ "vision_model.transformer.resblocks.33.ln_2",
684
+ "vision_model.transformer.resblocks.33.mlp.c_fc",
685
+ "vision_model.transformer.resblocks.33.mlp.c_proj",
686
+ "vision_model.transformer.resblocks.34.attn.out_proj",
687
+ "vision_model.transformer.resblocks.34.attn.qkv_proj",
688
+ "vision_model.transformer.resblocks.34.ln_1",
689
+ "vision_model.transformer.resblocks.34.ln_2",
690
+ "vision_model.transformer.resblocks.34.mlp.c_fc",
691
+ "vision_model.transformer.resblocks.34.mlp.c_proj",
692
+ "vision_model.transformer.resblocks.35.attn.out_proj",
693
+ "vision_model.transformer.resblocks.35.attn.qkv_proj",
694
+ "vision_model.transformer.resblocks.35.ln_1",
695
+ "vision_model.transformer.resblocks.35.ln_2",
696
+ "vision_model.transformer.resblocks.35.mlp.c_fc",
697
+ "vision_model.transformer.resblocks.35.mlp.c_proj",
698
+ "vision_model.transformer.resblocks.36.attn.out_proj",
699
+ "vision_model.transformer.resblocks.36.attn.qkv_proj",
700
+ "vision_model.transformer.resblocks.36.ln_1",
701
+ "vision_model.transformer.resblocks.36.ln_2",
702
+ "vision_model.transformer.resblocks.36.mlp.c_fc",
703
+ "vision_model.transformer.resblocks.36.mlp.c_proj",
704
+ "vision_model.transformer.resblocks.37.attn.out_proj",
705
+ "vision_model.transformer.resblocks.37.attn.qkv_proj",
706
+ "vision_model.transformer.resblocks.37.ln_1",
707
+ "vision_model.transformer.resblocks.37.ln_2",
708
+ "vision_model.transformer.resblocks.37.mlp.c_fc",
709
+ "vision_model.transformer.resblocks.37.mlp.c_proj",
710
+ "vision_model.transformer.resblocks.38.attn.out_proj",
711
+ "vision_model.transformer.resblocks.38.attn.qkv_proj",
712
+ "vision_model.transformer.resblocks.38.ln_1",
713
+ "vision_model.transformer.resblocks.38.ln_2",
714
+ "vision_model.transformer.resblocks.38.mlp.c_fc",
715
+ "vision_model.transformer.resblocks.38.mlp.c_proj",
716
+ "vision_model.transformer.resblocks.39.attn.out_proj",
717
+ "vision_model.transformer.resblocks.39.attn.qkv_proj",
718
+ "vision_model.transformer.resblocks.39.ln_1",
719
+ "vision_model.transformer.resblocks.39.ln_2",
720
+ "vision_model.transformer.resblocks.39.mlp.c_fc",
721
+ "vision_model.transformer.resblocks.39.mlp.c_proj",
722
+ "vision_model.transformer.resblocks.4.attn.out_proj",
723
+ "vision_model.transformer.resblocks.4.attn.qkv_proj",
724
+ "vision_model.transformer.resblocks.4.ln_1",
725
+ "vision_model.transformer.resblocks.4.ln_2",
726
+ "vision_model.transformer.resblocks.4.mlp.c_fc",
727
+ "vision_model.transformer.resblocks.4.mlp.c_proj",
728
+ "vision_model.transformer.resblocks.40.attn.out_proj",
729
+ "vision_model.transformer.resblocks.40.attn.qkv_proj",
730
+ "vision_model.transformer.resblocks.40.ln_1",
731
+ "vision_model.transformer.resblocks.40.ln_2",
732
+ "vision_model.transformer.resblocks.40.mlp.c_fc",
733
+ "vision_model.transformer.resblocks.40.mlp.c_proj",
734
+ "vision_model.transformer.resblocks.41.attn.out_proj",
735
+ "vision_model.transformer.resblocks.41.attn.qkv_proj",
736
+ "vision_model.transformer.resblocks.41.ln_1",
737
+ "vision_model.transformer.resblocks.41.ln_2",
738
+ "vision_model.transformer.resblocks.41.mlp.c_fc",
739
+ "vision_model.transformer.resblocks.41.mlp.c_proj",
740
+ "vision_model.transformer.resblocks.42.attn.out_proj",
741
+ "vision_model.transformer.resblocks.42.attn.qkv_proj",
742
+ "vision_model.transformer.resblocks.42.ln_1",
743
+ "vision_model.transformer.resblocks.42.ln_2",
744
+ "vision_model.transformer.resblocks.42.mlp.c_fc",
745
+ "vision_model.transformer.resblocks.42.mlp.c_proj",
746
+ "vision_model.transformer.resblocks.43.attn.out_proj",
747
+ "vision_model.transformer.resblocks.43.attn.qkv_proj",
748
+ "vision_model.transformer.resblocks.43.ln_1",
749
+ "vision_model.transformer.resblocks.43.ln_2",
750
+ "vision_model.transformer.resblocks.43.mlp.c_fc",
751
+ "vision_model.transformer.resblocks.43.mlp.c_proj",
752
+ "vision_model.transformer.resblocks.44.attn.out_proj",
753
+ "vision_model.transformer.resblocks.44.attn.qkv_proj",
754
+ "vision_model.transformer.resblocks.44.ln_1",
755
+ "vision_model.transformer.resblocks.44.ln_2",
756
+ "vision_model.transformer.resblocks.44.mlp.c_fc",
757
+ "vision_model.transformer.resblocks.44.mlp.c_proj",
758
+ "vision_model.transformer.resblocks.45.attn.out_proj",
759
+ "vision_model.transformer.resblocks.45.attn.qkv_proj",
760
+ "vision_model.transformer.resblocks.45.ln_1",
761
+ "vision_model.transformer.resblocks.45.ln_2",
762
+ "vision_model.transformer.resblocks.45.mlp.c_fc",
763
+ "vision_model.transformer.resblocks.45.mlp.c_proj",
764
+ "vision_model.transformer.resblocks.46.attn.out_proj",
765
+ "vision_model.transformer.resblocks.46.attn.qkv_proj",
766
+ "vision_model.transformer.resblocks.46.ln_1",
767
+ "vision_model.transformer.resblocks.46.ln_2",
768
+ "vision_model.transformer.resblocks.46.mlp.c_fc",
769
+ "vision_model.transformer.resblocks.46.mlp.c_proj",
770
+ "vision_model.transformer.resblocks.5.attn.out_proj",
771
+ "vision_model.transformer.resblocks.5.attn.qkv_proj",
772
+ "vision_model.transformer.resblocks.5.ln_1",
773
+ "vision_model.transformer.resblocks.5.ln_2",
774
+ "vision_model.transformer.resblocks.5.mlp.c_fc",
775
+ "vision_model.transformer.resblocks.5.mlp.c_proj",
776
+ "vision_model.transformer.resblocks.6.attn.out_proj",
777
+ "vision_model.transformer.resblocks.6.attn.qkv_proj",
778
+ "vision_model.transformer.resblocks.6.ln_1",
779
+ "vision_model.transformer.resblocks.6.ln_2",
780
+ "vision_model.transformer.resblocks.6.mlp.c_fc",
781
+ "vision_model.transformer.resblocks.6.mlp.c_proj",
782
+ "vision_model.transformer.resblocks.7.attn.out_proj",
783
+ "vision_model.transformer.resblocks.7.attn.qkv_proj",
784
+ "vision_model.transformer.resblocks.7.ln_1",
785
+ "vision_model.transformer.resblocks.7.ln_2",
786
+ "vision_model.transformer.resblocks.7.mlp.c_fc",
787
+ "vision_model.transformer.resblocks.7.mlp.c_proj",
788
+ "vision_model.transformer.resblocks.8.attn.out_proj",
789
+ "vision_model.transformer.resblocks.8.attn.qkv_proj",
790
+ "vision_model.transformer.resblocks.8.ln_1",
791
+ "vision_model.transformer.resblocks.8.ln_2",
792
+ "vision_model.transformer.resblocks.8.mlp.c_fc",
793
+ "vision_model.transformer.resblocks.8.mlp.c_proj",
794
+ "vision_model.transformer.resblocks.9.attn.out_proj",
795
+ "vision_model.transformer.resblocks.9.attn.qkv_proj",
796
+ "vision_model.transformer.resblocks.9.ln_1",
797
+ "vision_model.transformer.resblocks.9.ln_2",
798
+ "vision_model.transformer.resblocks.9.mlp.c_fc",
799
+ "vision_model.transformer.resblocks.9.mlp.c_proj",
800
+ "vision_model.vit_downsampler1",
801
+ "vision_model.vit_downsampler2",
802
+ "vit_large_projector"
803
+ ]
804
+ }
805
+ }
configuration_step_vl.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers import Qwen3Config
5
+
6
+
7
+ class StepRoboticsVisionEncoderConfig(PretrainedConfig):
8
+
9
+ def __init__(
10
+ self,
11
+ width=1536,
12
+ layers=47,
13
+ heads=16,
14
+ num_channels=3,
15
+ image_size=728,
16
+ mlp_ratio = 8960/1536,
17
+ patch_size=14,
18
+ hidden_act="quick_gelu",
19
+ layer_norm_eps=1e-5,
20
+ ues_cls_token=False,
21
+ use_ln_pre=True,
22
+ use_ln_post=False,
23
+ use_abs_posemb=True,
24
+ use_rope2d=True,
25
+ ls_init_value=0.1,
26
+ **kwargs,
27
+ ):
28
+ self.width = width
29
+ self.layers = layers
30
+ self.heads = heads
31
+ self.num_channels = num_channels
32
+ self.patch_size = patch_size
33
+ self.image_size = image_size
34
+ self.mlp_ratio = mlp_ratio
35
+ self.layer_norm_eps = layer_norm_eps
36
+ self.hidden_act = hidden_act
37
+ self.ues_cls_token = ues_cls_token
38
+ self.use_ln_pre = use_ln_pre
39
+ self.ls_init_value = ls_init_value
40
+ self.use_ln_post = use_ln_post
41
+ self.use_abs_posemb = use_abs_posemb
42
+ self.use_rope2d = use_rope2d
43
+ super().__init__(**kwargs)
44
+
45
+
46
+ class StepRoboticsConfig(PretrainedConfig):
47
+ model_type = "step_robotics"
48
+ architectures = ["StepVLForConditionalGeneration"]
49
+
50
+ def __init__(
51
+ self,
52
+ vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
53
+ text_config: Optional[Union[dict, Qwen3Config]] = None,
54
+ understand_projector_stride: int = 2,
55
+ projector_bias: bool = False,
56
+ image_token_id: int = 151679,
57
+ **kwargs,
58
+ ) -> None:
59
+ if vision_config is None:
60
+ vision_config = StepRoboticsVisionEncoderConfig()
61
+ elif isinstance(vision_config, dict):
62
+ vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
63
+ self.vision_config = vision_config
64
+
65
+ if text_config is None:
66
+ text_config = Qwen3Config()
67
+ elif isinstance(text_config, dict):
68
+ text_config = Qwen3Config(**text_config)
69
+ self.text_config = text_config
70
+
71
+ self.understand_projector_stride = understand_projector_stride
72
+ self.projector_bias = projector_bias
73
+ self.hidden_size = text_config.hidden_size
74
+ self.image_token_id = image_token_id
75
+ # Help Auto classes find the correct implementation when saving/loading.
76
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "temperature": 1.0,
3
+ "top_p": 1.0,
4
+ "top_k": 0,
5
+ "eos_token_id":[
6
+ 151643,
7
+ 151645,
8
+ 151679
9
+ ]
10
+ }
modeling_step_vl.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The STEPFUN and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Callable, Optional, Tuple, Union
17
+ from PIL import Image
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from transformers import Qwen3Model
23
+ from transformers.cache_utils import Cache, DynamicCache
24
+ from transformers.generation import GenerationMixin
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.processing_utils import Unpack
28
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
29
+
30
+ from typing import Any, Literal, Optional, TypedDict, Union
31
+
32
+ from .configuration_step_vl import StepRoboticsConfig
33
+ from .vision_encoder import StepRoboticsVisionEncoder
34
+ logger = logging.get_logger(__name__)
35
+
36
+ class StepVLImagePixelInputs(TypedDict):
37
+ type: Literal["pixel_values"]
38
+ pixel_values: torch.Tensor
39
+ patch_pixel_values: Optional[torch.Tensor]
40
+ num_patches: list[int]
41
+
42
+
43
+ class StepVLImageEmbeddingInputs(TypedDict):
44
+ type: Literal["image_embeds"]
45
+ image_embeds: torch.Tensor
46
+
47
+
48
+ StepVLImageInputs = Union[StepVLImagePixelInputs,
49
+ StepVLImageEmbeddingInputs]
50
+
51
+
52
+ @dataclass
53
+ class StepVLCausalLMOutputWithPast(ModelOutput):
54
+ r"""
55
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
56
+ Language modeling loss (for next-token prediction).
57
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
58
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
59
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
60
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
61
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
62
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
63
+ `past_key_values` input) to speed up sequential decoding.
64
+ """
65
+
66
+ loss: Optional[torch.FloatTensor] = None
67
+ last_hidden_state: Optional[torch.FloatTensor] = None
68
+ logits: torch.FloatTensor = None
69
+ past_key_values: Optional[list[torch.FloatTensor]] = None
70
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
71
+ attentions: Optional[tuple[torch.FloatTensor]] = None
72
+ image_hidden_states: Optional[torch.FloatTensor] = None
73
+
74
+ def _flatten_embeddings(embeddings) -> torch.Tensor:
75
+ """
76
+ Recursively flattens and concatenates NestedTensors on all but the last
77
+ dimension.
78
+ """
79
+
80
+ if isinstance(embeddings, torch.Tensor):
81
+ # Flatten all but the last dimension.
82
+ return embeddings.flatten(0, -2)
83
+
84
+ return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
85
+
86
+ def _embedding_count_expression(embeddings) -> str:
87
+ """
88
+ Constructs a debugging representation of the number of embeddings in the
89
+ NestedTensors.
90
+ """
91
+
92
+ if isinstance(embeddings, torch.Tensor):
93
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
94
+
95
+ return " + ".join(
96
+ _embedding_count_expression(inner) for inner in embeddings)
97
+
98
+ def _merge_multimodal_embeddings(
99
+ inputs_embeds: torch.Tensor,
100
+ is_multimodal: torch.Tensor,
101
+ multimodal_embeddings,
102
+ ) -> torch.Tensor:
103
+ """
104
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
105
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
106
+ ``input_ids``.
107
+ Note:
108
+ This updates ``inputs_embeds`` in place.
109
+ """
110
+ num_expected_tokens = is_multimodal.sum().item()
111
+ assert isinstance(num_expected_tokens, int)
112
+
113
+ flattened = _flatten_embeddings(multimodal_embeddings)
114
+ if flattened.shape[0] != num_expected_tokens:
115
+ expr = _embedding_count_expression(multimodal_embeddings)
116
+ raise ValueError(
117
+ f"Attempted to assign {expr} = {flattened.shape[0]} "
118
+ f"multimodal tokens to {num_expected_tokens} placeholders")
119
+
120
+ is_multimodal = is_multimodal.to(inputs_embeds.device)
121
+ flattened = flattened.to(inputs_embeds.device)
122
+ inputs_embeds[is_multimodal] = flattened
123
+ return inputs_embeds
124
+
125
+ def merge_multimodal_embeddings(
126
+ input_ids: torch.Tensor,
127
+ inputs_embeds: torch.Tensor,
128
+ multimodal_embeddings,
129
+ placeholder_token_id: Union[int, list[int]],
130
+ ) -> torch.Tensor:
131
+ """
132
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
133
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
134
+ ``input_ids``.
135
+
136
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
137
+ of img_start, img_break, and img_end tokens) when needed: This means
138
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
139
+ their embeddings in ``multimodal_embeddings`` since we need to
140
+ slice-merge instead of individually scattering.
141
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
142
+ - T is text token
143
+ - S is image start token
144
+ - I is image embedding token
145
+ - B is image break token
146
+ - E is image end token.
147
+
148
+ Then the image embeddings (that correspond to I's) from vision encoder
149
+ must be padded with embeddings of S, B, and E in the same order of
150
+ input_ids for a correct embedding merge.
151
+ Note:
152
+ This updates ``inputs_embeds`` in place.
153
+ """
154
+ if isinstance(placeholder_token_id, list):
155
+ placeholder_token_id = torch.tensor(placeholder_token_id,
156
+ device=input_ids.device)
157
+ return _merge_multimodal_embeddings(
158
+ inputs_embeds,
159
+ torch.isin(input_ids, placeholder_token_id),
160
+ multimodal_embeddings,
161
+ )
162
+
163
+ return _merge_multimodal_embeddings(
164
+ inputs_embeds,
165
+ (input_ids == placeholder_token_id),
166
+ multimodal_embeddings,
167
+ )
168
+
169
+ class StepRoboticsPreTrainedModel(PreTrainedModel):
170
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
171
+ # can load the config instead of failing with a NoneType error.
172
+ config_class = StepRoboticsConfig
173
+ supports_gradient_checkpointing = True
174
+ _skip_keys_device_placement = ["past_key_values"]
175
+ _supports_flash_attn = False
176
+ _supports_sdpa = True
177
+ _supports_flex_attn = True
178
+ _supports_static_cache = True
179
+ _supports_attention_backend = True
180
+
181
+
182
+ class StepRoboticsModel(StepRoboticsPreTrainedModel, GenerationMixin):
183
+ config: StepRoboticsConfig
184
+ base_model_prefix = ""
185
+ def __init__(self, config: StepRoboticsConfig):
186
+ super().__init__(config)
187
+ self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
188
+ self.language_model = Qwen3Model(config.text_config)
189
+ self.vocab_size = config.text_config.vocab_size
190
+ self.vit_large_projector = nn.Linear(
191
+ config.vision_config.width * 4,
192
+ config.text_config.hidden_size,
193
+ bias=config.projector_bias)
194
+ self.image_placeholder_token_id = config.image_token_id
195
+
196
+ # Initialize weights and apply final processing
197
+ self.post_init()
198
+
199
+ def get_input_embeddings(
200
+ self,
201
+ input_ids: torch.Tensor,
202
+ multimodal_embeddings = None,
203
+ ) -> torch.Tensor:
204
+ input_ids = input_ids.squeeze(0)
205
+ if multimodal_embeddings is None:
206
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
207
+ else:
208
+ is_text = input_ids != self.config.image_token_id
209
+ text_ids = input_ids[is_text]
210
+ text_embeds = self.language_model.embed_tokens(text_ids)
211
+
212
+ inputs_embeds = torch.empty(input_ids.shape[0],
213
+ text_embeds.shape[-1],
214
+ dtype=text_embeds.dtype,
215
+ device=text_embeds.device)
216
+ inputs_embeds[is_text] = text_embeds
217
+ inputs_embeds = merge_multimodal_embeddings(
218
+ input_ids, inputs_embeds, multimodal_embeddings,
219
+ self.config.image_token_id)
220
+ inputs_embeds = inputs_embeds.unsqueeze(0)
221
+ return inputs_embeds
222
+
223
+
224
+ def set_input_embeddings(self, value):
225
+ return self.language_model.set_input_embeddings(value)
226
+
227
+ def set_decoder(self, decoder):
228
+ self.language_model = decoder
229
+
230
+ def get_decoder(self):
231
+ return self.language_model
232
+
233
+ def _parse_and_validate_image_input(
234
+ self, **kwargs: object) -> Optional[StepVLImageInputs]:
235
+ pixel_values = kwargs.pop("pixel_values", None)
236
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
237
+ num_patches = kwargs.pop("num_patches", None)
238
+ image_embeds = kwargs.pop("image_embeds", None)
239
+
240
+ if pixel_values is None and image_embeds is None:
241
+ return None
242
+
243
+ if pixel_values is not None:
244
+ # pixel_values = flatten_bn(pixel_values, concat=True)
245
+ if pixel_values.dim() >= 3:
246
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
247
+ if patch_pixel_values is not None:
248
+ # patch_pixel_values = flatten_bn(patch_pixel_values,
249
+ # concat=True)
250
+ patch_pixel_values = patch_pixel_values.view(
251
+ -1, *patch_pixel_values.shape[-3:])
252
+ # Handle empty patch_pixel_values by setting to None
253
+ if patch_pixel_values.shape[0] == 0:
254
+ patch_pixel_values = None
255
+
256
+ return StepVLImagePixelInputs(
257
+ type="pixel_values",
258
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
259
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
260
+ self.device) if patch_pixel_values is not None else None,
261
+ num_patches=num_patches,
262
+ )
263
+
264
+ if image_embeds is not None:
265
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
266
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
267
+ else:
268
+ raise ValueError(
269
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
270
+
271
+ return StepVLImageEmbeddingInputs(
272
+ type="image_embeds",
273
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
274
+ )
275
+ return None
276
+
277
+ def _process_image_features(self,
278
+ image_features: torch.Tensor) -> torch.Tensor:
279
+ B, P = image_features.shape[:2]
280
+ HW = int(P ** 0.5)
281
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
282
+ image_features = self.vision_model.vit_downsampler1(image_features)
283
+ image_features = self.vision_model.vit_downsampler2(image_features)
284
+
285
+ B, C, HW, HW = image_features.shape
286
+ image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
287
+ image_features = self.vit_large_projector(image_features)
288
+ return image_features
289
+
290
+ def _get_vision_model_output(self,
291
+ input_tensor: torch.Tensor) -> torch.Tensor:
292
+ return self.vision_model(input_tensor)
293
+
294
+ def _process_image_input(
295
+ self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
296
+
297
+ if image_input["type"] == "image_embeds":
298
+ image_features = image_input["image_embeds"]
299
+ else:
300
+ image_features = self._get_vision_model_output(
301
+ image_input["pixel_values"])
302
+ patch_image_features = self._get_vision_model_output(
303
+ image_input["patch_pixel_values"]
304
+ ) if image_input["patch_pixel_values"] is not None else None
305
+ num_patches = image_input["num_patches"]
306
+
307
+ image_features = self._process_image_features(image_features)
308
+ patch_image_features = self._process_image_features(
309
+ patch_image_features) if patch_image_features is not None else None
310
+
311
+ merged_image_features = []
312
+ cur_patch_idx = 0
313
+ for i, num_patch in enumerate(num_patches):
314
+ cur_feature = []
315
+ if num_patch > 0:
316
+ patch_slice = patch_image_features[
317
+ cur_patch_idx:cur_patch_idx + num_patch]
318
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
319
+ cur_feature.append(image_features[i].view(
320
+ -1, image_features.shape[-1]))
321
+ cur_patch_idx += num_patch
322
+ merged_image_features.append(
323
+ torch.cat(cur_feature) if len(cur_feature) >
324
+ 1 else cur_feature[0])
325
+
326
+ return merged_image_features
327
+
328
+ def get_multimodal_embeddings(self, **kwargs):
329
+ image_input = self._parse_and_validate_image_input(**kwargs)
330
+ if image_input is None:
331
+ return None
332
+ vision_embeddings = self._process_image_input(image_input)
333
+ return vision_embeddings
334
+
335
+ @can_return_tuple
336
+ def forward(
337
+ self,
338
+ input_ids: torch.LongTensor = None,
339
+ attention_mask: Optional[torch.Tensor] = None,
340
+ position_ids: Optional[torch.LongTensor] = None,
341
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
342
+ inputs_embeds: Optional[torch.FloatTensor] = None,
343
+ labels: Optional[torch.LongTensor] = None,
344
+ use_cache: Optional[bool] = None,
345
+ output_attentions: Optional[bool] = None,
346
+ output_hidden_states: Optional[bool] = None,
347
+ return_dict: Optional[bool] = None,
348
+ cache_position: Optional[torch.LongTensor] = None,
349
+ logits_to_keep: Union[int, torch.Tensor] = 0,
350
+ images: Optional[list[Image.Image]] = None,
351
+ **kwargs: Unpack[TransformersKwargs],
352
+ ) -> Union[tuple, StepVLCausalLMOutputWithPast]:
353
+ r"""
354
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
355
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
356
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
357
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
358
+ Example:
359
+ ```python
360
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
361
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
362
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
363
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
364
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
365
+ >>> # Generate
366
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
367
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
368
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
369
+ ```"""
370
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
371
+ output_hidden_states = (
372
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
373
+ )
374
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
375
+
376
+ if inputs_embeds is None:
377
+ input_ids = input_ids
378
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
379
+ inputs_embeds = self.get_input_embeddings(input_ids,
380
+ vision_embeddings)
381
+ input_ids = None
382
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
383
+ outputs = self.language_model(
384
+ input_ids=None,
385
+ position_ids=position_ids,
386
+ attention_mask=attention_mask,
387
+ past_key_values=past_key_values,
388
+ inputs_embeds=inputs_embeds,
389
+ use_cache=use_cache,
390
+ output_attentions=output_attentions,
391
+ output_hidden_states=output_hidden_states,
392
+ return_dict=True,
393
+ cache_position=cache_position,
394
+ **kwargs,
395
+ )
396
+
397
+ output = StepVLCausalLMOutputWithPast(
398
+ last_hidden_state=outputs.last_hidden_state,
399
+ past_key_values=outputs.past_key_values,
400
+ attentions=outputs.attentions,
401
+
402
+ )
403
+ return output if return_dict else output.to_tuple()
404
+
405
+
406
+
407
+ class Step3VL10BForCausalLM(StepRoboticsPreTrainedModel, GenerationMixin):
408
+ _checkpoint_conversion_mapping = {
409
+ "^vision_model": "model.vision_model",
410
+ r"^model(?!\.(language_model|vision_model))": "model.language_model",
411
+ "^vit_large_projector": "model.vit_large_projector"
412
+ }
413
+ _tied_weights_keys = ["lm_head.weight"]
414
+ config: StepRoboticsConfig
415
+
416
+ def __init__(self, config: StepRoboticsConfig):
417
+ super().__init__(config)
418
+ self.model = StepRoboticsModel(config)
419
+ self.lm_head = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
420
+
421
+ self.post_init()
422
+
423
+ def get_input_embeddings(self):
424
+ return self.model.get_input_embeddings()
425
+
426
+ def set_input_embeddings(self, value):
427
+ self.model.set_input_embeddings(value)
428
+
429
+ def get_output_embeddings(self):
430
+ return self.model.get_output_embeddings()
431
+
432
+ def set_output_embeddings(self, new_embeddings):
433
+ self.model.set_output_embeddings(new_embeddings)
434
+
435
+ def set_decoder(self, decoder):
436
+ self.model.set_decoder(decoder)
437
+
438
+ def get_decoder(self):
439
+ return self.model.get_decoder()
440
+
441
+ @property
442
+ def language_model(self):
443
+ return self.model.language_model
444
+
445
+ @property
446
+ def visual(self):
447
+ return self.model.visual
448
+
449
+ def forward(
450
+ self,
451
+ input_ids: torch.LongTensor = None,
452
+ num_patches = None,
453
+ patch_pixel_values = None,
454
+ patch_newline_mask = None,
455
+ attention_mask: Optional[torch.Tensor] = None,
456
+ position_ids: Optional[torch.LongTensor] = None,
457
+ past_key_values: Optional[Cache] = None,
458
+ inputs_embeds: Optional[torch.FloatTensor] = None,
459
+ labels: Optional[torch.LongTensor] = None,
460
+ use_cache: Optional[bool] = None,
461
+ output_attentions: Optional[bool] = None,
462
+ output_hidden_states: Optional[bool] = None,
463
+ return_dict: Optional[bool] = None,
464
+ cache_position: Optional[torch.LongTensor] = None,
465
+ **kwargs: Unpack[TransformersKwargs],
466
+ ) -> Union[tuple, StepVLCausalLMOutputWithPast]:
467
+ r"""
468
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
469
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
470
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
471
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
472
+ Example:
473
+ ```python
474
+ >>> from PIL import Image
475
+ >>> import requests
476
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
477
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
478
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
479
+ >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
480
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
481
+ >>> image = Image.open(requests.get(url, stream=True).raw)
482
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
483
+ >>> # Generate
484
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
485
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
486
+ "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
487
+ ```"""
488
+
489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
+ output_hidden_states = (
491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
+ )
493
+
494
+ outputs = self.model(
495
+ input_ids=input_ids,
496
+ num_patches = num_patches,
497
+ patch_pixel_values = patch_pixel_values,
498
+ patch_newline_mask=patch_newline_mask,
499
+ position_ids=position_ids,
500
+ attention_mask=attention_mask,
501
+ past_key_values=past_key_values,
502
+ inputs_embeds=inputs_embeds,
503
+ use_cache=use_cache,
504
+ output_attentions=output_attentions,
505
+ output_hidden_states=output_hidden_states,
506
+ return_dict=return_dict,
507
+ cache_position=cache_position,
508
+ **kwargs,
509
+ )
510
+
511
+ hidden_states = outputs.last_hidden_state
512
+ logits = self.lm_head(hidden_states)
513
+
514
+ los = None
515
+ if labels is not None:
516
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
517
+
518
+ return StepVLCausalLMOutputWithPast(
519
+ logits=logits,
520
+ )
521
+
522
+ def prepare_inputs_for_generation(
523
+ self,
524
+ input_ids,
525
+ past_key_values=None,
526
+ inputs_embeds=None,
527
+ pixel_values=None,
528
+ attention_mask=None,
529
+ cache_position=None,
530
+ logits_to_keep=None,
531
+ **kwargs,
532
+ ):
533
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
534
+
535
+ model_inputs = super().prepare_inputs_for_generation(
536
+ input_ids,
537
+ past_key_values=past_key_values,
538
+ inputs_embeds=inputs_embeds,
539
+ attention_mask=attention_mask,
540
+ cache_position=cache_position,
541
+ logits_to_keep=logits_to_keep,
542
+ **kwargs,
543
+ )
544
+
545
+ if cache_position[0] == 0:
546
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
547
+ # Otherwise we need pixel values to be passed to model
548
+ model_inputs["pixel_values"] = pixel_values
549
+
550
+ return model_inputs
551
+
552
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
553
+ if key.startswith("language_model."):
554
+ return key[len("language_model."):], True
555
+
556
+ return key, False
557
+
processing_step3.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BaseImageProcessor, ImageProcessingMixin
2
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3
+ import math
4
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
5
+
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature, TensorType
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
19
+ from math import ceil
20
+ from itertools import product
21
+
22
+
23
+
24
+ MAX_IMAGE_SIZE: int = 3024
25
+
26
+ class Step3VLImagePixelInputs(TypedDict):
27
+ type: Literal["pixel_values"]
28
+ pixel_values: torch.Tensor
29
+ patch_pixel_values: Optional[torch.Tensor]
30
+ num_patches: list[int]
31
+
32
+
33
+ class Step3VLImageEmbeddingInputs(TypedDict):
34
+ type: Literal["image_embeds"]
35
+ image_embeds: torch.Tensor
36
+
37
+
38
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
39
+
40
+
41
+ class GPUToTensor(torch.nn.Module):
42
+
43
+ def forward(self, raw_image: Union[np.ndarray,
44
+ Image.Image]) -> torch.Tensor:
45
+ if isinstance(raw_image, Image.Image):
46
+ return transforms.ToTensor()(raw_image)
47
+ if raw_image.ndim == 2:
48
+ raw_image = raw_image[:, :, None].repeat(3, -1)
49
+ if torch.cuda.is_available():
50
+ device = torch.device("cuda")
51
+ else:
52
+ device = torch.device("cpu")
53
+ image_tensor = torch.from_numpy(raw_image).to(device)
54
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
55
+ if image_tensor.dtype == torch.uint8:
56
+ image_tensor = image_tensor.to(torch.float32).div(255)
57
+ return image_tensor
58
+
59
+ class Step3VisionProcessor(BaseImageProcessor):
60
+
61
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
62
+ mean = [0.48145466, 0.4578275, 0.40821073]
63
+ std = [0.26862954, 0.26130258, 0.27577711]
64
+ patch_size = patch_size if patch_size is not None else size
65
+
66
+ self.transform = transforms.Compose([
67
+ GPUToTensor(),
68
+ transforms.Normalize(mean, std),
69
+ transforms.Resize(
70
+ (size, size),
71
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
72
+ == "bicubic" else InterpolationMode.BILINEAR,
73
+ antialias=True),
74
+ ])
75
+
76
+ self.patch_transform = transforms.Compose([
77
+ GPUToTensor(),
78
+ transforms.Normalize(mean, std),
79
+ transforms.Resize(
80
+ (patch_size, patch_size),
81
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
82
+ == "bicubic" else InterpolationMode.BILINEAR,
83
+ antialias=True),
84
+ ]) if patch_size is not None else None
85
+
86
+ def __call__(self, image, is_patch=False):
87
+ if is_patch:
88
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
89
+ else:
90
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
91
+
92
+ class ImagePatcher:
93
+ def determine_window_size(self, long: int, short: int) -> int:
94
+ if long <= 728:
95
+ return short if long / short > 1.5 else 0
96
+ return min(short, 504) if long / short > 4 else 504
97
+ def slide_window(
98
+ self,
99
+ width: int,
100
+ height: int,
101
+ sizes: list[tuple[int, int]],
102
+ steps: list[tuple[int, int]],
103
+ img_rate_thr: float = 0.6,
104
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
105
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
106
+ windows = []
107
+ # Sliding windows.
108
+ for size, step in zip(sizes, steps):
109
+ size_w, size_h = size
110
+ step_w, step_h = step
111
+
112
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
113
+ 1)
114
+ x_start = [step_w * i for i in range(x_num)]
115
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
116
+ x_start[-1] = width - size_w
117
+
118
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
119
+ step_h + 1)
120
+ y_start = [step_h * i for i in range(y_num)]
121
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
122
+ y_start[-1] = height - size_h
123
+
124
+ start = np.array(list(product(y_start, x_start)), dtype=int)
125
+ start[:, [0, 1]] = start[:, [1, 0]]
126
+ windows.append(np.concatenate([start, start + size], axis=1))
127
+ windows = np.concatenate(windows, axis=0)
128
+
129
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
130
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
131
+
132
+ def square_pad(self, img: Image.Image) -> Image.Image:
133
+ w, h = img.size
134
+ if w == h:
135
+ return img
136
+ size = max(w, h)
137
+ padded = Image.new(img.mode, (size, size), 0)
138
+ padded.paste(img, (0, 0))
139
+ return padded
140
+
141
+ def get_image_size_for_padding(self, img_width: int,
142
+ img_height: int) -> tuple[int, int]:
143
+ ratio = img_width / img_height
144
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
145
+ new_size = max(img_height, img_width)
146
+ return new_size, new_size
147
+ return img_width, img_height
148
+
149
+ def get_image_size_for_preprocess(self, img_width: int,
150
+ img_height: int) -> tuple[int, int]:
151
+
152
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
153
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
154
+ img_width = int(img_width * scale_factor)
155
+ img_height = int(img_height * scale_factor)
156
+ return img_width, img_height
157
+
158
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
159
+ window_size: int):
160
+ w_ratio = img_width / window_size
161
+ h_ratio = img_height / window_size
162
+
163
+ if w_ratio < 1:
164
+ width_new = img_width
165
+ else:
166
+ decimal_w = w_ratio - img_width // window_size
167
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
168
+ width_new = window_size * w_ratio
169
+ if h_ratio < 1:
170
+ height_new = img_height
171
+ else:
172
+ decimal_h = h_ratio - img_height // window_size
173
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
174
+ height_new = window_size * h_ratio
175
+ return int(width_new), int(height_new)
176
+
177
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
178
+ target = img.crop((j, i, j + tw, i + th))
179
+ return target
180
+
181
+ def get_num_patches(self, img_width: int,
182
+ img_height: int) -> tuple[int, int]:
183
+ img_width, img_height = self.get_image_size_for_padding(
184
+ img_width, img_height)
185
+ img_width, img_height = self.get_image_size_for_preprocess(
186
+ img_width, img_height)
187
+ window_size = self.determine_window_size(max(img_height, img_width),
188
+ min(img_height, img_width))
189
+ if window_size == 0:
190
+ return 0, 0
191
+ else:
192
+ img_width, img_height = self.get_image_size_for_crop(
193
+ img_width, img_height, window_size)
194
+ center_list, (x_num, y_num) = self.slide_window(
195
+ img_width, img_height, [(window_size, window_size)],
196
+ [(window_size, window_size)])
197
+ full_rows = (len(center_list) - 1) // x_num + 1
198
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
199
+ full_rows -= 1
200
+ return len(center_list), full_rows
201
+
202
+ def __call__(
203
+ self, img: Image.Image
204
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
205
+ img_width, img_height = img.size
206
+ new_img_width, new_img_height = self.get_image_size_for_padding(
207
+ img_width, img_height)
208
+ if new_img_width != img_width or new_img_height != img_height:
209
+ img = self.square_pad(img)
210
+ img_width, img_height = img.size
211
+
212
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
213
+ img_width, img_height)
214
+ img = img.resize((new_img_width, new_img_height),
215
+ Image.Resampling.BILINEAR)
216
+ window_size = self.determine_window_size(
217
+ max(new_img_height, new_img_width),
218
+ min(new_img_height, new_img_width))
219
+
220
+ if window_size == 0:
221
+ return img, [], None
222
+ else:
223
+ new_img_width, new_img_height = self.get_image_size_for_crop(
224
+ new_img_width, new_img_height, window_size)
225
+ if (new_img_width, new_img_height) != (img_width, img_height):
226
+ img_for_crop = img.resize((new_img_width, new_img_height),
227
+ Image.Resampling.BILINEAR)
228
+ else:
229
+ img_for_crop = img
230
+
231
+ patches = []
232
+ newlines = []
233
+ center_list, (x_num, y_num) = self.slide_window(
234
+ new_img_width, new_img_height, [(window_size, window_size)],
235
+ [(window_size, window_size)])
236
+ for patch_id, center_lf_point in enumerate(center_list):
237
+ x, y, patch_w, patch_h = center_lf_point
238
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
239
+ patch_w)
240
+ patches.append(big_patch)
241
+ if (patch_id + 1) % x_num == 0:
242
+ newlines.append(patch_id)
243
+
244
+ if newlines and newlines[-1] == len(patches) - 1:
245
+ newlines.pop()
246
+
247
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
248
+
249
+
250
+
251
+
252
+ class Step3VLProcessor(ProcessorMixin):
253
+ # Align ProcessorMixin with our custom components.
254
+ # We only have an image processor (not a feature extractor) plus a tokenizer.
255
+ attributes = ["tokenizer"]
256
+ tokenizer_class = "AutoTokenizer"
257
+
258
+ def __init__(
259
+ self,
260
+ tokenizer=None,
261
+ chat_template=None,
262
+ **kwargs
263
+ ) -> None:
264
+ self.image_size = 728
265
+ self.patch_size = 504
266
+
267
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
268
+ "bilinear",
269
+ self.patch_size)
270
+
271
+ self.num_image_feature_size = 169
272
+ self.num_patch_feature_size = 81
273
+ self.image_token = "<im_patch>"
274
+ self.image_feature_placeholder = (self.image_token *
275
+ self.num_image_feature_size)
276
+ self.patch_feature_placeholder = (self.image_token *
277
+ self.num_patch_feature_size)
278
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
279
+ self.patcher = ImagePatcher()
280
+
281
+ @property
282
+ def image_token_id(self) -> int:
283
+ return self.tokenizer.get_vocab()[self.image_token]
284
+
285
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
286
+ num_patches, num_newlines = self.patcher.get_num_patches(
287
+ img_width, img_height)
288
+
289
+ return num_patches * (
290
+ self.num_patch_feature_size +
291
+ 2) + self.num_image_feature_size + 2 + num_newlines
292
+
293
+ def _split_images(self,
294
+ images: list[Image.Image]) -> list[ImageWithPatches]:
295
+ result = []
296
+ for img in images:
297
+ result.append(self.patcher(img))
298
+ return result
299
+
300
+ def _convert_images_to_pixel_values(
301
+ self,
302
+ images: list[Image.Image],
303
+ is_patch: bool = False,
304
+ ) -> list[torch.Tensor]:
305
+ return [
306
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
307
+ for img in images
308
+ ]
309
+
310
+ def _get_patch_repl(
311
+ self,
312
+ num_patches: int,
313
+ patch_newline_mask: list[bool] | None,
314
+ ) -> tuple[str, list[int]]:
315
+ text = ""
316
+ token_ids = []
317
+ for i in range(num_patches):
318
+ assert len(patch_newline_mask) == num_patches
319
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
320
+ token_ids.extend(
321
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
322
+ [self.image_token_id] * self.num_patch_feature_size +
323
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
324
+ if patch_newline_mask and patch_newline_mask[i]:
325
+ text += "<patch_newline>"
326
+ token_ids.append(
327
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
328
+ return text, token_ids
329
+
330
+ def _get_image_repl(
331
+ self,
332
+ num_images: int,
333
+ ) -> tuple[str, list[int]]:
334
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
335
+ token_ids = [
336
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
337
+ ] + [self.image_token_id] * self.num_image_feature_size + [
338
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
339
+ ]
340
+ return text * num_images, token_ids * num_images
341
+
342
+ def _get_image_repl_features(
343
+ self,
344
+ num_images: int,
345
+ num_patches: int,
346
+ patch_new_line_idx: Optional[list[bool]],
347
+ ) -> tuple[str, list[int]]:
348
+ if num_patches > 0:
349
+ patch_repl, patch_repl_ids = self._get_patch_repl(
350
+ num_patches, patch_new_line_idx)
351
+ else:
352
+ patch_repl = ""
353
+ patch_repl_ids = []
354
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
355
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
356
+
357
+ def replace_placeholder(self, text: str, placeholder: str,
358
+ repls: list[str]) -> str:
359
+ parts = text.split(placeholder)
360
+
361
+ if len(parts) - 1 != len(repls):
362
+ raise ValueError(
363
+ "The number of placeholders does not match the number of replacements." # noqa: E501
364
+ )
365
+
366
+ result = [parts[0]]
367
+ for i, repl in enumerate(repls):
368
+ result.append(repl)
369
+ result.append(parts[i + 1])
370
+
371
+ return "".join(result)
372
+
373
+ def __call__(
374
+ self,
375
+ text: Optional[Union[str, list[str]]] = None,
376
+ images: ImageInput | None = None,
377
+ return_tensors: Optional[Union[str, TensorType]] = None,
378
+ **kwargs,
379
+ ) -> BatchFeature:
380
+
381
+ if images is not None:
382
+ images = self.image_preprocessor.fetch_images(images)
383
+ if text is None:
384
+ text = []
385
+ if not isinstance(text, list):
386
+ text = [text]
387
+ if images is None:
388
+ images = []
389
+ elif not isinstance(images, list):
390
+ images = [images]
391
+ elif isinstance(images[0], list):
392
+ images = images[0]
393
+
394
+ if len(images) == 0:
395
+ image_inputs = {}
396
+ text_inputs = self.tokenizer(text)
397
+ else:
398
+ splitted_images_data = self._split_images(images)
399
+ pixel_values_lst = []
400
+ patch_pixel_values_lst = []
401
+ patch_newline_mask_lst = []
402
+ image_repl_str_lst = []
403
+ image_repl_ids_lst = []
404
+ num_patches = []
405
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
406
+ pixel_values_lst.extend(
407
+ self._convert_images_to_pixel_values([raw_img]))
408
+
409
+ if len(img_patches) > 0:
410
+ patch_pixel_values_lst.extend(
411
+ self._convert_images_to_pixel_values(img_patches,
412
+ is_patch=True))
413
+ num_patches.append(len(img_patches))
414
+
415
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
416
+ 1, len(img_patches), patch_newline_mask)
417
+ image_repl_str_lst.append(image_repl_str)
418
+ image_repl_ids_lst.extend(image_repl_ids)
419
+
420
+ if patch_newline_mask is not None:
421
+ patch_newline_mask_lst.extend(patch_newline_mask)
422
+
423
+ image_inputs = {
424
+ "pixel_values": torch.cat(pixel_values_lst),
425
+ "num_patches": num_patches,
426
+ }
427
+ if patch_pixel_values_lst:
428
+ image_inputs["patch_pixel_values"] = torch.cat(
429
+ patch_pixel_values_lst)
430
+ if patch_newline_mask_lst:
431
+ image_inputs["patch_newline_mask"] = torch.tensor(
432
+ patch_newline_mask_lst, dtype=torch.bool)
433
+
434
+ text = [
435
+ self.replace_placeholder(t, self.image_token,
436
+ image_repl_str_lst) for t in text
437
+ ]
438
+ text_inputs = self.tokenizer(text)
439
+
440
+ return BatchFeature(
441
+ {
442
+ **text_inputs,
443
+ **image_inputs,
444
+ },
445
+ tensor_type=return_tensors,
446
+ )
447
+
448
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
449
+ def batch_decode(self, *args, **kwargs):
450
+ """
451
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
452
+ refer to the docstring of this method for more information.
453
+ """
454
+ return self.tokenizer.batch_decode(*args, **kwargs)
455
+
456
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
457
+ def decode(self, *args, **kwargs):
458
+ """
459
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
460
+ the docstring of this method for more information.
461
+ """
462
+ return self.tokenizer.decode(*args, **kwargs)
463
+
464
+ __all__ = ["Step3VLProcessor"]
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_step3.Step3VLProcessor"
4
+ }
5
+ }
6
+
special_tokens_map.json ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|im_start|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<|im_end|>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ {
18
+ "content": "<|object_ref_start|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ {
25
+ "content": "<|object_ref_end|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ {
32
+ "content": "<|box_start|>",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ },
38
+ {
39
+ "content": "<|box_end|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false
44
+ },
45
+ {
46
+ "content": "<|quad_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false
51
+ },
52
+ {
53
+ "content": "<|quad_end|>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false
58
+ },
59
+ {
60
+ "content": "<|vision_start|>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false
65
+ },
66
+ {
67
+ "content": "<|vision_end|>",
68
+ "lstrip": false,
69
+ "normalized": false,
70
+ "rstrip": false,
71
+ "single_word": false
72
+ },
73
+ {
74
+ "content": "<|vision_pad|>",
75
+ "lstrip": false,
76
+ "normalized": false,
77
+ "rstrip": false,
78
+ "single_word": false
79
+ },
80
+ {
81
+ "content": "<|image_pad|>",
82
+ "lstrip": false,
83
+ "normalized": false,
84
+ "rstrip": false,
85
+ "single_word": false
86
+ },
87
+ {
88
+ "content": "<|video_pad|>",
89
+ "lstrip": false,
90
+ "normalized": false,
91
+ "rstrip": false,
92
+ "single_word": false
93
+ },
94
+ {
95
+ "content": "<tool_calls>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false
100
+ },
101
+ {
102
+ "content": "</tool_calls>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false
107
+ },
108
+ {
109
+ "content": "<|EOT|>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false
114
+ },
115
+ {
116
+ "content": "<|BOT|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false
121
+ },
122
+ {
123
+ "content": "<|CALL_START|>",
124
+ "lstrip": false,
125
+ "normalized": false,
126
+ "rstrip": false,
127
+ "single_word": false
128
+ },
129
+ {
130
+ "content": "<|CALL_END|>",
131
+ "lstrip": false,
132
+ "normalized": false,
133
+ "rstrip": false,
134
+ "single_word": false
135
+ },
136
+ {
137
+ "content": "<|THINK_START|>",
138
+ "lstrip": false,
139
+ "normalized": false,
140
+ "rstrip": false,
141
+ "single_word": false
142
+ },
143
+ {
144
+ "content": "<|THINK_END|>",
145
+ "lstrip": false,
146
+ "normalized": false,
147
+ "rstrip": false,
148
+ "single_word": false
149
+ },
150
+ {
151
+ "content": "<|IMG_START|>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false
156
+ },
157
+ {
158
+ "content": "<|IMG_END|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false
163
+ },
164
+ {
165
+ "content": "<im_patch>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false
170
+ },
171
+ {
172
+ "content": "<im_start>",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false
177
+ },
178
+ {
179
+ "content": "<im_end>",
180
+ "lstrip": false,
181
+ "normalized": false,
182
+ "rstrip": false,
183
+ "single_word": false
184
+ },
185
+ {
186
+ "content": "<dream>",
187
+ "lstrip": false,
188
+ "normalized": false,
189
+ "rstrip": false,
190
+ "single_word": false
191
+ },
192
+ {
193
+ "content": "<dream_start>",
194
+ "lstrip": false,
195
+ "normalized": false,
196
+ "rstrip": false,
197
+ "single_word": false
198
+ },
199
+ {
200
+ "content": "<dream_end>",
201
+ "lstrip": false,
202
+ "normalized": false,
203
+ "rstrip": false,
204
+ "single_word": false
205
+ },
206
+ {
207
+ "content": "<|MASK_1e69f|>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false
212
+ },
213
+ {
214
+ "content": "<|UNMASK_1e69f|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false
219
+ },
220
+ {
221
+ "content": "<video_start>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false
226
+ },
227
+ {
228
+ "content": "<video_end>",
229
+ "lstrip": false,
230
+ "normalized": false,
231
+ "rstrip": false,
232
+ "single_word": false
233
+ },
234
+ {
235
+ "content": "<patch_start>",
236
+ "lstrip": false,
237
+ "normalized": false,
238
+ "rstrip": false,
239
+ "single_word": false
240
+ },
241
+ {
242
+ "content": "<patch_end>",
243
+ "lstrip": false,
244
+ "normalized": false,
245
+ "rstrip": false,
246
+ "single_word": false
247
+ },
248
+ {
249
+ "content": "<patch_newline>",
250
+ "lstrip": false,
251
+ "normalized": false,
252
+ "rstrip": false,
253
+ "single_word": false
254
+ }
255
+ ],
256
+ "eos_token": {
257
+ "content": "<|im_end|>",
258
+ "lstrip": false,
259
+ "normalized": false,
260
+ "rstrip": false,
261
+ "single_word": false
262
+ },
263
+ "pad_token": {
264
+ "content": "<|endoftext|>",
265
+ "lstrip": false,
266
+ "normalized": false,
267
+ "rstrip": false,
268
+ "single_word": false
269
+ }
270
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151669": {
214
+ "content": "<tool_calls>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "</tool_calls>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<|EOT|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ },
237
+ "151672": {
238
+ "content": "<|BOT|>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": false,
242
+ "single_word": false,
243
+ "special": true
244
+ },
245
+ "151673": {
246
+ "content": "<|CALL_START|>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": false,
250
+ "single_word": false,
251
+ "special": true
252
+ },
253
+ "151674": {
254
+ "content": "<|CALL_END|>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": false,
258
+ "single_word": false,
259
+ "special": true
260
+ },
261
+ "151675": {
262
+ "content": "<|THINK_START|>",
263
+ "lstrip": false,
264
+ "normalized": false,
265
+ "rstrip": false,
266
+ "single_word": false,
267
+ "special": true
268
+ },
269
+ "151676": {
270
+ "content": "<|THINK_END|>",
271
+ "lstrip": false,
272
+ "normalized": false,
273
+ "rstrip": false,
274
+ "single_word": false,
275
+ "special": true
276
+ },
277
+ "151677": {
278
+ "content": "<|IMG_START|>",
279
+ "lstrip": false,
280
+ "normalized": false,
281
+ "rstrip": false,
282
+ "single_word": false,
283
+ "special": true
284
+ },
285
+ "151678": {
286
+ "content": "<|IMG_END|>",
287
+ "lstrip": false,
288
+ "normalized": false,
289
+ "rstrip": false,
290
+ "single_word": false,
291
+ "special": true
292
+ },
293
+ "151679": {
294
+ "content": "<im_patch>",
295
+ "lstrip": false,
296
+ "normalized": false,
297
+ "rstrip": false,
298
+ "single_word": false,
299
+ "special": true
300
+ },
301
+ "151680": {
302
+ "content": "<im_start>",
303
+ "lstrip": false,
304
+ "normalized": false,
305
+ "rstrip": false,
306
+ "single_word": false,
307
+ "special": true
308
+ },
309
+ "151681": {
310
+ "content": "<im_end>",
311
+ "lstrip": false,
312
+ "normalized": false,
313
+ "rstrip": false,
314
+ "single_word": false,
315
+ "special": true
316
+ },
317
+ "151682": {
318
+ "content": "<dream>",
319
+ "lstrip": false,
320
+ "normalized": false,
321
+ "rstrip": false,
322
+ "single_word": false,
323
+ "special": true
324
+ },
325
+ "151683": {
326
+ "content": "<dream_start>",
327
+ "lstrip": false,
328
+ "normalized": false,
329
+ "rstrip": false,
330
+ "single_word": false,
331
+ "special": true
332
+ },
333
+ "151684": {
334
+ "content": "<dream_end>",
335
+ "lstrip": false,
336
+ "normalized": false,
337
+ "rstrip": false,
338
+ "single_word": false,
339
+ "special": true
340
+ },
341
+ "151685": {
342
+ "content": "<|MASK_1e69f|>",
343
+ "lstrip": false,
344
+ "normalized": false,
345
+ "rstrip": false,
346
+ "single_word": false,
347
+ "special": true
348
+ },
349
+ "151686": {
350
+ "content": "<|UNMASK_1e69f|>",
351
+ "lstrip": false,
352
+ "normalized": false,
353
+ "rstrip": false,
354
+ "single_word": false,
355
+ "special": true
356
+ },
357
+ "151687": {
358
+ "content": "<video_start>",
359
+ "lstrip": false,
360
+ "normalized": false,
361
+ "rstrip": false,
362
+ "single_word": false,
363
+ "special": true
364
+ },
365
+ "151688": {
366
+ "content": "<video_end>",
367
+ "lstrip": false,
368
+ "normalized": false,
369
+ "rstrip": false,
370
+ "single_word": false,
371
+ "special": true
372
+ },
373
+ "151689": {
374
+ "content": "<patch_start>",
375
+ "lstrip": false,
376
+ "normalized": false,
377
+ "rstrip": false,
378
+ "single_word": false,
379
+ "special": true
380
+ },
381
+ "151690": {
382
+ "content": "<patch_end>",
383
+ "lstrip": false,
384
+ "normalized": false,
385
+ "rstrip": false,
386
+ "single_word": false,
387
+ "special": true
388
+ },
389
+ "151691": {
390
+ "content": "<patch_newline>",
391
+ "lstrip": false,
392
+ "normalized": false,
393
+ "rstrip": false,
394
+ "single_word": false,
395
+ "special": true
396
+ }
397
+ },
398
+ "additional_special_tokens": [
399
+ "<|im_start|>",
400
+ "<|im_end|>",
401
+ "<|object_ref_start|>",
402
+ "<|object_ref_end|>",
403
+ "<|box_start|>",
404
+ "<|box_end|>",
405
+ "<|quad_start|>",
406
+ "<|quad_end|>",
407
+ "<|vision_start|>",
408
+ "<|vision_end|>",
409
+ "<|vision_pad|>",
410
+ "<|image_pad|>",
411
+ "<|video_pad|>",
412
+ "<tool_calls>",
413
+ "</tool_calls>",
414
+ "<|EOT|>",
415
+ "<|BOT|>",
416
+ "<|CALL_START|>",
417
+ "<|CALL_END|>",
418
+ "<|THINK_START|>",
419
+ "<|THINK_END|>",
420
+ "<|IMG_START|>",
421
+ "<|IMG_END|>",
422
+ "<im_patch>",
423
+ "<im_start>",
424
+ "<im_end>",
425
+ "<dream>",
426
+ "<dream_start>",
427
+ "<dream_end>",
428
+ "<|MASK_1e69f|>",
429
+ "<|UNMASK_1e69f|>",
430
+ "<video_start>",
431
+ "<video_end>",
432
+ "<patch_start>",
433
+ "<patch_end>",
434
+ "<patch_newline>"
435
+ ],
436
+ "bos_token": null,
437
+ "clean_up_tokenization_spaces": false,
438
+ "eos_token": "<|im_end|>",
439
+ "errors": "replace",
440
+ "extra_special_tokens": {},
441
+ "model_max_length": 131072,
442
+ "pad_token": "<|endoftext|>",
443
+ "split_special_tokens": false,
444
+ "tokenizer_class": "Qwen2Tokenizer",
445
+ "unk_token": null
446
+ }
vision_encoder.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from transformers.activations import ACT2FN
8
+
9
+ from configuration_step_vl import StepRoboticsVisionEncoderConfig
10
+
11
+
12
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
13
+ """Rotate last dimension halves (used by RoPE)."""
14
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
15
+ x1, x2 = x.unbind(dim=-1)
16
+ x = torch.stack((-x2, x1), dim=-1)
17
+ return rearrange(x, "... d r -> ... (d r)")
18
+
19
+
20
+ def apply_rotary_emb(freqs: torch.Tensor,
21
+ t: torch.Tensor,
22
+ start_index: int = 0,
23
+ scale: float = 1.0,
24
+ seq_dim: int = -2) -> torch.Tensor:
25
+ """Apply 2D rotary embeddings to queries / keys."""
26
+ dtype = t.dtype
27
+
28
+ if t.ndim == 3:
29
+ seq_len = t.shape[seq_dim]
30
+ freqs = freqs[-seq_len:]
31
+
32
+ rot_dim = freqs.shape[-1]
33
+ end_index = start_index + rot_dim
34
+ assert rot_dim <= t.shape[-1], (
35
+ f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
36
+
37
+ t_left, t, t_right = (
38
+ t[..., :start_index],
39
+ t[..., start_index:end_index],
40
+ t[..., end_index:],
41
+ )
42
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
43
+ out = torch.cat((t_left, t, t_right), dim=-1)
44
+ return out.type(dtype)
45
+
46
+
47
+ class EncoderRope2D(nn.Module):
48
+ """Cacheable 2D rotary positional embedding."""
49
+
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ max_grid_height: int,
54
+ max_grid_width: int,
55
+ use_cls_token: bool = False,
56
+ theta: Union[int, float] = 10000,
57
+ max_freq: int = 10,
58
+ num_freqs: int = 1,
59
+ theta_rescale_factor: float = 1.0,
60
+ ):
61
+ super().__init__()
62
+ self.dim = dim
63
+ self.max_grid_height = max_grid_height
64
+ self.max_grid_width = max_grid_width
65
+ self.use_cls_token = use_cls_token
66
+ self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
67
+ self.max_freq = max_freq
68
+ self.num_freqs = num_freqs
69
+ cache = self._compute_2d_freqs()
70
+ self.register_buffer("freqs_cache", cache, persistent=False)
71
+
72
+ def _compute_inv_freq(self, base: Union[int, float],
73
+ dim: int) -> torch.Tensor:
74
+
75
+ freqs = 1.0 / (base**(
76
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
77
+ return freqs
78
+
79
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
80
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
81
+ inv_freq)
82
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
83
+ return freqs
84
+
85
+ def _compute_2d_freqs(self) -> torch.Tensor:
86
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
87
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
88
+ if self.use_cls_token:
89
+ grid_h_range += 1
90
+ grid_w_range += 1
91
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
92
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
93
+ self.max_grid_height, self.max_grid_width, -1)
94
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
95
+ self.max_grid_height, self.max_grid_width, -1)
96
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
97
+ self.max_grid_height * self.max_grid_width, -1)
98
+ if self.use_cls_token:
99
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
100
+ freqs = freqs[None, None, ...]
101
+ return freqs
102
+
103
+ def forward(self, q: torch.Tensor, k: torch.Tensor,
104
+ grid_hw: tuple[int, int]):
105
+ # If grid matches cached shape we reuse directly to avoid recomputation.
106
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
107
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
108
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
109
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(
110
+ torch.long)
111
+ if self.use_cls_token:
112
+ positions = torch.cat(
113
+ [torch.zeros(1, device=q.device), positions + 1], dim=0)
114
+ freqs = self.freqs_cache.index_select(2, positions)
115
+ else:
116
+ freqs = self.freqs_cache
117
+ q = apply_rotary_emb(freqs, q)
118
+ k = apply_rotary_emb(freqs, k)
119
+ return q, k
120
+
121
+
122
+ class EncoderLayerScale(nn.Module):
123
+ """Per-channel residual scaling used when ls_init_value is set."""
124
+
125
+ def __init__(self, dim: int, init_values: float):
126
+ super().__init__()
127
+ self.gamma = nn.Parameter(torch.full((dim,), init_values))
128
+
129
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
130
+ return hidden_states * self.gamma
131
+
132
+
133
+ class EncoderMLP(nn.Module):
134
+ """Feed-forward network used inside each transformer block."""
135
+
136
+ def __init__(self, hidden_size: int, intermediate_size: int,
137
+ hidden_act: str):
138
+ super().__init__()
139
+ self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
140
+ self.act_fn = ACT2FN[hidden_act]
141
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
142
+
143
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
144
+
145
+ hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
146
+ return hidden_states
147
+
148
+
149
+ class EncoderVisionAttention(nn.Module):
150
+ """Multi-head self attention with optional 2D RoPE."""
151
+
152
+ def __init__(
153
+ self,
154
+ hidden_size: int,
155
+ num_heads: int,
156
+ max_grid_height: int,
157
+ max_grid_width: int,
158
+ use_cls_token: bool = False,
159
+ use_rope2d: bool = True,
160
+ rope_theta: Union[int, float] = 10000,
161
+ rope_max_freq: int = 10,
162
+ rope_num_freqs: int = 1,
163
+ rope_theta_rescale_factor: float = 1.0,
164
+ rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
165
+ ):
166
+ super().__init__()
167
+ if hidden_size % num_heads != 0:
168
+ raise ValueError(
169
+ f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
170
+ )
171
+ self.num_heads = num_heads
172
+ self.head_dim = hidden_size // num_heads
173
+ self.scale = self.head_dim**-0.5
174
+ self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
175
+ self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
176
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
177
+
178
+ self.rope = None
179
+ if use_rope2d:
180
+ self.rope = EncoderRope2D(
181
+ dim=self.head_dim,
182
+ max_grid_height=max_grid_height,
183
+ max_grid_width=max_grid_width,
184
+ use_cls_token=use_cls_token,
185
+ theta=rope_theta,
186
+ max_freq=rope_max_freq,
187
+ num_freqs=rope_num_freqs,
188
+ theta_rescale_factor=rope_theta_rescale_factor,
189
+ )
190
+
191
+ def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
192
+ bsz, seq_len, _ = hidden_states.shape
193
+ qkv = F.linear(
194
+ hidden_states,
195
+ self.in_proj_weight,
196
+ self.in_proj_bias,
197
+ )
198
+ q, k, v = qkv.chunk(3, dim=-1)
199
+
200
+ q = q.view(bsz, seq_len, self.num_heads,
201
+ self.head_dim).transpose(1, 2)
202
+ k = k.view(bsz, seq_len, self.num_heads,
203
+ self.head_dim).transpose(1, 2)
204
+ if self.rope is not None:
205
+ q, k = self.rope(q, k, grid_hw=grid_hw)
206
+ v = v.view(bsz, seq_len, self.num_heads,
207
+ self.head_dim).transpose(1, 2)
208
+
209
+ attn_output = F.scaled_dot_product_attention(
210
+ q, k, v, is_causal=False, scale=self.scale)
211
+ attn_output = attn_output.transpose(1, 2).reshape(
212
+ bsz, seq_len, self.num_heads * self.head_dim)
213
+ return self.out_proj(attn_output)
214
+
215
+
216
+ class EncoderVisionBlock(nn.Module):
217
+ """A single Vision Transformer block (self-attention + MLP)."""
218
+
219
+ def __init__(
220
+ self,
221
+ hidden_size: int,
222
+ num_heads: int,
223
+ mlp_ratio: float,
224
+ hidden_act: str,
225
+ layer_norm_eps: float,
226
+ ls_init_value: Optional[float] = None,
227
+ max_grid_height: Optional[int] = None,
228
+ max_grid_width: Optional[int] = None,
229
+ use_cls_token: bool = False,
230
+ use_rope2d: bool = True,
231
+ rope_kwargs: Optional[dict] = None,
232
+ ):
233
+ super().__init__()
234
+ rope_kwargs = rope_kwargs or {}
235
+ self.attn = EncoderVisionAttention(
236
+ hidden_size,
237
+ num_heads,
238
+ max_grid_height=max_grid_height,
239
+ max_grid_width=max_grid_width,
240
+ use_cls_token=use_cls_token,
241
+ use_rope2d=use_rope2d,
242
+ **rope_kwargs,
243
+ )
244
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
245
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
246
+
247
+ intermediate = int(hidden_size * mlp_ratio)
248
+ self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
249
+
250
+ self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
251
+ self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
252
+
253
+ def forward(self, hidden_states: torch.Tensor,
254
+ grid_hw: tuple[int, int]) -> torch.Tensor:
255
+ # breakpoint()
256
+ residual = hidden_states
257
+ hidden_states = self.ln_1(hidden_states)
258
+ hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
259
+ hidden_states = residual + self.ls_1(hidden_states)
260
+
261
+ residual = hidden_states
262
+ hidden_states = self.ln_2(hidden_states)
263
+ hidden_states = self.mlp(hidden_states)
264
+ hidden_states = residual + self.ls_2(hidden_states)
265
+ return hidden_states
266
+
267
+
268
+ class EncoderVisionTransformer(nn.Module):
269
+ """Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
270
+
271
+ def __init__(
272
+ self,
273
+ embed_dim: int,
274
+ depth: int,
275
+ num_heads: int,
276
+ mlp_ratio: float,
277
+ hidden_act: str,
278
+ layer_norm_eps: float,
279
+ ls_init_value: Optional[float] = None,
280
+ max_grid_height: Optional[int] = None,
281
+ max_grid_width: Optional[int] = None,
282
+ use_cls_token: bool = False,
283
+ use_rope2d: bool = True,
284
+ rope_kwargs: Optional[dict] = None,
285
+ ):
286
+ super().__init__()
287
+ self.layers = depth
288
+ rope_kwargs = rope_kwargs or {}
289
+ self.resblocks = nn.ModuleList([
290
+ EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
291
+ layer_norm_eps,
292
+ max_grid_height=max_grid_height,
293
+ max_grid_width=max_grid_width,
294
+ use_cls_token=use_cls_token,
295
+ use_rope2d=use_rope2d,
296
+ ls_init_value=ls_init_value,
297
+ rope_kwargs=rope_kwargs)
298
+ for _ in range(depth)
299
+ ])
300
+
301
+ def forward(self,
302
+ hidden_states: torch.Tensor,
303
+ grid_hw: tuple[int, int]) -> torch.Tensor:
304
+ for block in self.resblocks:
305
+ hidden_states = block(hidden_states, grid_hw=grid_hw)
306
+ return hidden_states
307
+
308
+
309
+ class StepRoboticsVisionEncoder(nn.Module):
310
+ """
311
+ Vision encoder built from StepRoboticsVisionEncoderConfig.
312
+
313
+ The encoder performs patch embedding followed by a stack of transformer
314
+ blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
315
+ StepRoboticVLConfig.vision_config) are expected.
316
+ """
317
+
318
+ def __init__(self, config: StepRoboticsVisionEncoderConfig):
319
+ super().__init__()
320
+ self.config = config
321
+
322
+ # Align commonly used attributes so downstream code (e.g. StepRoboticVL)
323
+ # can access them without extra renaming.
324
+ self.hidden_size = config.width
325
+ self.num_heads = config.heads
326
+ self.num_hidden_layers = config.layers
327
+ self.patch_size = config.patch_size
328
+ self.image_size = config.image_size
329
+ self.use_cls_token = getattr(config, "use_cls_token", False)
330
+ self.use_rope2d = getattr(config, "use_rope2d", True)
331
+ self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
332
+ self.layer_norm_eps = config.layer_norm_eps
333
+ self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
334
+ self.ls_init_value = getattr(config, "ls_init_value", None)
335
+ self.hidden_act = config.hidden_act
336
+ self.use_ln_pre = getattr(config, "use_ln_pre", False)
337
+ self.use_ln_post = getattr(config, "use_ln_post", True)
338
+
339
+ # Patch embedding.
340
+ self.conv1 = nn.Conv2d(in_channels=config.num_channels,
341
+ out_channels=self.hidden_size,
342
+ kernel_size=self.patch_size,
343
+ stride=self.patch_size,
344
+ bias=False)
345
+
346
+ self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
347
+ self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
348
+
349
+ grid_size = self.image_size // self.patch_size
350
+ self.base_grid = (grid_size, grid_size)
351
+
352
+ if self.use_cls_token:
353
+ self.class_embedding = nn.Parameter(
354
+ torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
355
+ else:
356
+ self.class_embedding = None
357
+
358
+ if self.use_abs_posemb:
359
+ self.posemb_grid_size = self.image_size // self.patch_size
360
+ self.positional_embedding = nn.Parameter(
361
+ (self.hidden_size**-0.5) * torch.randn(
362
+ int(self.use_cls_token) + self.posemb_grid_size**2,
363
+ self.hidden_size,
364
+ ))
365
+
366
+ self.transformer = EncoderVisionTransformer(
367
+ embed_dim=self.hidden_size,
368
+ depth=self.num_hidden_layers,
369
+ num_heads=self.num_heads,
370
+ mlp_ratio=self.mlp_ratio,
371
+ hidden_act=self.hidden_act,
372
+ layer_norm_eps=self.layer_norm_eps,
373
+ ls_init_value=self.ls_init_value,
374
+ max_grid_height=self.base_grid[0],
375
+ max_grid_width=self.base_grid[1],
376
+ use_cls_token=self.use_cls_token,
377
+ use_rope2d=self.use_rope2d,
378
+ rope_kwargs={
379
+ "rope_theta": getattr(config, "rope_theta", 10000),
380
+ "rope_max_freq": getattr(config, "rope_max_freq", 10),
381
+ "rope_num_freqs": getattr(config, "rope_num_freqs", 1),
382
+ "rope_theta_rescale_factor":
383
+ getattr(config, "rope_theta_rescale_factor", 1.0),
384
+ "rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
385
+ },
386
+ )
387
+ self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
388
+ self.hidden_size * 2,
389
+ kernel_size=3,
390
+ stride=2,
391
+ padding=1)
392
+ self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
393
+ self.hidden_size * 4,
394
+ kernel_size=3,
395
+ stride=2,
396
+ padding=1)
397
+
398
+
399
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
400
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
401
+ return self.positional_embedding[None, ...]
402
+
403
+ pos_embed = self.positional_embedding
404
+ if self.use_cls_token:
405
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
406
+
407
+ pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
408
+ self.posemb_grid_size,
409
+ -1).permute(0, 3, 1, 2).contiguous())
410
+ pos_embed = F.interpolate(pos_embed,
411
+ size=(grid_h, grid_w),
412
+ mode="bilinear",
413
+ align_corners=False)
414
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
415
+
416
+ if self.use_cls_token:
417
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
418
+
419
+ return pos_embed[None, ...]
420
+
421
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
422
+ """
423
+ Args:
424
+ pixel_values: Image tensor of shape (B, C, H, W).
425
+ layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
426
+ strip_cls_token: If True and cls token is used, remove it from output.
427
+ """
428
+ bsz, _, height, width = pixel_values.shape
429
+ grid_h, grid_w = height // self.patch_size, width // self.patch_size
430
+
431
+ hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
432
+ hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
433
+
434
+ if self.use_cls_token:
435
+ cls_token = self.class_embedding.view(1, 1,
436
+ -1).expand(bsz, -1, -1)
437
+ hidden_state = torch.cat([cls_token, hidden_state], dim=1)
438
+
439
+ if self.use_abs_posemb:
440
+ pos_emb = self.sample_abs_posemb(grid_h, grid_w)
441
+ hidden_state = hidden_state + pos_emb
442
+ hidden_state = self.ln_pre(hidden_state)
443
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
444
+
445
+ if self.use_ln_post:
446
+ hidden_state = self.ln_post(hidden_state)
447
+
448
+ if self.use_cls_token:
449
+ hidden_state = hidden_state[:, 1:, :]
450
+
451
+ return hidden_state